From 802dee2fc353f41d16fb4c7176f7ededaf6f0fa2 Mon Sep 17 00:00:00 2001 From: YunaiV Date: Tue, 21 May 2024 08:30:51 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E6=96=B0=E5=A2=9E=E3=80=91AI=EF=BC=9A?= =?UTF-8?q?=E5=8F=91=E9=80=81=E6=B6=88=E6=81=AF=E6=97=B6=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E4=B8=8A=E4=B8=8B=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../vo/message/AiChatMessageSendReqVO.java | 3 + .../dal/dataobject/chat/AiChatMessageDO.java | 17 ++++- .../ai/service/impl/AiChatServiceImpl.java | 71 ++++++++++++++----- 3 files changed, 72 insertions(+), 19 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendReqVO.java index 9592da347..89a84bcbd 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendReqVO.java @@ -19,4 +19,7 @@ public class AiChatMessageSendReqVO { @NotEmpty(message = "聊天内容不能为空") private String content; + @Schema(description = "是否携带上下文", example = "true") + private Boolean useContext; + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java index c66537673..1c915ed7c 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java @@ -1,5 +1,6 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat; +import com.baomidou.mybatisplus.annotation.TableId; import org.springframework.ai.chat.messages.MessageType; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; @@ -27,14 +28,23 @@ public class AiChatMessageDO extends BaseDO { /** * 编号,作为每条聊天记录的唯一标识符 */ + @TableId private Long id; /** * 会话编号 * - * 关联 {@link AiChatConversationDO#getId()} + * 关联 {@link AiChatConversationDO#getId()} 字段 */ private Long conversationId; + /** + * 回复消息编号 + * + * 关联 {@link #id} 字段 + * + * 大模型回复的消息编号,用于“问答”的关联 + */ + private Long replyId; /** * 消息类型 @@ -75,6 +85,9 @@ public class AiChatMessageDO extends BaseDO { */ private String content; - // TODO 芋艿:是否作为上下文语料?use_context,待定 + /** + * 是否携带上下文 + */ + private Boolean useContext; } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java index 9fd80d1c6..c45a1025e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java @@ -1,5 +1,9 @@ package cn.iocoder.yudao.module.ai.service.impl; +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.collection.ListUtil; +import cn.hutool.core.util.ArrayUtil; +import cn.hutool.core.util.BooleanUtil; import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; @@ -109,15 +113,14 @@ public class AiChatServiceImpl implements AiChatService { StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform); // 2. 插入 user 发送消息 - AiChatMessageDO userMessage = createChatMessage(conversation.getId(), model, - userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent()); + AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, + userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext()); // 3.1 插入 assistant 接收消息 - AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), model, - userId, conversation.getRoleId(), MessageType.ASSISTANT, ""); + AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, + userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); // 3.2 创建 chat 需要的 Prompt - // TODO 消息上下文 Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO); Flux streamResponse = chatClient.stream(prompt); @@ -139,32 +142,66 @@ public class AiChatServiceImpl implements AiChatService { } private Prompt buildPrompt(AiChatConversationDO conversation, List messages, AiChatMessageSendReqVO sendReqVO) { - // TODO 芋艿:1)保留 n 个上下文;2)每一轮 token 数量 -// if (conversation.getMaxContexts() != null && messages.size() > conversation.getMaxContexts()) { -// -// } // 1. 构建 Prompt Message 列表 List chatMessages = new ArrayList<>(); // 1.1 system context 角色设定 chatMessages.add(new SystemMessage(conversation.getSystemMessage())); // 1.2 history message 历史消息 - messages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent()))); + List contextMessages = filterContextMessages(messages, conversation, sendReqVO); + contextMessages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent()))); // 1.3 user message 新发送消息 chatMessages.add(new UserMessage(sendReqVO.getContent())); // 2. 构建 ChatOptions 对象 TODO 芋艿:临时注释掉;等文心一言兼容了; + // TODO 每一轮 token 数量 // ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build(); // return new Prompt(chatMessages, null); return new Prompt(chatMessages); } - private AiChatMessageDO createChatMessage(Long conversationId, AiChatModelDO model, - Long userId, Long roleId, - MessageType messageType, String content) { - AiChatMessageDO message = new AiChatMessageDO() - .setConversationId(conversationId).setModel(model.getModel()).setModelId(model.getId()) - .setUserId(userId).setRoleId(roleId) - .setType(messageType.getValue()).setContent(content); + /** + * 从历史消息中,获得倒序的 n 组消息作为消息上下文 + * + * n 组:指的是 user + assistant 形成一组 + * + * @param messages 消息列表 + * @param conversation 会话 + * @param sendReqVO 发送请求 + * @return 消息上下文 + */ + private List filterContextMessages(List messages, AiChatConversationDO conversation, AiChatMessageSendReqVO sendReqVO) { + if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) { + return Collections.emptyList(); + } + List contextMessages = new ArrayList<>(conversation.getMaxContexts() * 2); + for (int i = messages.size() - 1; i >= 0; i--) { + AiChatMessageDO assistantMessage = CollUtil.get(messages, i); + if (assistantMessage == null || assistantMessage.getReplyId() == null) { + continue; + } + AiChatMessageDO userMessage = CollUtil.get(messages, i - 1); + if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId()) + || StrUtil.isEmpty(assistantMessage.getContent())) { + continue; + } + // 由于后续要 reverse 反转,所以先添加 assistantMessage + contextMessages.add(assistantMessage); + contextMessages.add(userMessage); + // 超过最大上下文,结束 + if (contextMessages.size() >= conversation.getMaxContexts() * 2) { + break; + } + } + Collections.reverse(contextMessages); + return contextMessages; + } + + private AiChatMessageDO createChatMessage(Long conversationId, Long replyId, + AiChatModelDO model, Long userId, Long roleId, + MessageType messageType, String content, Boolean useContext) { + AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId) + .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId) + .setType(messageType.getValue()).setContent(content).setUseContext(useContext); message.setCreateTime(LocalDateTime.now()); chatMessageMapper.insert(message); return message;