diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java index 029961bf3..6cb98c562 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java @@ -34,7 +34,12 @@ public enum AiChatRoleEnum { ### 支付宝 ### 微信 除此之外不要任何解释性语句。 - """); + """), + + AI_KNOWLEDGE_ROLE("知识库助手", """ + 给你提供一些数据参考:{info},请回答我的问题。 + 请你跟进数据参考与工具返回结果回复用户的请求。 + """); /** * 角色名 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationCreateMyReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationCreateMyReqVO.java index c13200b6a..84595bea2 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationCreateMyReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationCreateMyReqVO.java @@ -10,4 +10,7 @@ public class AiChatConversationCreateMyReqVO { @Schema(description = "聊天角色编号", example = "666") private Long roleId; + @Schema(description = "知识库编号", example = "1204") + private Long knowledgeId; + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationUpdateMyReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationUpdateMyReqVO.java index f9ce64bae..2b57572c4 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationUpdateMyReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationUpdateMyReqVO.java @@ -21,6 +21,9 @@ public class AiChatConversationUpdateMyReqVO { @Schema(description = "模型编号", example = "1") private Long modelId; + @Schema(description = "知识库编号", example = "1") + private Long knowledgeId; + @Schema(description = "角色设定", example = "一个快乐的程序员") private String systemMessage; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java index 0b7eb0233..7d9625f58 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import com.baomidou.mybatisplus.annotation.KeySequence; @@ -64,6 +65,13 @@ public class AiChatConversationDO extends BaseDO { */ private Long roleId; + /** + * 知识库编号 + *

+ * 关联 {@link AiKnowledgeDO#getId()} + */ + private Long knowledgeId; + /** * 模型编号 * diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java index 83dcd8dff..8f094087f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java @@ -13,6 +13,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper; +import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import jakarta.annotation.Resource; @@ -22,6 +23,7 @@ import org.springframework.validation.annotation.Validated; import java.time.LocalDateTime; import java.util.List; +import java.util.Objects; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; @@ -45,6 +47,8 @@ public class AiChatConversationServiceImpl implements AiChatConversationService private AiChatModelService chatModalService; @Resource private AiChatRoleService chatRoleService; + @Resource + private AiKnowledgeService knowledgeService; @Override public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) { @@ -56,9 +60,14 @@ public class AiChatConversationServiceImpl implements AiChatConversationService Assert.notNull(model, "必须找到默认模型"); validateChatModel(model); + // 1.3 校验知识库 + if (Objects.nonNull(createReqVO.getKnowledgeId())) { + knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId()); + } + // 2. 创建 AiChatConversationDO 聊天对话 AiChatConversationDO conversation = new AiChatConversationDO().setUserId(userId).setPinned(false) - .setModelId(model.getId()).setModel(model.getModel()) + .setModelId(model.getId()).setModel(model.getModel()).setKnowledgeId(createReqVO.getKnowledgeId()) .setTemperature(model.getTemperature()).setMaxTokens(model.getMaxTokens()).setMaxContexts(model.getMaxContexts()); if (role != null) { conversation.setTitle(role.getName()).setRoleId(role.getId()).setSystemMessage(role.getSystemMessage()); @@ -82,6 +91,11 @@ public class AiChatConversationServiceImpl implements AiChatConversationService model = chatModalService.validateChatModel(updateReqVO.getModelId()); } + // 1.3 校验知识库是否存在 + if (updateReqVO.getKnowledgeId() != null) { + knowledgeService.validateKnowledgeExists(updateReqVO.getKnowledgeId()); + } + // 2. 更新对话信息 AiChatConversationDO updateObj = BeanUtils.toBean(updateReqVO, AiChatConversationDO.class); if (Boolean.TRUE.equals(updateReqVO.getPinned())) { diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 72fa06a79..4ef5af8ee 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -12,21 +12,29 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; +import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; +import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; +import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; @@ -59,6 +67,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { private AiChatModelService chatModalService; @Resource private AiApiKeyService apiKeyService; + @Resource + private AiKnowledgeSegmentService knowledgeSegmentService; @Transactional(rollbackFor = Exception.class) public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { @@ -141,14 +151,27 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) { // 1. 构建 Prompt Message 列表 List chatMessages = new ArrayList<>(); - // 1.1 system context 角色设定 + + // 1.1 知识库召回 + if (Objects.nonNull(conversation.getKnowledgeId())) { + List segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(conversation.getKnowledgeId()).setContent(sendReqVO.getContent())); + if (CollUtil.isNotEmpty(segmentList)) { + PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage()); + StringBuilder infoBuilder = StrUtil.builder(); + segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent())); + Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString())); + chatMessages.add(message); + } + } + + // 1.2 system context 角色设定 if (StrUtil.isNotBlank(conversation.getSystemMessage())) { chatMessages.add(new SystemMessage(conversation.getSystemMessage())); } - // 1.2 history message 历史消息 + // 1.3 history message 历史消息 List contextMessages = filterContextMessages(messages, conversation, sendReqVO); contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); - // 1.3 user message 新发送消息 + // 1.4 user message 新发送消息 chatMessages.add(new UserMessage(sendReqVO.getContent())); // 2. 构建 ChatOptions 对象 @@ -160,12 +183,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { /** * 从历史消息中,获得倒序的 n 组消息作为消息上下文 - * + *

* n 组:指的是 user + assistant 形成一组 * - * @param messages 消息列表 + * @param messages 消息列表 * @param conversation 对话 - * @param sendReqVO 发送请求 + * @param sendReqVO 发送请求 * @return 消息上下文 */ private List filterContextMessages(List messages, @@ -182,7 +205,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } AiChatMessageDO userMessage = CollUtil.get(messages, i - 1); if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId()) - || StrUtil.isEmpty(assistantMessage.getContent())) { + || StrUtil.isEmpty(assistantMessage.getContent())) { continue; } // 由于后续要 reverse 反转,所以先添加 assistantMessage