diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
index 92222b590..0442a52d7 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
@@ -12,8 +12,7 @@ import lombok.Data;
*
* @author xiaoxin
*/
-// TODO @xin:如果没 typehandler 的需求,autoResultMap 可以去掉哈
-@TableName(value = "ai_mind_map", autoResultMap = true)
+@TableName(value = "ai_mind_map")
@Data
public class AiMindMapDO extends BaseDO {
@@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO {
/**
* 用户编号
- *
+ *
* 关联 AdminUserDO 的 userId 字段
*/
private Long userId;
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
index 7d96c70d2..7b49ee807 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
@@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.mindmap;
import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@@ -57,33 +58,25 @@ public class AiMindMapServiceImpl implements AiMindMapService {
@Override
public Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
- // 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
+ // 1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
- AiChatModelDO model;
- String systemMessage;
- if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) {
- model = chatModalService.getChatModel(mindMapRole.getModelId());
- systemMessage = mindMapRole.getSystemMessage();
- } else {
- model = chatModalService.getRequiredDefaultChatModel();
- systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
- }
-
+ // 1.1 获取脑图执行模型
+ AiChatModelDO model = getModel(mindMapRole);
+ // 1.2 获取角色设定消息
+ String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage())
+ ? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
+ // 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 2 插入思维导图信息
- AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
+ AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
+ mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
mindMapMapper.insert(mindMapDO);
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
// 3.1 角色设定
- List chatMessages = new ArrayList<>();
- if (StrUtil.isNotBlank(systemMessage)) {
- chatMessages.add(new SystemMessage(systemMessage));
- }
- // 3.2 用户输入
- chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
+ List chatMessages = buildMessages(generateReqVO, systemMessage);
// 3.3 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions);
@@ -109,4 +102,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
}
+ private static List buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
+ List chatMessages = new ArrayList<>();
+ if (StrUtil.isNotBlank(systemMessage)) {
+ // 1.1 角色设定
+ chatMessages.add(new SystemMessage(systemMessage));
+ }
+ // 1.2 用户输入
+ chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
+ return chatMessages;
+ }
+
+ // TODO 芋艿:这里脑图、写作都用到了,是不是可以抽哪里去
+ private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) {
+ AiChatModelDO model = null;
+ if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) {
+ model = chatModalService.getChatModel(chatRoleDO.getModelId());
+ }
+ if (Objects.isNull(model)) {
+ model = chatModalService.getRequiredDefaultChatModel();
+ }
+ Assert.notNull(model, "[AI] 获取不到模型");
+ return model;
+ }
+
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
index b03a90ab7..4b583e3c1 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
@@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@@ -67,19 +68,14 @@ public class AiWriteServiceImpl implements AiWriteService {
@Override
public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
- // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
+ // 1 获取写作模型 尝试获取写作助手角色,没有则使用默认模型
AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
- // TODO @xin:如果有 writeRole,但是没 modeId,是不是也可以用 systemMessage 哈?建议的写法是:先通过 modelId 获取 model。如果 model == null,则 chatModalService.getRequiredDefaultChatModel();如果还是 null,则抛出异常;。。。。。。。。。。。。。。然后,systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理。
- AiChatModelDO model;
- String systemMessage;
- if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
- model = chatModalService.getChatModel(writeRole.getModelId());
- systemMessage = writeRole.getSystemMessage();
- } else {
- model = chatModalService.getRequiredDefaultChatModel();
- systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
- }
- // 1.2 校验平台
+ // 1.1 获取写作执行模型
+ AiChatModelDO model = getModel(writeRole);
+ // 1.2 获取角色设定消息
+ String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
+ ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
+ // 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
@@ -90,16 +86,11 @@ public class AiWriteServiceImpl implements AiWriteService {
// 3. 调用大模型,写作生成
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
- // 3.1 角色设定
- // TODO @xin:要不把 90 到 97 这部分,合并到一个方法里。目的是:让这个方法的主干更明确
- List chatMessages = new ArrayList<>();
- if (StrUtil.isNotBlank(systemMessage)) {
- chatMessages.add(new SystemMessage(systemMessage));
- }
- // 3.2 用户输入
- chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
- // 3.3 构建提示词
+ // 3.1 构建消息列表
+ List chatMessages = buildMessages(generateReqVO, systemMessage);
+ // 3.2 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions);
+ // 3.3 流式调用
Flux streamResponse = chatModel.stream(prompt);
// 4. 流式返回
@@ -122,6 +113,29 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
}
+ private AiChatModelDO getModel(AiChatRoleDO writeRole) {
+ AiChatModelDO model = null;
+ if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
+ model = chatModalService.getChatModel(writeRole.getModelId());
+ }
+ if (Objects.isNull(model)) {
+ model = chatModalService.getRequiredDefaultChatModel();
+ }
+ Assert.notNull(model, "[AI] 获取不到模型");
+ return model;
+ }
+
+ private List buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
+ List chatMessages = new ArrayList<>();
+ if (StrUtil.isNotBlank(systemMessage)) {
+ // 1.1 角色设定
+ chatMessages.add(new SystemMessage(systemMessage));
+ }
+ // 1.2 用户输入
+ chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
+ return chatMessages;
+ }
+
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());