diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java index a2d8e54e2..913944dc0 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java @@ -18,7 +18,9 @@ import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum; import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper; +import cn.iocoder.yudao.module.ai.service.ChatConversationService; import cn.iocoder.yudao.module.ai.service.ChatService; +import cn.iocoder.yudao.module.ai.vo.ChatConversationRes; import cn.iocoder.yudao.module.ai.vo.ChatReq; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -46,6 +48,7 @@ public class ChatServiceImpl implements ChatService { private final AiChatRoleMapper aiChatRoleMapper; private final AiChatMessageMapper aiChatMessageMapper; private final AiChatConversationMapper aiChatConversationMapper; + private final ChatConversationService chatConversationService; /** @@ -59,13 +62,10 @@ public class ChatServiceImpl implements ChatService { Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); // 获取 client 类型 AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); - // 获取 对话类型(新建还是继续) - ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); - AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId); - + // 获取对话信息 + ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId()); // 保存 chat message - saveChatMessage(req, aiChatConversationDO.getId(), loginUserId); - + saveChatMessage(req, conversationRes.getId(), loginUserId); String content = null; try { // 创建 chat 需要的 Prompt @@ -80,7 +80,7 @@ public class ChatServiceImpl implements ChatService { content = ExceptionUtil.getMessage(e); } finally { // 保存 chat message - saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content); + saveSystemChatMessage(req, conversationRes.getId(), loginUserId, content); } return content; } @@ -176,16 +176,15 @@ public class ChatServiceImpl implements ChatService { Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); // 获取 client 类型 AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); - // 获取 对话类型(新建还是继续) - ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); - AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId); + // 获取对话信息 + ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId()); // 创建 chat 需要的 Prompt Prompt prompt = new Prompt(req.getPrompt()); req.setTopK(req.getTopK()); req.setTopP(req.getTopP()); req.setTemperature(req.getTemperature()); // 保存 chat message - saveChatMessage(req, aiChatConversationDO.getId(), loginUserId); + saveChatMessage(req, conversationRes.getId(), loginUserId); Flux streamResponse = aiClient.stream(prompt, clientNameEnum.getName()); StringBuffer contentBuffer = new StringBuffer(); @@ -212,7 +211,7 @@ public class ChatServiceImpl implements ChatService { log.info("发送完成!"); sseEmitter.complete(); // 保存 chat message - saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString()); + saveSystemChatMessage(req, conversationRes.getId(), loginUserId, contentBuffer.toString()); } ); }