From 29e421432d828bad49a147136bec8a7116dcf04d Mon Sep 17 00:00:00 2001 From: xiaoxin <718949661@qq.com> Date: Thu, 11 Jul 2024 10:14:59 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E8=A7=A3=E5=86=B3todo=E3=80=91AI=20?= =?UTF-8?q?=E5=86=99=E4=BD=9C=E3=80=81=E8=84=91=E5=9B=BE=EF=BC=9Amodel?= =?UTF-8?q?=E3=80=81systemMessage=E8=8E=B7=E5=8F=96=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dal/dataobject/mindmap/AiMindMapDO.java | 5 +- .../service/mindmap/AiMindMapServiceImpl.java | 53 ++++++++++++------ .../ai/service/write/AiWriteServiceImpl.java | 56 ++++++++++++------- 3 files changed, 72 insertions(+), 42 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java index 92222b590..0442a52d7 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java @@ -12,8 +12,7 @@ import lombok.Data; * * @author xiaoxin */ -// TODO @xin:如果没 typehandler 的需求,autoResultMap 可以去掉哈 -@TableName(value = "ai_mind_map", autoResultMap = true) +@TableName(value = "ai_mind_map") @Data public class AiMindMapDO extends BaseDO { @@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO { /** * 用户编号 - * + *

* 关联 AdminUserDO 的 userId 字段 */ private Long userId; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java index 7d96c70d2..7b49ee807 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.ai.service.mindmap; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.lang.Assert; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; @@ -57,33 +58,25 @@ public class AiMindMapServiceImpl implements AiMindMapService { @Override public Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { - // 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型 + // 1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型 AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); - AiChatModelDO model; - String systemMessage; - if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) { - model = chatModalService.getChatModel(mindMapRole.getModelId()); - systemMessage = mindMapRole.getSystemMessage(); - } else { - model = chatModalService.getRequiredDefaultChatModel(); - systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); - } - + // 1.1 获取脑图执行模型 + AiChatModelDO model = getModel(mindMapRole); + // 1.2 获取角色设定消息 + String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage()) + ? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); + // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); // 2 插入思维导图信息 - AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); + AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, + mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); mindMapMapper.insert(mindMapDO); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); // 3.1 角色设定 - List chatMessages = new ArrayList<>(); - if (StrUtil.isNotBlank(systemMessage)) { - chatMessages.add(new SystemMessage(systemMessage)); - } - // 3.2 用户输入 - chatMessages.add(new UserMessage(generateReqVO.getPrompt())); + List chatMessages = buildMessages(generateReqVO, systemMessage); // 3.3 构建提示词 Prompt prompt = new Prompt(chatMessages, chatOptions); @@ -109,4 +102,28 @@ public class AiMindMapServiceImpl implements AiMindMapService { } + private static List buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) { + List chatMessages = new ArrayList<>(); + if (StrUtil.isNotBlank(systemMessage)) { + // 1.1 角色设定 + chatMessages.add(new SystemMessage(systemMessage)); + } + // 1.2 用户输入 + chatMessages.add(new UserMessage(generateReqVO.getPrompt())); + return chatMessages; + } + + // TODO 芋艿:这里脑图、写作都用到了,是不是可以抽哪里去 + private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) { + AiChatModelDO model = null; + if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) { + model = chatModalService.getChatModel(chatRoleDO.getModelId()); + } + if (Objects.isNull(model)) { + model = chatModalService.getRequiredDefaultChatModel(); + } + Assert.notNull(model, "[AI] 获取不到模型"); + return model; + } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index b03a90ab7..4b583e3c1 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.ai.service.write; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.lang.Assert; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; @@ -67,19 +68,14 @@ public class AiWriteServiceImpl implements AiWriteService { @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { - // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型 + // 1 获取写作模型 尝试获取写作助手角色,没有则使用默认模型 AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); - // TODO @xin:如果有 writeRole,但是没 modeId,是不是也可以用 systemMessage 哈?建议的写法是:先通过 modelId 获取 model。如果 model == null,则 chatModalService.getRequiredDefaultChatModel();如果还是 null,则抛出异常;。。。。。。。。。。。。。。然后,systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理。 - AiChatModelDO model; - String systemMessage; - if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { - model = chatModalService.getChatModel(writeRole.getModelId()); - systemMessage = writeRole.getSystemMessage(); - } else { - model = chatModalService.getRequiredDefaultChatModel(); - systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage(); - } - // 1.2 校验平台 + // 1.1 获取写作执行模型 + AiChatModelDO model = getModel(writeRole); + // 1.2 获取角色设定消息 + String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage()) + ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage(); + // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); @@ -90,16 +86,11 @@ public class AiWriteServiceImpl implements AiWriteService { // 3. 调用大模型,写作生成 ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); - // 3.1 角色设定 - // TODO @xin:要不把 90 到 97 这部分,合并到一个方法里。目的是:让这个方法的主干更明确 - List chatMessages = new ArrayList<>(); - if (StrUtil.isNotBlank(systemMessage)) { - chatMessages.add(new SystemMessage(systemMessage)); - } - // 3.2 用户输入 - chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO))); - // 3.3 构建提示词 + // 3.1 构建消息列表 + List chatMessages = buildMessages(generateReqVO, systemMessage); + // 3.2 构建提示词 Prompt prompt = new Prompt(chatMessages, chatOptions); + // 3.3 流式调用 Flux streamResponse = chatModel.stream(prompt); // 4. 流式返回 @@ -122,6 +113,29 @@ public class AiWriteServiceImpl implements AiWriteService { }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); } + private AiChatModelDO getModel(AiChatRoleDO writeRole) { + AiChatModelDO model = null; + if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { + model = chatModalService.getChatModel(writeRole.getModelId()); + } + if (Objects.isNull(model)) { + model = chatModalService.getRequiredDefaultChatModel(); + } + Assert.notNull(model, "[AI] 获取不到模型"); + return model; + } + + private List buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) { + List chatMessages = new ArrayList<>(); + if (StrUtil.isNotBlank(systemMessage)) { + // 1.1 角色设定 + chatMessages.add(new SystemMessage(systemMessage)); + } + // 1.2 用户输入 + chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO))); + return chatMessages; + } + private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());