From 25523fd53ee0d9cd9657c14f85f7f298a6094728 Mon Sep 17 00:00:00 2001 From: cherishsince Date: Thu, 18 Apr 2024 17:10:08 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=20chat=20=E5=92=8C=20steamCh?= =?UTF-8?q?at=20=E8=8E=B7=E5=8F=96=E5=AF=B9=E8=AF=9D=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/service/impl/ChatServiceImpl.java | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java index a2d8e54e2..913944dc0 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java @@ -18,7 +18,9 @@ import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum; import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper; +import cn.iocoder.yudao.module.ai.service.ChatConversationService; import cn.iocoder.yudao.module.ai.service.ChatService; +import cn.iocoder.yudao.module.ai.vo.ChatConversationRes; import cn.iocoder.yudao.module.ai.vo.ChatReq; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -46,6 +48,7 @@ public class ChatServiceImpl implements ChatService { private final AiChatRoleMapper aiChatRoleMapper; private final AiChatMessageMapper aiChatMessageMapper; private final AiChatConversationMapper aiChatConversationMapper; + private final ChatConversationService chatConversationService; /** @@ -59,13 +62,10 @@ public class ChatServiceImpl implements ChatService { Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); // 获取 client 类型 AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); - // 获取 对话类型(新建还是继续) - ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); - AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId); - + // 获取对话信息 + ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId()); // 保存 chat message - saveChatMessage(req, aiChatConversationDO.getId(), loginUserId); - + saveChatMessage(req, conversationRes.getId(), loginUserId); String content = null; try { // 创建 chat 需要的 Prompt @@ -80,7 +80,7 @@ public class ChatServiceImpl implements ChatService { content = ExceptionUtil.getMessage(e); } finally { // 保存 chat message - saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content); + saveSystemChatMessage(req, conversationRes.getId(), loginUserId, content); } return content; } @@ -176,16 +176,15 @@ public class ChatServiceImpl implements ChatService { Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); // 获取 client 类型 AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); - // 获取 对话类型(新建还是继续) - ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); - AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId); + // 获取对话信息 + ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId()); // 创建 chat 需要的 Prompt Prompt prompt = new Prompt(req.getPrompt()); req.setTopK(req.getTopK()); req.setTopP(req.getTopP()); req.setTemperature(req.getTemperature()); // 保存 chat message - saveChatMessage(req, aiChatConversationDO.getId(), loginUserId); + saveChatMessage(req, conversationRes.getId(), loginUserId); Flux streamResponse = aiClient.stream(prompt, clientNameEnum.getName()); StringBuffer contentBuffer = new StringBuffer(); @@ -212,7 +211,7 @@ public class ChatServiceImpl implements ChatService { log.info("发送完成!"); sseEmitter.complete(); // 保存 chat message - saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString()); + saveSystemChatMessage(req, conversationRes.getId(), loginUserId, contentBuffer.toString()); } ); }