【新增】AI:对话消息记录召回段落

This commit is contained in:
xiaoxin 2024-09-26 15:10:55 +08:00
parent 6b651baeed
commit c05d7c9f95
2 changed files with 48 additions and 21 deletions

View File

@ -1,13 +1,18 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import com.baomidou.mybatisplus.annotation.TableId;
import org.springframework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.*;
import org.springframework.ai.chat.messages.MessageType;
import java.util.List;
/**
* AI Chat 消息 DO
@ -66,6 +71,15 @@ public class AiChatMessageDO extends BaseDO {
*/
private Long roleId;
/**
* 段落编号数组
*
* 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
*/
@TableField(typeHandler = JacksonTypeHandler.class)
private List<Long> segmentIds;
/**
* 模型标志
*/

View File

@ -90,13 +90,16 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
// 3.2 召回段落
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
// 3.3 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
ChatResponse chatResponse = chatModel.call(prompt);
// 3.3 段式返回
// 3.4 段式返回
String newContent = chatResponse.getResult().getOutput().getContent();
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent));
return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
}
@ -121,11 +124,15 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
// 3.2 召回段落
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
// 3.3 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 3.3 流式返回
// 3.4 流式返回
// TODO 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
@ -138,7 +145,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}).doOnComplete(() -> {
// 忽略租户因为 Flux 异步无法透传租户
TenantUtils.executeIgnore(() ->
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())));
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId))
.setContent(contentBuffer.toString())));
}).doOnError(throwable -> {
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
// 忽略租户因为 Flux 异步无法透传租户
@ -147,21 +155,26 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
}
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
List<AiKnowledgeSegmentDO> segmentList = new ArrayList<>();
if (Objects.nonNull(knowledgeId)) {
segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
}
return segmentList;
}
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
// 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>();
// 1.1 知识库召回
if (Objects.nonNull(conversation.getKnowledgeId())) {
List<AiKnowledgeSegmentDO> segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(conversation.getKnowledgeId()).setContent(sendReqVO.getContent()));
if (CollUtil.isNotEmpty(segmentList)) {
PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
StringBuilder infoBuilder = StrUtil.builder();
segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
chatMessages.add(message);
}
// 1.1 召回内容消息构建
if (CollUtil.isNotEmpty(segmentList)) {
PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
StringBuilder infoBuilder = StrUtil.builder();
segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
chatMessages.add(message);
}
// 1.2 system context 角色设定