From f84d25d3b710f0788c0427f5f0d59739a291033d Mon Sep 17 00:00:00 2001 From: cherishsince Date: Thu, 18 Apr 2024 16:09:45 +0800 Subject: [PATCH] =?UTF-8?q?stream=20=E4=BF=9D=E5=AD=98=E8=81=8A=E5=A4=A9?= =?UTF-8?q?=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../module/ai/controller/ChatController.java | 27 +---- .../yudao/module/ai/service/ChatService.java | 7 +- .../ai/service/impl/ChatServiceImpl.java | 105 +++++++++++++++--- 3 files changed, 92 insertions(+), 47 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/ChatController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/ChatController.java index 3a71c5434..d9f2d2b97 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/ChatController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/ChatController.java @@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.controller; import cn.hutool.core.exceptions.ExceptionUtil; import cn.iocoder.yudao.framework.ai.chat.ChatResponse; -import cn.iocoder.yudao.framework.ai.config.AiClient; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.module.ai.service.ChatService; import cn.iocoder.yudao.module.ai.vo.ChatReq; @@ -38,7 +37,6 @@ import java.util.function.Consumer; public class ChatController { @Autowired - private AiClient aiClient; private final ChatService chatService; @Operation(summary = "聊天-chat", description = "这个一般等待时间比较久,需要全部完成才会返回!") @@ -52,30 +50,7 @@ public class ChatController { @GetMapping(value = "/chatStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public SseEmitter chatStream(@Validated @ModelAttribute ChatReq req) { Utf8SseEmitter sseEmitter = new Utf8SseEmitter(); - Flux streamResponse = chatService.chatStream(req); - streamResponse.subscribe( - new Consumer() { - @Override - public void accept(ChatResponse chatResponse) { - String content = chatResponse.getResults().get(0).getOutput().getContent(); - try { - sseEmitter.send(content, MediaType.APPLICATION_JSON); - } catch (IOException e) { - log.error("发送异常{}", ExceptionUtil.getMessage(e)); - // 如果不是因为关闭而抛出异常,则重新连接 - sseEmitter.completeWithError(e); - } - } - }, - error -> { - // - log.error("subscribe错误 {}", ExceptionUtil.getMessage(error)); - }, - () -> { - log.info("发送完成!"); - sseEmitter.complete(); - } - ); + chatService.chatStream(req, sseEmitter); return sseEmitter; } } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/ChatService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/ChatService.java index f72c52493..fd19d822a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/ChatService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/ChatService.java @@ -1,9 +1,7 @@ package cn.iocoder.yudao.module.ai.service; -import cn.iocoder.yudao.framework.ai.chat.ChatResponse; -import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum; +import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.vo.ChatReq; -import reactor.core.publisher.Flux; /** * 聊天 chat @@ -26,7 +24,8 @@ public interface ChatService { * chat stream * * @param req + * @param sseEmitter * @return */ - Flux chatStream(ChatReq req); + void chatStream(ChatReq req, Utf8SseEmitter sseEmitter); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java index 6fe226181..a2d8e54e2 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java @@ -8,6 +8,7 @@ import cn.iocoder.yudao.framework.ai.config.AiClient; import cn.iocoder.yudao.framework.common.exception.ServerException; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.module.ai.ErrorCodeConstants; +import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO; @@ -21,10 +22,14 @@ import cn.iocoder.yudao.module.ai.service.ChatService; import cn.iocoder.yudao.module.ai.vo.ChatReq; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.http.MediaType; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; +import java.io.IOException; +import java.util.function.Consumer; + /** * 聊天 service * @@ -51,25 +56,17 @@ public class ChatServiceImpl implements ChatService { */ @Transactional(rollbackFor = Exception.class) public String chat(ChatReq req) { + Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); // 获取 client 类型 AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); // 获取 对话类型(新建还是继续) ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); + AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId); - AiChatConversationDO aiChatConversationDO; - Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); - if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) { - // 创建一个新的对话 - aiChatConversationDO = createNewChatConversation(req, loginUserId); - } else { - // 继续对话 - if (req.getConversationId() == null) { - throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL); - } - aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId()); - } + // 保存 chat message + saveChatMessage(req, aiChatConversationDO.getId(), loginUserId); - String content; + String content = null; try { // 创建 chat 需要的 Prompt Prompt prompt = new Prompt(req.getPrompt()); @@ -81,13 +78,19 @@ public class ChatServiceImpl implements ChatService { content = call.getResult().getOutput().getContent(); } catch (Exception e) { content = ExceptionUtil.getMessage(e); + } finally { + // 保存 chat message + saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content); } + return content; + } + private void saveChatMessage(ChatReq req, Long chatConversationId, Long loginUserId) { // 增加 chat message 记录 aiChatMessageMapper.insert( new AiChatMessageDO() .setId(null) - .setChatConversationId(aiChatConversationDO.getId()) + .setChatConversationId(chatConversationId) .setUserId(loginUserId) .setMessage(req.getPrompt()) .setMessageType(MessageType.USER.getValue()) @@ -98,7 +101,39 @@ public class ChatServiceImpl implements ChatService { // chat count 先+1 aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); - return content; + } + + public void saveSystemChatMessage(ChatReq req, Long chatConversationId, Long loginUserId, String systemPrompts) { + // 增加 chat message 记录 + aiChatMessageMapper.insert( + new AiChatMessageDO() + .setId(null) + .setChatConversationId(chatConversationId) + .setUserId(loginUserId) + .setMessage(systemPrompts) + .setMessageType(MessageType.SYSTEM.getValue()) + .setTopK(req.getTopK()) + .setTopP(req.getTopP()) + .setTemperature(req.getTemperature()) + ); + + // chat count 先+1 + aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); + } + + private AiChatConversationDO getChatConversationNoExistToCreate(ChatReq req, ChatConversationTypeEnum chatConversationTypeEnum, Long loginUserId) { + AiChatConversationDO aiChatConversationDO; + if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) { + // 创建一个新的对话 + aiChatConversationDO = createNewChatConversation(req, loginUserId); + } else { + // 继续对话 + if (req.getConversationId() == null) { + throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL); + } + aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId()); + } + return aiChatConversationDO; } private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) { @@ -133,16 +168,52 @@ public class ChatServiceImpl implements ChatService { * chat stream * * @param req + * @param sseEmitter * @return */ @Override - public Flux chatStream(ChatReq req) { + public void chatStream(ChatReq req, Utf8SseEmitter sseEmitter) { + Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); + // 获取 client 类型 AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); + // 获取 对话类型(新建还是继续) + ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); + AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId); // 创建 chat 需要的 Prompt Prompt prompt = new Prompt(req.getPrompt()); req.setTopK(req.getTopK()); req.setTopP(req.getTopP()); req.setTemperature(req.getTemperature()); - return aiClient.stream(prompt, clientNameEnum.getName()); + // 保存 chat message + saveChatMessage(req, aiChatConversationDO.getId(), loginUserId); + Flux streamResponse = aiClient.stream(prompt, clientNameEnum.getName()); + + StringBuffer contentBuffer = new StringBuffer(); + streamResponse.subscribe( + new Consumer() { + @Override + public void accept(ChatResponse chatResponse) { + String content = chatResponse.getResults().get(0).getOutput().getContent(); + try { + contentBuffer.append(content); + sseEmitter.send(content, MediaType.APPLICATION_JSON); + } catch (IOException e) { + log.error("发送异常{}", ExceptionUtil.getMessage(e)); + // 如果不是因为关闭而抛出异常,则重新连接 + sseEmitter.completeWithError(e); + } + } + }, + error -> { + // + log.error("subscribe错误 {}", ExceptionUtil.getMessage(error)); + }, + () -> { + log.info("发送完成!"); + sseEmitter.complete(); + // 保存 chat message + saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString()); + } + ); } }