!14 【解决todo】AI 写作、脑图:model、systemMessage获取逻辑调整

Merge pull request !14 from 小新/master-jdk21-ai
This commit is contained in:
芋道源码 2024-07-12 00:48:48 +00:00 committed by Gitee
commit ecb50c6511
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 72 additions and 42 deletions

View File

@ -12,8 +12,7 @@ import lombok.Data;
* *
* @author xiaoxin * @author xiaoxin
*/ */
// TODO @xin如果没 typehandler 的需求autoResultMap 可以去掉哈 @TableName(value = "ai_mind_map")
@TableName(value = "ai_mind_map", autoResultMap = true)
@Data @Data
public class AiMindMapDO extends BaseDO { public class AiMindMapDO extends BaseDO {
@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO {
/** /**
* 用户编号 * 用户编号
* * <p>
* 关联 AdminUserDO userId 字段 * 关联 AdminUserDO userId 字段
*/ */
private Long userId; private Long userId;

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.mindmap; package cn.iocoder.yudao.module.ai.service.mindmap;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@ -57,33 +58,25 @@ public class AiMindMapServiceImpl implements AiMindMapService {
@Override @Override
public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
// 1.1 获取脑图模型 尝试获取思维导图助手角色如果没有则使用默认模型 // 1 获取脑图模型 尝试获取思维导图助手角色如果没有则使用默认模型
AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
AiChatModelDO model; // 1.1 获取脑图执行模型
String systemMessage; AiChatModelDO model = getModel(mindMapRole);
if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) { // 1.2 获取角色设定消息
model = chatModalService.getChatModel(mindMapRole.getModelId()); String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage())
systemMessage = mindMapRole.getSystemMessage(); ? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
} else { // 1.3 校验平台
model = chatModalService.getRequiredDefaultChatModel();
systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
}
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 2 插入思维导图信息 // 2 插入思维导图信息
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
mindMapMapper.insert(mindMapDO); mindMapMapper.insert(mindMapDO);
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
// 3.1 角色设定 // 3.1 角色设定
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage));
}
// 3.2 用户输入
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
// 3.3 构建提示词 // 3.3 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions); Prompt prompt = new Prompt(chatMessages, chatOptions);
@ -109,4 +102,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
} }
private static List<Message> buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) {
// 1.1 角色设定
chatMessages.add(new SystemMessage(systemMessage));
}
// 1.2 用户输入
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
return chatMessages;
}
// TODO 芋艿这里脑图写作都用到了是不是可以抽哪里去
private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) {
AiChatModelDO model = null;
if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) {
model = chatModalService.getChatModel(chatRoleDO.getModelId());
}
if (Objects.isNull(model)) {
model = chatModalService.getRequiredDefaultChatModel();
}
Assert.notNull(model, "[AI] 获取不到模型");
return model;
}
} }

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.write; package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@ -67,19 +68,14 @@ public class AiWriteServiceImpl implements AiWriteService {
@Override @Override
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
// 1.1 获取写作模型 尝试获取写作助手角色如果没有则使用默认模型 // 1 获取写作模型 尝试获取写作助手角色没有则使用默认模型
AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
// TODO @xin如果有 writeRole但是没 modeId是不是也可以用 systemMessage 建议的写法是先通过 modelId 获取 model如果 model == null chatModalService.getRequiredDefaultChatModel()如果还是 null则抛出异常然后systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理 // 1.1 获取写作执行模型
AiChatModelDO model; AiChatModelDO model = getModel(writeRole);
String systemMessage; // 1.2 获取角色设定消息
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
model = chatModalService.getChatModel(writeRole.getModelId()); ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
systemMessage = writeRole.getSystemMessage(); // 1.3 校验平台
} else {
model = chatModalService.getRequiredDefaultChatModel();
systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
}
// 1.2 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
@ -90,16 +86,11 @@ public class AiWriteServiceImpl implements AiWriteService {
// 3. 调用大模型写作生成 // 3. 调用大模型写作生成
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
// 3.1 角色设定 // 3.1 构建消息列表
// TODO @xin要不把 90 97 这部分合并到一个方法里目的是让这个方法的主干更明确 List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
List<Message> chatMessages = new ArrayList<>(); // 3.2 构建提示词
if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage));
}
// 3.2 用户输入
chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
// 3.3 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions); Prompt prompt = new Prompt(chatMessages, chatOptions);
// 3.3 流式调用
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 4. 流式返回 // 4. 流式返回
@ -122,6 +113,29 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
} }
private AiChatModelDO getModel(AiChatRoleDO writeRole) {
AiChatModelDO model = null;
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
model = chatModalService.getChatModel(writeRole.getModelId());
}
if (Objects.isNull(model)) {
model = chatModalService.getRequiredDefaultChatModel();
}
Assert.notNull(model, "[AI] 获取不到模型");
return model;
}
private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) {
// 1.1 角色设定
chatMessages.add(new SystemMessage(systemMessage));
}
// 1.2 用户输入
chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
return chatMessages;
}
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());