diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java index c954a1667..cf5d6e76e 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java @@ -30,6 +30,7 @@ public interface ErrorCodeConstants { // role ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "chatRole 不存在!"); + ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!"); // modal diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java index 473f1c856..63e947e59 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java @@ -32,7 +32,7 @@ public class AiChatMessageController { @PostMapping("/send") public CommonResult sendMessage(@Validated @ModelAttribute AiChatMessageSendReqVO sendReqVO) { // TODO @fan:使用 static import;这样就 success 就行了; - return success(null); + return success(chatService.chat(sendReqVO)); } // TODO @芋艿:调用这个方法异常,Unable to handle the Spring Security Exception because the response is already committed.;可以再试试 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java index 720736f21..70ff21fc5 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java @@ -21,6 +21,9 @@ public class AiChatConversationRespVO { @Schema(description = "是否置顶", requiredMode = Schema.RequiredMode.REQUIRED, example = "true") private Boolean pinned; + @Schema(description = "角色编号", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1") + private Long roleId; + @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") private Long modelId; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModalService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModalService.java index 60511c8a6..5740328be 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModalService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModalService.java @@ -52,4 +52,11 @@ public interface AiChatModalService { * @return */ AiChatModalRes getChatModalOfValidate(Long modalId); + + /** + * 校验 - 校验是否可用 + * + * @param chatModal + */ + void validateAvailable(AiChatModalRes chatModal); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatRoleService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatRoleService.java index 96da840f8..5f11e4ffc 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatRoleService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatRoleService.java @@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.role.*; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; /** * chat 角色 @@ -58,4 +59,19 @@ public interface AiChatRoleService { * @return */ AiChatRoleRes getChatRole(Long roleId); + + /** + * 校验 - 角色是否存在 + * + * @param id + * @return + */ + AiChatRoleDO validateExists(Long id); + + /** + * 校验 - 角色是否公开 + * + * @param aiChatRoleDO + */ + void validateIsPublic(AiChatRoleDO aiChatRoleDO); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java index c054e41fa..e25062e78 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.ai.service; import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; /** @@ -15,10 +16,10 @@ public interface AiChatService { /** * chat * - * @param req + * @param sendReqVO * @return */ - String chat(AiChatMessageSendReqVO req); + AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO); /** * chat stream diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatConversationServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatConversationServiceImpl.java index 4f6a2cc52..6fbafc659 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatConversationServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatConversationServiceImpl.java @@ -15,7 +15,6 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper; -import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum; import cn.iocoder.yudao.module.ai.service.AiChatConversationService; import cn.iocoder.yudao.module.ai.service.AiChatModalService; import cn.iocoder.yudao.module.ai.service.AiChatRoleService; @@ -91,9 +90,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService // 获取模型信息并验证 AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(updateReqVO.getModelId()); // 校验modal是否可用 - if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) { - throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED); - } + aiChatModalService.validateAvailable(chatModal); // 更新对话信息 AiChatConversationDO updateAiChatConversationDO = AiChatConversationConvert.INSTANCE.convertAiChatConversationDO(updateReqVO); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java index 67013bdda..89c85f8ee 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java @@ -116,6 +116,14 @@ public class AiChatModalServiceImpl implements AiChatModalService { return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO); } + @Override + public void validateAvailable(AiChatModalRes chatModal) { + // 对话模型是否可用 + if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) { + throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED); + } + } + private AiChatModalDO validateChatModalExists(Long id) { AiChatModalDO aiChatModalDO = aiChatModalMapper.selectById(id); if (aiChatModalDO == null) { diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatRoleServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatRoleServiceImpl.java index fc882e020..95ed67861 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatRoleServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatRoleServiceImpl.java @@ -70,7 +70,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService { AiChatRoleClassifyEnum.valueOfClassify(req.getClassify()); AiChatRoleEnableEnum.valueOfType(req.getEnable()); // 检查角色是否存在 - validateChatRoleExists(id); + validateExists(id); // 转换do AiChatRoleDO updateChatRole = AiChatRoleConvert.INSTANCE.convertAiChatRoleDO(req); updateChatRole.setId(id); @@ -83,7 +83,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService { // 转换enum,并校验enum AiChatRoleEnableEnum.valueOfType(req.getEnable()); // 检查角色是否存在 - validateChatRoleExists(id); + validateExists(id); // 更新 aiChatRoleMapper.updateById(new AiChatRoleDO() .setId(id) @@ -94,7 +94,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService { @Override public void delete(Long chatRoleId) { // 检查角色是否存在 - validateChatRoleExists(chatRoleId); + validateExists(chatRoleId); // 删除 aiChatRoleMapper.deleteById(chatRoleId); } @@ -102,15 +102,25 @@ public class AiChatRoleServiceImpl implements AiChatRoleService { @Override public AiChatRoleRes getChatRole(Long roleId) { // 检查角色是否存在 - AiChatRoleDO aiChatRoleDO = validateChatRoleExists(roleId); + AiChatRoleDO aiChatRoleDO = validateExists(roleId); return AiChatRoleConvert.INSTANCE.convertAiChatRoleRes(aiChatRoleDO); } - private AiChatRoleDO validateChatRoleExists(Long id) { + public AiChatRoleDO validateExists(Long id) { AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(id); if (aiChatRoleDO == null) { throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXIST); } return aiChatRoleDO; } + + public void validateIsPublic(AiChatRoleDO aiChatRoleDO) { + if (aiChatRoleDO == null) { + return; + } + if (!aiChatRoleDO.getPublicStatus()) { + throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_PUBLIC); + } + } } + diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java index f833f60d9..b5e2634de 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java @@ -10,14 +10,19 @@ import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper; import cn.iocoder.yudao.module.ai.service.AiChatConversationService; +import cn.iocoder.yudao.module.ai.service.AiChatModalService; +import cn.iocoder.yudao.module.ai.service.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.AiChatService; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.http.MediaType; @@ -45,29 +50,39 @@ public class AiChatServiceImpl implements AiChatService { private final AiChatMessageMapper aiChatMessageMapper; private final AiChatConversationMapper aiChatConversationMapper; private final AiChatConversationService chatConversationService; + private final AiChatModalService aiChatModalService; + private final AiChatRoleService aiChatRoleService; - /** - * chat - * - * @param req - * @return - */ @Transactional(rollbackFor = Exception.class) - public String chat(AiChatMessageSendReqVO req) { + public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) { Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); + // 查询对话 + AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId()); + // 获取对话模型 + AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId()); + // 对话模型是否可用 + aiChatModalService.validateAvailable(chatModal); + // 获取角色信息 + AiChatRoleDO aiChatRoleDO = null; + if (conversation.getRoleId() != null) { + aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId()); + } + // 校验角色是否公开 + aiChatRoleService.validateIsPublic(aiChatRoleDO); // 获取 client 类型 - AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal()); - // 获取对话信息 - AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId()); + AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal()); // 保存 chat message - saveChatMessage(req, conversationRes, loginUserId); + insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), + chatModal.getModal(), chatModal.getId(), req.getContent(), + null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); String content = null; try { // 创建 chat 需要的 Prompt - Prompt prompt = new Prompt(req.getPrompt()); - req.setTopK(req.getTopK()); - req.setTopP(req.getTopP()); - req.setTemperature(req.getTemperature()); + Prompt prompt = new Prompt(req.getContent()); + // TODO @芋艿 @范 看要不要支持这些 +// req.setTopK(req.getTopK()); +// req.setTopP(req.getTopP()); +// req.setTemperature(req.getTemperature()); // 发送 call 调用 ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum); ChatResponse call = chatClient.call(prompt); @@ -78,69 +93,66 @@ public class AiChatServiceImpl implements AiChatService { content = ExceptionUtil.getMessage(e); } finally { // 保存 chat message - saveSystemChatMessage(req, conversationRes, loginUserId, content); + insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), + chatModal.getModal(), chatModal.getId(), req.getContent(), + null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); } - return content; + return new AiChatMessageRespVO().setContent(content); } - private void saveChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId) { - Long chatConversationId = conversationRes.getId(); + private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId, + String model, Long modelId, String content, Integer tokens, Double temperature, + Integer maxTokens, Integer maxContexts) { + AiChatMessageDO insertChatMessageDO = new AiChatMessageDO() + .setId(null) + .setConversationId(conversationId) + .setType(messageType.getValue()) + .setUserId(loginUserId) + .setRoleId(roleId) + .setModel(model) + .setModelId(modelId) + .setContent(content) + .setTokens(tokens) + + .setTemperature(temperature) + .setMaxTokens(maxTokens) + .setMaxContexts(maxContexts); // 增加 chat message 记录 - aiChatMessageMapper.insert( - new AiChatMessageDO() - .setId(null) - .setConversationId(chatConversationId) - .setUserId(loginUserId) - .setMessage(req.getPrompt()) - .setMessageType(MessageType.USER.getValue()) - .setTopK(req.getTopK()) - .setTopP(req.getTopP()) - .setTemperature(req.getTemperature()) - ); + aiChatMessageMapper.insert(insertChatMessageDO); // chat count 先+1 - aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); + aiChatConversationMapper.updateIncrChatCount(conversationId); + return insertChatMessageDO; } - public void saveSystemChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId, String systemPrompts) { - Long chatConversationId = conversationRes.getId(); - // 增加 chat message 记录 - aiChatMessageMapper.insert( - new AiChatMessageDO() - .setId(null) - .setConversationId(chatConversationId) - .setUserId(loginUserId) - .setMessage(systemPrompts) - .setMessageType(MessageType.SYSTEM.getValue()) - .setTopK(req.getTopK()) - .setTopP(req.getTopP()) - .setTemperature(req.getTemperature()) - ); - - // chat count 先+1 - aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); - } - - /** - * chat stream - * - * @param req - * @param sseEmitter - * @return - */ @Override public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) { Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); - // 获取 client 类型 - AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal()); - // 获取对话信息 - AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId()); + // 查询对话 + AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId()); + // 获取对话模型 + AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId()); + // 对话模型是否可用 + aiChatModalService.validateAvailable(chatModal); + // 获取角色信息 + AiChatRoleDO aiChatRoleDO = null; + if (conversation.getRoleId() != null) { + aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId()); + } + // 校验角色是否公开 + aiChatRoleService.validateIsPublic(aiChatRoleDO); // 创建 chat 需要的 Prompt - Prompt prompt = new Prompt(req.getPrompt()); - req.setTopK(req.getTopK()); - req.setTopP(req.getTopP()); - req.setTemperature(req.getTemperature()); + Prompt prompt = new Prompt(req.getContent()); +// req.setTopK(req.getTopK()); +// req.setTopP(req.getTopP()); +// req.setTemperature(req.getTemperature()); // 保存 chat message - saveChatMessage(req, conversationRes, loginUserId); + // 保存 chat message + insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), + chatModal.getModal(), chatModal.getId(), req.getContent(), + null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); + + // 获取 client 类型 + AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal()); StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); Flux streamResponse = streamingChatClient.stream(prompt); @@ -168,7 +180,10 @@ public class AiChatServiceImpl implements AiChatService { log.info("发送完成!"); sseEmitter.complete(); // 保存 chat message - saveSystemChatMessage(req, conversationRes, loginUserId, contentBuffer.toString()); + insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), + chatModal.getModal(), chatModal.getId(), req.getContent(), + null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); + } ); }