diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/resources/http/chat-message.http b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http similarity index 82% rename from yudao-module-ai/yudao-module-ai-biz/src/main/resources/http/chat-message.http rename to yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http index b357ab66e..2d417a55f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/resources/http/chat-message.http +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http @@ -10,13 +10,13 @@ Authorization: {{token}} } -### chat call -POST {{baseUrl}}/admin-api/ai/chat/message/send-stream +### 发送消息(流式) +POST {{baseUrl}}/ai/chat/message/send-stream Content-Type: application/json Authorization: {{token}} { - "conversationId": "1781604279872581649", + "conversationId": "1781604279872581651", "content": "苹果是什么颜色?" } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java index a89381ff4..47f35f4ea 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java @@ -17,6 +17,7 @@ import reactor.core.publisher.Flux; import java.util.List; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; +import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; @Tag(name = "管理后台 - 聊天消息") @RestController @@ -36,14 +37,8 @@ public class AiChatMessageController { @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题 - public Flux sendMessageStream(@Validated @RequestBody AiChatMessageSendStreamReqVO sendReqVO) { - return chatService.chatStream(sendReqVO); - } - - @Operation(summary = "添加/提问", description = "先创建好 message 前端才好渲染") - @PostMapping(value = "/add") - public CommonResult add(@Validated @RequestBody AiChatMessageAddReqVO req) { - return success(chatService.add(req)); + public Flux sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { + return chatService.sendChatMessageStream(sendReqVO, getLoginUserId()); } @Operation(summary = "获得指定会话的消息列表") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java index f117c67c6..c4863c735 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java @@ -44,4 +44,5 @@ public class AiChatMessageRespVO { @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51") private LocalDateTime createTime; + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java new file mode 100644 index 000000000..9ea7900cb --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java @@ -0,0 +1,36 @@ +package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +import java.time.LocalDateTime; + +@Schema(description = "管理后台 - AI 聊天消息发送 Response VO") +@Data +public class AiChatMessageSendRespVO { + + @Schema(description = "发送消息", requiredMode = Schema.RequiredMode.REQUIRED) + private Message send; + + @Schema(description = "接收消息", requiredMode = Schema.RequiredMode.REQUIRED) + private Message receive; + + @Schema(description = "消息") + @Data + public static class Message { + + @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") + private Long id; + + @Schema(description = "消息类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "role") + private String type; // 参见 MessageType 枚举类 + + @Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊") + private String content; + + @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51") + private LocalDateTime createTime; + + } + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java deleted file mode 100644 index cfd67ccba..000000000 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java +++ /dev/null @@ -1,16 +0,0 @@ -package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message; - -import io.swagger.v3.oas.annotations.media.Schema; -import jakarta.validation.constraints.NotEmpty; -import jakarta.validation.constraints.NotNull; -import lombok.Data; - -@Schema(description = "管理后台 - AI 聊天消息发送 Request VO") -@Data -public class AiChatMessageSendStreamReqVO { - - @Schema(description = "提问的 messageId", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") - @NotNull(message = "提问的 messageId 不能为空") - private Long id; - -} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java index 05f7b83b6..eda556358 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java @@ -27,12 +27,4 @@ public interface AiChatMessageConvert { */ List convertAiChatMessageRespVOList(List aiChatMessageDOList); - /** - * 转换 - aiChatMessageDO - * - * @param aiChatMessageDO - * @return - */ - AiChatMessageRespVO convertAiChatMessageRespVO(AiChatMessageDO aiChatMessageDO); - } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java index 7be2b8afc..57d848eea 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java @@ -22,22 +22,6 @@ public interface AiChatService { */ AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO); - /** - * chat stream - * - * @param sendReqVO - * @return - */ - Flux chatStream(AiChatMessageSendStreamReqVO sendReqVO); - - /** - * 添加 - message - * - * @param sendReqVO - * @return - */ - AiChatMessageRespVO add(AiChatMessageAddReqVO sendReqVO); - /** * 获取 - 获取对话 message list * @@ -54,4 +38,13 @@ public interface AiChatService { */ Boolean deleteMessage(Long id); + /** + * 发送消息 + * + * @param sendReqVO + * @param userId + * @return + */ + Flux sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId); + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java index 4642c648d..1ab160bfc 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java @@ -1,27 +1,24 @@ package cn.iocoder.yudao.module.ai.service.impl; import cn.hutool.core.exceptions.ExceptionUtil; +import cn.hutool.core.util.ObjUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; -import cn.iocoder.yudao.module.ai.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageAddReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendStreamReqVO; import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; -import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; @@ -33,13 +30,16 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; +import java.time.LocalDateTime; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Consumer; import java.util.stream.Collectors; +import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; +import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS; + /** * 聊天 service * @@ -52,11 +52,11 @@ import java.util.stream.Collectors; @AllArgsConstructor public class AiChatServiceImpl implements AiChatService { - private final AiChatClientFactory aiChatClientFactory; + private final AiChatClientFactory chatClientFactory; private final AiChatMessageMapper aiChatMessageMapper; private final AiChatConversationService chatConversationService; - private final AiChatModelService aiChatModalService; + private final AiChatModelService chatModalService; private final AiChatRoleService chatRoleService; @Transactional(rollbackFor = Exception.class) @@ -65,7 +65,7 @@ public class AiChatServiceImpl implements AiChatService { // 查询对话 AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId()); // 获取对话模型 - AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId()); + AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId()); // 获取角色信息 AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; // 获取 client 类型 @@ -84,7 +84,7 @@ public class AiChatServiceImpl implements AiChatService { // req.setTopP(req.getTopP()); // req.setTemperature(req.getTemperature()); // 发送 call 调用 - ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum); + ChatClient chatClient = chatClientFactory.getChatClient(platformEnum); ChatResponse call = chatClient.call(prompt); content = call.getResult().getOutput().getContent(); tokens = call.getResults().size(); @@ -113,88 +113,72 @@ public class AiChatServiceImpl implements AiChatService { .setModelId(modelId) .setContent(content) .setTokens(tokens) - .setTemperature(temperature) .setMaxTokens(maxTokens) .setMaxContexts(maxContexts); + insertChatMessageDO.setCreateTime(LocalDateTime.now()); // 增加 chat message 记录 aiChatMessageMapper.insert(insertChatMessageDO); return insertChatMessageDO; } - public Flux chatStream(AiChatMessageSendStreamReqVO req) { - Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); - // 查询提问的 message - AiChatMessageDO aiChatMessageDO = aiChatMessageMapper.selectById(req.getId()); - if (aiChatMessageDO == null) { - throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST); - } - // 查询对话 - AiChatConversationDO conversation = chatConversationService.validateExists(aiChatMessageDO.getConversationId()); - // 获取对话模型 - AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId()); - // 获取角色信息 - AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; - // 创建 chat 需要的 Prompt - Prompt prompt = new Prompt(aiChatMessageDO.getContent()); - // 提前创建一个 system message - AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), - chatModel.getModel(), chatModel.getId(), "", - 0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); -// req.setTopK(req.getTopK()); -// req.setTopP(req.getTopP()); -// req.setTemperature(req.getTemperature()); - // 获取 client 类型 - AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform()); - StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); - Flux streamResponse = streamingChatClient.stream(prompt); - // 转换 flex AiChatMessageRespVO - StringBuffer contentBuffer = new StringBuffer(); - AtomicInteger tokens = new AtomicInteger(0); - return streamResponse.map(res -> { - AiChatMessageRespVO aiChatMessageRespVO = - AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(systemMessage); - aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent()); - contentBuffer.append(res.getResult().getOutput().getContent()); - tokens.incrementAndGet(); - return aiChatMessageRespVO; - } - ).doOnComplete(new Runnable() { - @Override - public void run() { - log.info("发送完成!"); - // 保存 chat message - aiChatMessageMapper.updateById(new AiChatMessageDO() - .setId(systemMessage.getId()) - .setContent(contentBuffer.toString()) - .setTokens(tokens.get()) - ); - } - }).doOnError(new Consumer() { - @Override - public void accept(Throwable throwable) { - log.error("发送错误 {}!", throwable.getMessage()); - // 更新错误信息 - aiChatMessageMapper.updateById(new AiChatMessageDO() - .setId(systemMessage.getId()) - .setContent(throwable.getMessage()) - .setTokens(tokens.get()) - ); - } - }); - } - @Override - public AiChatMessageRespVO add(AiChatMessageAddReqVO req) { - Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); - // 查询对话 - AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId()); - // 获取对话模型 - AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId()); - AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), - chatModel.getModel(), chatModel.getId(), req.getContent(), + public Flux sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) { + // 1.1 校验对话存在 + AiChatConversationDO conversation = chatConversationService.validateExists(sendReqVO.getConversationId()); + if (ObjUtil.notEqual(conversation.getUserId(), userId)) { + throw exception(CHAT_CONVERSATION_NOT_EXISTS); + } + // 1.2 校验模型 + AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform); + + // 2. 插入 user 发送消息 TODO tokens 计算 + AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(), + conversation.getModel(), conversation.getId(), sendReqVO.getContent(), null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); - return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(userMessage); + + // 3.1 插入 system 接收消息 + AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(), + conversation.getModel(), conversation.getId(), conversation.getSystemMessage(), + 0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); + // 3.2 创建 chat 需要的 Prompt + // TODO 消息上下文 + Prompt prompt = new Prompt(sendReqVO.getContent()); +// ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build() + Flux streamResponse = chatClient.stream(prompt); + // 3.3 转换 flex AiChatMessageRespVO + StringBuffer contentBuffer = new StringBuffer(); + AtomicInteger tokens = new AtomicInteger(0); // TODO token 计算不对; + return streamResponse.map(res -> { + contentBuffer.append(res.getResult().getOutput().getContent()); + tokens.incrementAndGet(); + + AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId()) + .setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime()) + .setContent(sendReqVO.getContent()); + AiChatMessageSendRespVO.Message receive = new AiChatMessageSendRespVO.Message().setId(systemMessage.getId()) + .setType(MessageType.SYSTEM.getValue()).setCreateTime(systemMessage.getCreateTime()) + .setContent(res.getResult().getOutput().getContent()); + return new AiChatMessageSendRespVO().setSend(send).setReceive(receive); + }).doOnComplete(() -> { + log.info("发送完成!"); + // 保存 chat message + aiChatMessageMapper.updateById(new AiChatMessageDO() + .setId(systemMessage.getId()) + .setContent(contentBuffer.toString()) + .setTokens(tokens.get()) + ); + }).doOnError(throwable -> { + log.error("发送错误 {}!", throwable.getMessage()); + // 更新错误信息 TODO 貌似不应该更新异常 + aiChatMessageMapper.updateById(new AiChatMessageDO() + .setId(systemMessage.getId()) + .setContent(throwable.getMessage()) + .setTokens(tokens.get()) + ); + }); } @Override @@ -205,7 +189,7 @@ public class AiChatServiceImpl implements AiChatService { List aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId); // 获取模型信息 Set modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet()); - List modalList = aiChatModalService.getModalByIds(modalIds); + List modalList = chatModalService.getModalByIds(modalIds); Map modalIdMap = modalList.stream().collect(Collectors.toMap(AiChatModelDO::getId, o -> o)); // 转换 AiChatMessageRespVO List aiChatMessageRespList = AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/models/yiyan/YiYanChatClient.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/models/yiyan/YiYanChatClient.java index 819976629..5d1dd4f76 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/models/yiyan/YiYanChatClient.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/models/yiyan/YiYanChatClient.java @@ -94,7 +94,10 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient { String a = ";"; } }); - return response.map(res -> new ChatResponse(List.of(new Generation(res.getResult())))); + return response.map(res -> { + // TODO @fan:这里缺少了 usage 的封装 + return new ChatResponse(List.of(new Generation(res.getResult()))); + }); } private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {