【优化】AI 写作:做角色设定,提高准确率

This commit is contained in:
xiaoxin 2024-07-10 14:28:31 +08:00
parent 7cd16ffbf3
commit bcdb23b89d
9 changed files with 111 additions and 74 deletions

View File

@ -0,0 +1,63 @@
package cn.iocoder.yudao.module.ai.enums;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 写作类型的枚举
*
* @author xiaoxin
*/
@AllArgsConstructor
@Getter
public enum AiChatRoleEnum implements IntArrayValuable {
AI_WRITE_ROLE(1, "写作助手", """
你是一位出色的写作助手能够帮助用户生成创意和灵感并在用户提供场景和提示词时生成对应的回复你的任务包括
1. 撰写建议根据用户提供的主题或问题提供详细的写作建议情节发展方向角色设定以及背景描写确保内容结构清晰有逻辑
2. 回复生成根据用户提供的场景和提示词生成合适的对话或文字回复确保语气和风格符合场景需求
除此之外不需要除了正文内容外的其他回复如标题开头任何解释性语句或道歉
"""),
AI_MIND_MAP_ROLE(2, "脑图助手", """
你是一位非常优秀的思维导图助手你会把用户的所有提问都总结成思维导图然后以 Markdown 格式输出markdown 只需要输出一级标题二级标题三级标题四级标题最多输出四级除此之外不要输出任何其他 markdown 标记下面是一个合格的例子
# Geek-AI 助手
## 完整的开源系统
### 前端开源
### 后端开源
## 支持各种大模型
### OpenAI
### Azure
### 文心一言
### 通义千问
## 集成多种收费方式
### 支付宝
### 微信
除此之外不要任何解释性语句
""");
/**
* 角色
*/
private final Integer role;
/**
* 角色名
*/
private final String name;
/**
* 角色设定
*/
private final String prompt;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiChatRoleEnum::getRole).toArray();
@Override
public int[] array() {
return ARRAYS;
}
}

View File

@ -1,7 +1,5 @@
package cn.iocoder.yudao.module.ai.enums.write; package cn.iocoder.yudao.module.ai.enums.write;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable; import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
@ -41,9 +39,4 @@ public enum AiWriteTypeEnum implements IntArrayValuable {
return ARRAYS; return ARRAYS;
} }
public static void validateType(Integer type) {
if (ArrayUtil.contains(ARRAYS, type)) return;
throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", type));
}
} }

View File

@ -26,7 +26,7 @@ public class AiMindMapController {
private AiMindMapService mindMapService; private AiMindMapService mindMapService;
@PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") @Operation(summary = "脑图生成(流式)", description = "流式返回,响应较快")
@PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题 @PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题
public Flux<CommonResult<String>> generateMindMap(@RequestBody @Valid AiMindMapGenerateReqVO generateReqVO) { public Flux<CommonResult<String>> generateMindMap(@RequestBody @Valid AiMindMapGenerateReqVO generateReqVO) {
return mindMapService.generateMindMap(generateReqVO, getLoginUserId()); return mindMapService.generateMindMap(generateReqVO, getLoginUserId());

View File

@ -11,7 +11,7 @@ import lombok.Data;
public class AiWriteGenerateReqVO { public class AiWriteGenerateReqVO {
@Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@InEnum(AiWriteTypeEnum.class) @InEnum(value = AiWriteTypeEnum.class, message = "写作类型必须是 {value}")
private Integer type; private Integer type;
@Schema(description = "写作内容提示", example = "1.撰写田忌赛马2.回复:不批") @Schema(description = "写作内容提示", example = "1.撰写田忌赛马2.回复:不批")

View File

@ -4,9 +4,7 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
@ -47,4 +45,10 @@ public interface AiChatRoleMapper extends BaseMapperX<AiChatRoleDO> {
.groupBy(AiChatRoleDO::getCategory)); .groupBy(AiChatRoleDO::getCategory));
} }
default List<AiChatRoleDO> selectListByName(String name) {
return selectList(new LambdaQueryWrapperX<AiChatRoleDO>()
.likeIfPresent(AiChatRoleDO::getName, name)
.orderByAsc(AiChatRoleDO::getSort));
}
} }

View File

@ -5,15 +5,14 @@ import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper; import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
@ -56,61 +55,40 @@ public class AiMindMapServiceImpl implements AiMindMapService {
@Resource @Resource
private AiMindMapMapper mindMapMapper; private AiMindMapMapper mindMapMapper;
private static final String DEFAULT_SYSTEM_MESSAGE = """
你是一位非常优秀的思维导图助手你会把用户的所有提问都总结成思维导图然后以 Markdown 格式输出markdown 只需要输出一级标题二级标题三级标题四级标题最多输出四级除此之外不要输出任何其他 markdown 标记下面是一个合格的例子
# Geek-AI 助手
## 完整的开源系统
### 前端开源
### 后端开源
## 支持各种大模型
### OpenAI
### Azure
### 文心一言
### 通义千问
## 集成多种收费方式
### 支付宝
### 微信
另外除此之外不要任何解释性语句
""";
@Override @Override
public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
// 1.1 获取脑图模型 尝试获取思维导图助手角色如果没有则使用默认模型 // 1.1 获取脑图模型 尝试获取思维导图助手角色如果没有则使用默认模型
AiChatRoleDO mindMapRole = selectOneMindMapRole(); AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
AiChatModelDO model; AiChatModelDO model;
String systemMessage; String systemMessage;
if (Objects.nonNull(mindMapRole)) { if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) {
model = chatModalService.getChatModel(mindMapRole.getModelId()); model = chatModalService.getChatModel(mindMapRole.getModelId());
systemMessage = mindMapRole.getSystemMessage(); systemMessage = mindMapRole.getSystemMessage();
} else { } else {
model = chatModalService.getRequiredDefaultChatModel(); model = chatModalService.getRequiredDefaultChatModel();
systemMessage = DEFAULT_SYSTEM_MESSAGE; systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getPrompt();
} }
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 1.2 插入思维导图信息 // 2 插入思维导图信息
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
mindMapMapper.insert(mindMapDO); mindMapMapper.insert(mindMapDO);
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
// 2.1 角色设定 // 3.1 角色设定
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) { if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage)); chatMessages.add(new SystemMessage(systemMessage));
} }
// 2.2 用户输入 // 3.2 用户输入
chatMessages.add(new UserMessage(generateReqVO.getPrompt())); chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
// 2.3 构建提示词 // 3.3 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions); Prompt prompt = new Prompt(chatMessages, chatOptions);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 2.4 流式返回 // 3.4 流式返回
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> { return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@ -131,13 +109,4 @@ public class AiMindMapServiceImpl implements AiMindMapService {
} }
private AiChatRoleDO selectOneMindMapRole() {
AiChatRoleDO chatRoleDO = null;
PageResult<AiChatRoleDO> mindMapRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("思维导图助手"));
List<AiChatRoleDO> list = mindMapRolePage.getList();
if (CollUtil.isNotEmpty(list)) {
chatRoleDO = list.get(0);
}
return chatRoleDO;
}
} }

View File

@ -118,4 +118,11 @@ public interface AiChatRoleService {
*/ */
List<String> getChatRoleCategoryList(); List<String> getChatRoleCategoryList();
/**
* 根据名字获得聊天角色
* @param name 名字
* @return 聊天角色列表
*/
List<AiChatRoleDO> getChatRoleListByName(String name);
} }

View File

@ -137,5 +137,10 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
return convertList(list, AiChatRoleDO::getCategory, role -> role != null && StrUtil.isNotBlank(role.getCategory())); return convertList(list, AiChatRoleDO::getCategory, role -> role != null && StrUtil.isNotBlank(role.getCategory()));
} }
@Override
public List<AiChatRoleDO> getChatRoleListByName(String name) {
return chatRoleMapper.selectListByName(name);
}
} }

View File

@ -5,15 +5,14 @@ import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants; import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
@ -23,6 +22,9 @@ import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.system.api.dict.DictDataApi; import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
@ -30,6 +32,7 @@ import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@ -61,13 +64,15 @@ public class AiWriteServiceImpl implements AiWriteService {
@Override @Override
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
// 1.1 获取写作模型 尝试获取写作助手角色如果没有则使用默认模型 // 1.1 获取写作模型 尝试获取写作助手角色如果没有则使用默认模型
AiChatRoleDO writeRole = selectOneWriteRole(); AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
AiChatModelDO model; AiChatModelDO model;
// TODO @xinwriteRole.getModelId 可能为空所以最好是先通过 chatRole 如果它没拿到通过 getRequiredDefaultChatModel 再拿 String systemMessage;
if (Objects.nonNull(writeRole)) { if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
model = chatModalService.getChatModel(writeRole.getModelId()); model = chatModalService.getChatModel(writeRole.getModelId());
systemMessage = writeRole.getSystemMessage();
} else { } else {
model = chatModalService.getRequiredDefaultChatModel(); model = chatModalService.getRequiredDefaultChatModel();
systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getPrompt();
} }
// 1.2 校验平台 // 1.2 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
@ -77,9 +82,16 @@ public class AiWriteServiceImpl implements AiWriteService {
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
writeMapper.insert(writeDO); writeMapper.insert(writeDO);
// 3.1 构建提示词
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); // 3.1 角色设定
List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage));
}
// 3.2 用户输入
chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
// 3.3 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 3.2 流式返回 // 3.2 流式返回
@ -102,24 +114,8 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
} }
// TODO @xinchatRoleService 增加一个 getChatRoleListByName
private AiChatRoleDO selectOneWriteRole() {
AiChatRoleDO chatRoleDO = null;
// TODO @xin"写作助手" 枚举下
PageResult<AiChatRoleDO> writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手"));
List<AiChatRoleDO> list = writeRolePage.getList();
// TODO @xinCollUtil.getFirst 简化下
if (CollUtil.isNotEmpty(list)) {
chatRoleDO = list.get(0);
}
return chatRoleDO;
}
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
// 校验写作类型是否合法
Integer type = generateReqVO.getType(); Integer type = generateReqVO.getType();
// TODO @xin这里可以搞到 validator 的校验InEnum
AiWriteTypeEnum.validateType(type);
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage()); String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());