diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 6c8cdeaca..72fa06a79 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -111,7 +111,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); - // 3.2 创建 chat 需要的 Prompt + // 3.2 构建 Prompt,并进行调用 Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); Flux streamResponse = chatModel.stream(prompt); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java index 7b49ee807..72be20c54 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java @@ -32,13 +32,12 @@ import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; -import java.util.Objects; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; /** - * AI 写作 Service 实现类 + * AI 思维导图 Service 实现类 * * @author xiaoxin */ @@ -58,30 +57,28 @@ public class AiMindMapServiceImpl implements AiMindMapService { @Override public Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { - // 1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型 - AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); + // 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型 + AiChatRoleDO role = CollUtil.getFirst( + chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); // 1.1 获取脑图执行模型 - AiChatModelDO model = getModel(mindMapRole); + AiChatModelDO model = getModel(role); // 1.2 获取角色设定消息 - String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage()) - ? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); + String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage()) + ? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); - // 2 插入思维导图信息 + // 2. 插入思维导图信息 AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); mindMapMapper.insert(mindMapDO); - ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); - // 3.1 角色设定 - List chatMessages = buildMessages(generateReqVO, systemMessage); - // 3.3 构建提示词 - Prompt prompt = new Prompt(chatMessages, chatOptions); - + // 3.1 构建 Prompt,并进行调用 + Prompt prompt = buildPrompt(generateReqVO, model, systemMessage); Flux streamResponse = chatModel.stream(prompt); - // 3.4 流式返回 + + // 3.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; @@ -102,24 +99,32 @@ public class AiMindMapServiceImpl implements AiMindMapService { } + private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) { + // 1. 构建 message 列表 + List chatMessages = buildMessages(generateReqVO, systemMessage); + // 2. 构建 options 对象 + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); + return new Prompt(chatMessages, options); + } + private static List buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) { List chatMessages = new ArrayList<>(); + // 1. 角色设定 if (StrUtil.isNotBlank(systemMessage)) { - // 1.1 角色设定 chatMessages.add(new SystemMessage(systemMessage)); } - // 1.2 用户输入 + // 2. 用户输入 chatMessages.add(new UserMessage(generateReqVO.getPrompt())); return chatMessages; } - // TODO 芋艿:这里脑图、写作都用到了,是不是可以抽哪里去 - private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) { + private AiChatModelDO getModel(AiChatRoleDO role) { AiChatModelDO model = null; - if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) { - model = chatModalService.getChatModel(chatRoleDO.getModelId()); + if (role != null && role.getModelId() != null) { + model = chatModalService.getChatModel(role.getModelId()); } - if (Objects.isNull(model)) { + if (model != null) { model = chatModalService.getRequiredDefaultChatModel(); } Assert.notNull(model, "[AI] 获取不到模型"); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index 4b583e3c1..2fae31d59 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -68,8 +68,9 @@ public class AiWriteServiceImpl implements AiWriteService { @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { - // 1 获取写作模型 尝试获取写作助手角色,没有则使用默认模型 - AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); + // 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型 + AiChatRoleDO writeRole = CollUtil.getFirst( + chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); // 1.1 获取写作执行模型 AiChatModelDO model = getModel(writeRole); // 1.2 获取角色设定消息 @@ -84,16 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService { write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel())); writeMapper.insert(writeDO); - // 3. 调用大模型,写作生成 - ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); - // 3.1 构建消息列表 - List chatMessages = buildMessages(generateReqVO, systemMessage); - // 3.2 构建提示词 - Prompt prompt = new Prompt(chatMessages, chatOptions); - // 3.3 流式调用 + // 3.1 构建 Prompt,并进行调用 + Prompt prompt = buildPrompt(generateReqVO, model, systemMessage); Flux streamResponse = chatModel.stream(prompt); - // 4. 流式返回 + // 3.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; @@ -125,6 +121,15 @@ public class AiWriteServiceImpl implements AiWriteService { return model; } + private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) { + // 1. 构建 message 列表 + List chatMessages = buildMessages(generateReqVO, systemMessage); + // 2. 构建 options 对象 + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); + return new Prompt(chatMessages, options); + } + private List buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) { List chatMessages = new ArrayList<>(); if (StrUtil.isNotBlank(systemMessage)) { @@ -132,11 +137,11 @@ public class AiWriteServiceImpl implements AiWriteService { chatMessages.add(new SystemMessage(systemMessage)); } // 1.2 用户输入 - chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO))); + chatMessages.add(new UserMessage(buildUserMessage(generateReqVO))); return chatMessages; } - private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { + private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) { String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());