【代码评审】AI:写作部分的建议

This commit is contained in:
YunaiV 2024-07-10 12:59:21 +08:00
parent f3a6ba8349
commit 4c21ae32fe
3 changed files with 17 additions and 13 deletions

View File

@ -32,7 +32,7 @@ public enum AiWriteTypeEnum implements IntArrayValuable {
/**
* 模版
*/
private final String template;
private final String prompt;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray();

View File

@ -63,26 +63,26 @@ public class AiWriteServiceImpl implements AiWriteService {
// 1.1 获取写作模型 尝试获取写作助手角色如果没有则使用默认模型
AiChatRoleDO writeRole = selectOneWriteRole();
AiChatModelDO model;
// TODO @xinwriteRole.getModelId 可能为空所以最好是先通过 chatRole 如果它没拿到通过 getRequiredDefaultChatModel 再拿
if (Objects.nonNull(writeRole)) {
model = chatModalService.getChatModel(writeRole.getModelId());
} else {
model = chatModalService.getRequiredDefaultChatModel();
}
// 1.2 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 1.2 插入写作信息
// 2. 插入写作信息
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
writeMapper.insert(writeDO);
// 2.1 构建提示词
// 3.1 构建提示词
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 2.2 流式返回
// 3.2 流式返回
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@ -102,10 +102,13 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
}
// TODO @xinchatRoleService 增加一个 getChatRoleListByName
private AiChatRoleDO selectOneWriteRole() {
AiChatRoleDO chatRoleDO = null;
// TODO @xin"写作助手" 枚举下
PageResult<AiChatRoleDO> writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手"));
List<AiChatRoleDO> list = writeRolePage.getList();
// TODO @xinCollUtil.getFirst 简化下
if (CollUtil.isNotEmpty(list)) {
chatRoleDO = list.get(0);
}
@ -113,19 +116,19 @@ public class AiWriteServiceImpl implements AiWriteService {
}
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
// 校验写作类型是否合法
Integer type = generateReqVO.getType();
// TODO @xin这里可以搞到 validator 的校验InEnum
AiWriteTypeEnum.validateType(type);
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength());
String prompt = generateReqVO.getPrompt();
// 校验写作类型是否合法
AiWriteTypeEnum.validateType(type);
if (Objects.equals(type, AiWriteTypeEnum.WRITING.getType())) {
return StrUtil.format(AiWriteTypeEnum.WRITING.getTemplate(), prompt, format, tone, language, length);
return StrUtil.format(AiWriteTypeEnum.WRITING.getPrompt(), prompt, format, tone, language, length);
} else {
return StrUtil.format(AiWriteTypeEnum.REPLY.getTemplate(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);
return StrUtil.format(AiWriteTypeEnum.REPLY.getPrompt(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);
}
}

View File

@ -10,15 +10,16 @@ import org.junit.jupiter.api.Test;
import java.util.Map;
// TODO @fan改成 TongYiImagesModel
/**
* 通义万象 - 测试
* 通义万象
*/
public class TongYiImagesModelTests {
@Test
public void imageCallTest() throws NoApiKeyException {
// 设置 api key
Constants.apiKey="sk-Zsd81gZYg7";
Constants.apiKey = "sk-Zsd81gZYg7";
ImageSynthesisParam param =
ImageSynthesisParam.builder()
.model(ImageSynthesis.Models.WANX_V1)