【代码评审】AI:写作部分的建议

This commit is contained in:
YunaiV 2024-07-10 18:57:12 +08:00
parent b8a443fbe0
commit b4014bf2df
7 changed files with 39 additions and 25 deletions

View File

@ -7,7 +7,7 @@ import lombok.Getter;
import java.util.Arrays; import java.util.Arrays;
/** /**
* AI 写作类型的枚举 * AI 内置聊天角色的枚举
* *
* @author xiaoxin * @author xiaoxin
*/ */
@ -21,6 +21,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
2. 回复生成根据用户提供的场景和提示词生成合适的对话或文字回复确保语气和风格符合场景需求 2. 回复生成根据用户提供的场景和提示词生成合适的对话或文字回复确保语气和风格符合场景需求
除此之外不需要除了正文内容外的其他回复如标题开头任何解释性语句或道歉 除此之外不需要除了正文内容外的其他回复如标题开头任何解释性语句或道歉
"""), """),
AI_MIND_MAP_ROLE(2, "脑图助手", """ AI_MIND_MAP_ROLE(2, "脑图助手", """
你是一位非常优秀的思维导图助手你会把用户的所有提问都总结成思维导图然后以 Markdown 格式输出markdown 只需要输出一级标题二级标题三级标题四级标题最多输出四级除此之外不要输出任何其他 markdown 标记下面是一个合格的例子 你是一位非常优秀的思维导图助手你会把用户的所有提问都总结成思维导图然后以 Markdown 格式输出markdown 只需要输出一级标题二级标题三级标题四级标题最多输出四级除此之外不要输出任何其他 markdown 标记下面是一个合格的例子
# Geek-AI 助手 # Geek-AI 助手
@ -38,7 +39,6 @@ public enum AiChatRoleEnum implements IntArrayValuable {
除此之外不要任何解释性语句 除此之外不要任何解释性语句
"""); """);
/** /**
* 角色 * 角色
*/ */
@ -51,7 +51,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
/** /**
* 角色设定 * 角色设定
*/ */
private final String prompt; private final String systemMessage;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiChatRoleEnum::getRole).toArray(); public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiChatRoleEnum::getRole).toArray();

View File

@ -7,7 +7,9 @@ import lombok.Data;
@Schema(description = "管理后台 - AI 思维导图生成 Request VO") @Schema(description = "管理后台 - AI 思维导图生成 Request VO")
@Data @Data
public class AiMindMapGenerateReqVO { public class AiMindMapGenerateReqVO {
@Schema(description = "思维导图内容提示", example = "Java 学习路线") @Schema(description = "思维导图内容提示", example = "Java 学习路线")
@NotBlank(message = "思维导图内容提示不能为空") @NotBlank(message = "思维导图内容提示不能为空")
private String prompt; private String prompt;
} }

View File

@ -12,6 +12,7 @@ import lombok.Data;
* *
* @author xiaoxin * @author xiaoxin
*/ */
// TODO @xin如果没 typehandler 的需求autoResultMap 可以去掉哈
@TableName(value = "ai_mind_map", autoResultMap = true) @TableName(value = "ai_mind_map", autoResultMap = true)
@Data @Data
public class AiMindMapDO extends BaseDO { public class AiMindMapDO extends BaseDO {
@ -24,20 +25,21 @@ public class AiMindMapDO extends BaseDO {
/** /**
* 用户编号 * 用户编号
*
* 关联 AdminUserDO userId 字段
*/ */
private Long userId; private Long userId;
/**
* 模型
*/
private String model;
/** /**
* 平台 * 平台
* <p> * <p>
* 枚举 {@link AiPlatformEnum} * 枚举 {@link AiPlatformEnum}
*/ */
private String platform; private String platform;
/**
* 模型
*/
private String model;
/** /**
* 生成内容提示 * 生成内容提示

View File

@ -2,18 +2,18 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.write;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
import com.baomidou.mybatisplus.annotation.IdType; import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data; import lombok.Data;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
/** /**
* AI 写作 DO * AI 写作 DO
* *
* @author xiaoxin * @author xiaoxin
*/ */
@TableName(value = "ai_write", autoResultMap = true) @TableName("ai_write")
@Data @Data
public class AiWriteDO extends BaseDO { public class AiWriteDO extends BaseDO {
@ -25,6 +25,8 @@ public class AiWriteDO extends BaseDO {
/** /**
* 用户编号 * 用户编号
*
* 关联 AdminUserDO userId 字段
*/ */
private Long userId; private Long userId;
@ -35,17 +37,16 @@ public class AiWriteDO extends BaseDO {
*/ */
private Integer type; private Integer type;
/**
* 模型
*/
private String model;
/** /**
* 平台 * 平台
* *
* 枚举 {@link AiPlatformEnum} * 枚举 {@link AiPlatformEnum}
*/ */
private String platform; private String platform;
/**
* 模型
*/
private String model;
/** /**
* 生成内容提示 * 生成内容提示
@ -56,7 +57,6 @@ public class AiWriteDO extends BaseDO {
* 生成的内容 * 生成的内容
*/ */
private String generatedContent; private String generatedContent;
/** /**
* 原文 * 原文
*/ */
@ -64,21 +64,26 @@ public class AiWriteDO extends BaseDO {
/** /**
* 长度提示词 * 长度提示词
*
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LENGTH}
*/ */
private Integer length; private Integer length;
/** /**
* 格式提示词 * 格式提示词
*
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_FORMAT}
*/ */
private Integer format; private Integer format;
/** /**
* 语气提示词 * 语气提示词
*
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_TONE}
*/ */
private Integer tone; private Integer tone;
/** /**
* 语言提示词 * 语言提示词
*
* 字典{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LANGUAGE}
*/ */
private Integer language; private Integer language;

View File

@ -66,7 +66,7 @@ public class AiMindMapServiceImpl implements AiMindMapService {
systemMessage = mindMapRole.getSystemMessage(); systemMessage = mindMapRole.getSystemMessage();
} else { } else {
model = chatModalService.getRequiredDefaultChatModel(); model = chatModalService.getRequiredDefaultChatModel();
systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getPrompt(); systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
} }
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());

View File

@ -120,6 +120,7 @@ public interface AiChatRoleService {
/** /**
* 根据名字获得聊天角色 * 根据名字获得聊天角色
*
* @param name 名字 * @param name 名字
* @return 聊天角色列表 * @return 聊天角色列表
*/ */

View File

@ -65,6 +65,7 @@ public class AiWriteServiceImpl implements AiWriteService {
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
// 1.1 获取写作模型 尝试获取写作助手角色如果没有则使用默认模型 // 1.1 获取写作模型 尝试获取写作助手角色如果没有则使用默认模型
AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
// TODO @xin如果有 writeRole但是没 modeId是不是也可以用 systemMessage 建议的写法是先通过 modelId 获取 model如果 model == null chatModalService.getRequiredDefaultChatModel()如果还是 null则抛出异常然后systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理
AiChatModelDO model; AiChatModelDO model;
String systemMessage; String systemMessage;
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
@ -72,18 +73,21 @@ public class AiWriteServiceImpl implements AiWriteService {
systemMessage = writeRole.getSystemMessage(); systemMessage = writeRole.getSystemMessage();
} else { } else {
model = chatModalService.getRequiredDefaultChatModel(); model = chatModalService.getRequiredDefaultChatModel();
systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getPrompt(); systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
} }
// 1.2 校验平台 // 1.2 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 2. 插入写作信息 // 2. 插入写作信息
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class,
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
writeMapper.insert(writeDO); writeMapper.insert(writeDO);
// 3. 调用大模型写作生成
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
// 3.1 角色设定 // 3.1 角色设定
// TODO @xin要不把 90 97 这部分合并到一个方法里目的是让这个方法的主干更明确
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) { if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage)); chatMessages.add(new SystemMessage(systemMessage));
@ -94,7 +98,7 @@ public class AiWriteServiceImpl implements AiWriteService {
Prompt prompt = new Prompt(chatMessages, chatOptions); Prompt prompt = new Prompt(chatMessages, chatOptions);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 3.2 流式返回 // 4. 流式返回
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> { return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@ -115,13 +119,13 @@ public class AiWriteServiceImpl implements AiWriteService {
} }
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
Integer type = generateReqVO.getType();
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage()); String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength()); String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength());
// 格式化 prompt
String prompt = generateReqVO.getPrompt(); String prompt = generateReqVO.getPrompt();
if (Objects.equals(type, AiWriteTypeEnum.WRITING.getType())) { if (Objects.equals(generateReqVO.getType(), AiWriteTypeEnum.WRITING.getType())) {
return StrUtil.format(AiWriteTypeEnum.WRITING.getPrompt(), prompt, format, tone, language, length); return StrUtil.format(AiWriteTypeEnum.WRITING.getPrompt(), prompt, format, tone, language, length);
} else { } else {
return StrUtil.format(AiWriteTypeEnum.REPLY.getPrompt(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length); return StrUtil.format(AiWriteTypeEnum.REPLY.getPrompt(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);