diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java index 9cd27dd3a..49c7ee54b 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java @@ -19,10 +19,10 @@ public class AiChatMessageRespVO { private String type; // 参见 MessageType 枚举类 @Schema(description = "用户编号", example = "4096") - private Long userId; // 仅当 user 发送时非空 + private Long userId; @Schema(description = "角色编号", example = "888") - private Long roleId; // 仅当 assistant 回复时非空 + private Long roleId; @Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "gpt-3.5-turbo") private String model; // 参见 AiOpenAiModelEnum 枚举类 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java index 994947724..c66537673 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java @@ -47,16 +47,12 @@ public class AiChatMessageDO extends BaseDO { /** * 用户编号 * - * 仅当 user 发送时非空 - * * 关联 AdminUserDO 的 userId 字段 */ private Long userId; /** * 角色编号 * - * 仅当 assistant 回复时非空 - * * 关联 {@link AiChatRoleDO#getId()} 字段 */ private Long roleId; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java index 1e9ff242a..335c20789 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java @@ -1,23 +1,22 @@ package cn.iocoder.yudao.module.ai.service.impl; -import cn.hutool.core.exceptions.ExceptionUtil; import cn.hutool.core.util.ObjUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; -import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; @@ -30,10 +29,7 @@ import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; import java.time.LocalDateTime; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.*; import java.util.stream.Collectors; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; @@ -53,64 +49,49 @@ public class AiChatServiceImpl implements AiChatService { private final AiChatClientFactory chatClientFactory; - private final AiChatMessageMapper aiChatMessageMapper; + private final AiChatMessageMapper chatMessageMapper; + private final AiChatConversationService chatConversationService; private final AiChatModelService chatModalService; private final AiChatRoleService chatRoleService; @Transactional(rollbackFor = Exception.class) public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) { - Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); - // 查询对话 - AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId()); - // 获取对话模型 - AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId()); - // 获取角色信息 - AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; - // 获取 client 类型 - AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform()); - // 保存 chat message - insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), - chatModel.getModel(), chatModel.getId(), req.getContent()); - String content = null; - int tokens = 0; - try { - // 创建 chat 需要的 Prompt - Prompt prompt = new Prompt(req.getContent()); - // TODO @芋艿 @范 看要不要支持这些 -// req.setTopK(req.getTopK()); -// req.setTopP(req.getTopP()); -// req.setTemperature(req.getTemperature()); - // 发送 call 调用 - ChatClient chatClient = chatClientFactory.getChatClient(platformEnum); - ChatResponse call = chatClient.call(prompt); - content = call.getResult().getOutput().getContent(); - tokens = call.getResults().size(); - // 更新 conversation - } catch (Exception e) { - content = ExceptionUtil.getMessage(e); - } finally { - // 保存 chat message - insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), - chatModel.getModel(), chatModel.getId(), content); - } - return new AiChatMessageRespVO().setContent(content); - } - - private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId, - String model, Long modelId, String content) { - AiChatMessageDO insertChatMessageDO = new AiChatMessageDO() - .setConversationId(conversationId) - .setType(messageType.getValue()) - .setUserId(loginUserId) - .setRoleId(roleId) - .setModel(model) - .setModelId(modelId) - .setContent(content); - insertChatMessageDO.setCreateTime(LocalDateTime.now()); - // 增加 chat message 记录 - aiChatMessageMapper.insert(insertChatMessageDO); - return insertChatMessageDO; + return null; // TODO 芋艿:一起改 +// Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); +// // 查询对话 +// AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId()); +// // 获取对话模型 +// AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId()); +// // 获取角色信息 +// AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; +// // 获取 client 类型 +// AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform()); +// // 保存 chat message +// createChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), +// chatModel.getModel(), chatModel.getId(), req.getContent()); +// String content = null; +// int tokens = 0; +// try { +// // 创建 chat 需要的 Prompt +// Prompt prompt = new Prompt(req.getContent()); +// // TODO @芋艿 @范 看要不要支持这些 +//// req.setTopK(req.getTopK()); +//// req.setTopP(req.getTopP()); +//// req.setTemperature(req.getTemperature()); +// // 发送 call 调用 +// ChatClient chatClient = chatClientFactory.getChatClient(platformEnum); +// ChatResponse call = chatClient.call(prompt); +// content = call.getResult().getOutput().getContent(); +// // 更新 conversation +// } catch (Exception e) { +// content = ExceptionUtil.getMessage(e); +// } finally { +// // 保存 chat message +// createChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), +// chatModel.getModel(), chatModel.getId(), content); +// } +// return new AiChatMessageRespVO().setContent(content); } @Override @@ -120,55 +101,78 @@ public class AiChatServiceImpl implements AiChatService { if (ObjUtil.notEqual(conversation.getUserId(), userId)) { throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿:异常情况的对接; } + List historyMessages = chatMessageMapper.selectByConversationId(conversation.getId()); // 1.2 校验模型 AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform); // 2. 插入 user 发送消息 - AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(), - conversation.getModel(), conversation.getId(), sendReqVO.getContent()); + AiChatMessageDO userMessage = createChatMessage(conversation.getId(), model, + userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent()); + + // 3.1 插入 assistant 接收消息 + AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), model, + userId, conversation.getRoleId(), MessageType.ASSISTANT, ""); - // 3.1 插入 system 接收消息 - AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(), - conversation.getModel(), conversation.getId(), conversation.getSystemMessage()); // 3.2 创建 chat 需要的 Prompt // TODO 消息上下文 - Prompt prompt = new Prompt(sendReqVO.getContent()); -// ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build() + Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO); Flux streamResponse = chatClient.stream(prompt); - // 3.3 转换 flex AiChatMessageRespVO + + // 3.3 流式返回 StringBuffer contentBuffer = new StringBuffer(); - return streamResponse.map(res -> { - contentBuffer.append(res.getResult().getOutput().getContent()); - AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId()) - .setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime()) - .setContent(sendReqVO.getContent()); - AiChatMessageSendRespVO.Message receive = new AiChatMessageSendRespVO.Message().setId(systemMessage.getId()) - .setType(MessageType.SYSTEM.getValue()).setCreateTime(systemMessage.getCreateTime()) - .setContent(res.getResult().getOutput().getContent()); - return new AiChatMessageSendRespVO().setSend(send).setReceive(receive); + return streamResponse.map(response -> { + String newContent = response.getResult().getOutput().getContent(); + contentBuffer.append(newContent); + // 响应结果 + return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) + .setReceive(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent)); }).doOnComplete(() -> { - log.info("发送完成!"); - // 保存 chat message - aiChatMessageMapper.updateById(new AiChatMessageDO() - .setId(systemMessage.getId()) - .setContent(contentBuffer.toString())); + chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())); }).doOnError(throwable -> { - log.error("发送错误 {}!", throwable.getMessage()); - // 更新错误信息 TODO 貌似不应该更新异常 - aiChatMessageMapper.updateById(new AiChatMessageDO() - .setId(systemMessage.getId()) - .setContent(throwable.getMessage())); + log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable); + chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())); }); } + private Prompt buildPrompt(AiChatConversationDO conversation, List messages, AiChatMessageSendReqVO sendReqVO) { + // TODO 芋艿:1)保留 n 个上下文;2)每一轮 token 数量 +// if (conversation.getMaxContexts() != null && messages.size() > conversation.getMaxContexts()) { +// +// } + // 1. 构建 Prompt Message 列表 + List chatMessages = new ArrayList<>(); + // 1.1 system context 角色设定 + chatMessages.add(new SystemMessage(conversation.getSystemMessage())); + // 1.2 history message 历史消息 + messages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent()))); + // 1.3 user message 新发送消息 + chatMessages.add(new UserMessage(sendReqVO.getContent())); + + // 2. 构建 ChatOptions 对象 + ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build(); + return new Prompt(chatMessages, chatOptions); + } + + private AiChatMessageDO createChatMessage(Long conversationId, AiChatModelDO model, + Long userId, Long roleId, + MessageType messageType, String content) { + AiChatMessageDO message = new AiChatMessageDO() + .setConversationId(conversationId).setModel(model.getModel()).setModelId(model.getId()) + .setUserId(userId).setRoleId(roleId) + .setType(messageType.getValue()).setContent(content); + message.setCreateTime(LocalDateTime.now()); + chatMessageMapper.insert(message); + return message; + } + @Override public List getMessageListByConversationId(Long conversationId) { // 校验对话是否存在 chatConversationService.validateExists(conversationId); // 获取对话所有 message - List aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId); + List aiChatMessageDOList = chatMessageMapper.selectByConversationId(conversationId); // 获取模型信息 Set modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet()); List modalList = chatModalService.getModalByIds(modalIds); @@ -187,7 +191,7 @@ public class AiChatServiceImpl implements AiChatService { @Override public Boolean deleteMessage(Long id) { - return aiChatMessageMapper.deleteById(id) > 0; + return chatMessageMapper.deleteById(id) > 0; } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java index e9e1f418e..66228ff22 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java @@ -15,13 +15,13 @@ import lombok.Getter; public enum AiPlatformEnum { OPENAI("OpenAI", "OpenAI"), - OLLAMA("dall", "dall"), + OLLAMA("Ollama", "Ollama"), YI_YAN("yiyan", "一言"), QIAN_WEN("qianwen", "千问"), XING_HUO("xinghuo", "星火"), OPEN_AI_DALL("dall", "dall"), - MIDJOURNEY("Ollama", "Ollama"), + MIDJOURNEY("midjourney", "midjourney"), ;