From c05d7c9f9521ad2b1d1534c66423f8ce4fa9f535 Mon Sep 17 00:00:00 2001 From: xiaoxin <718949661@qq.com> Date: Thu, 26 Sep 2024 15:10:55 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E6=96=B0=E5=A2=9E=E3=80=91AI=EF=BC=9A?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E6=B6=88=E6=81=AF=E8=AE=B0=E5=BD=95=E5=8F=AC?= =?UTF-8?q?=E5=9B=9E=E6=AE=B5=E8=90=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dal/dataobject/chat/AiChatMessageDO.java | 18 ++++++- .../chat/AiChatMessageServiceImpl.java | 51 ++++++++++++------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java index 973c593ce..ecd10609f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java @@ -1,13 +1,18 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat; -import com.baomidou.mybatisplus.annotation.TableId; -import org.springframework.ai.chat.messages.MessageType; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import com.baomidou.mybatisplus.annotation.KeySequence; +import com.baomidou.mybatisplus.annotation.TableField; +import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler; import lombok.*; +import org.springframework.ai.chat.messages.MessageType; + +import java.util.List; /** * AI Chat 消息 DO @@ -66,6 +71,15 @@ public class AiChatMessageDO extends BaseDO { */ private Long roleId; + + /** + * 段落编号数组 + * + * 关联 {@link AiKnowledgeSegmentDO#getId()} 字段 + */ + @TableField(typeHandler = JacksonTypeHandler.class) + private List segmentIds; + /** * 模型标志 */ diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 4ef5af8ee..1247ce12d 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -90,13 +90,16 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); - // 3.2 创建 chat 需要的 Prompt - Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); + // 3.2 召回段落 + List segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId()); + + // 3.3 创建 chat 需要的 Prompt + Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO); ChatResponse chatResponse = chatModel.call(prompt); - // 3.3 段式返回 + // 3.4 段式返回 String newContent = chatResponse.getResult().getOutput().getContent(); - chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent)); + chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent)); return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent)); } @@ -121,11 +124,15 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); - // 3.2 构建 Prompt,并进行调用 - Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); + + // 3.2 召回段落 + List segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId()); + + // 3.3 构建 Prompt,并进行调用 + Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO); Flux streamResponse = chatModel.stream(prompt); - // 3.3 流式返回 + // 3.4 流式返回 // TODO 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { @@ -138,7 +145,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { }).doOnComplete(() -> { // 忽略租户,因为 Flux 异步无法透传租户 TenantUtils.executeIgnore(() -> - chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString()))); + chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)) + .setContent(contentBuffer.toString()))); }).doOnError(throwable -> { log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable); // 忽略租户,因为 Flux 异步无法透传租户 @@ -147,21 +155,26 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR))); } - private Prompt buildPrompt(AiChatConversationDO conversation, List messages, + private List recallSegment(String content, Long knowledgeId) { + List segmentList = new ArrayList<>(); + if (Objects.nonNull(knowledgeId)) { + segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content)); + } + return segmentList; + } + + private Prompt buildPrompt(AiChatConversationDO conversation, List messages,List segmentList, AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) { // 1. 构建 Prompt Message 列表 List chatMessages = new ArrayList<>(); - // 1.1 知识库召回 - if (Objects.nonNull(conversation.getKnowledgeId())) { - List segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(conversation.getKnowledgeId()).setContent(sendReqVO.getContent())); - if (CollUtil.isNotEmpty(segmentList)) { - PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage()); - StringBuilder infoBuilder = StrUtil.builder(); - segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent())); - Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString())); - chatMessages.add(message); - } + // 1.1 召回内容消息构建 + if (CollUtil.isNotEmpty(segmentList)) { + PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage()); + StringBuilder infoBuilder = StrUtil.builder(); + segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent())); + Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString())); + chatMessages.add(message); } // 1.2 system context 角色设定