【代码优化】AI:新增 AIUtils,用于对接 spring ai 各种对象的构建

This commit is contained in:
YunaiV 2024-07-04 23:53:22 +08:00
parent 41071fc689
commit 471968eaf2
4 changed files with 71 additions and 69 deletions

View File

@ -26,9 +26,10 @@ public class AiWriteController {
private AiWriteService writeService;
@PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@PermitAll
@Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
@PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题
public Flux<CommonResult<String>> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
return writeService.generateWriteContent(generateReqVO, getLoginUserId());
}
}

View File

@ -4,8 +4,7 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@ -19,7 +18,6 @@ import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.*;
@ -28,9 +26,6 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.qianfan.QianFanChatOptions;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
@ -148,46 +143,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
// 1.2 history message 历史消息
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> {
// TODO @芋艿看看有没优化空间
if (MessageType.USER.getValue().equals(message.getType())) {
chatMessages.add(new UserMessage(message.getContent()));
} else {
chatMessages.add(new AssistantMessage(message.getContent()));
}
});
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
// 1.3 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent()));
// 2. 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(),
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
conversation.getTemperature(), conversation.getMaxTokens());
return new Prompt(chatMessages, chatOptions);
}
private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
Float temperatureF = temperature != null ? temperature.floatValue() : null;
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN:
// TODO 芋艿貌似 model 只要一设置就报错
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens);
case QIAN_WEN:
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
/**
* 从历史消息中获得倒序的 n 组消息作为消息上下文
*

View File

@ -2,8 +2,7 @@ package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
@ -16,16 +15,12 @@ import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.qianfan.QianFanChatOptions;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
@ -56,19 +51,21 @@ public class AiWriteServiceImpl implements AiWriteService {
@Override
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
// 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok
// 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok那可以有限拿 chatRole 的角色如果没有则获取默认的
AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
// 1.2 插入写作信息
// TODO @xin建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())写在 toBean consumer 原因是让这个 set 保持完整性
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
// 2.1 构建提示词
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
// 2.2 流式返回
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
@ -92,7 +89,9 @@ public class AiWriteServiceImpl implements AiWriteService {
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat());
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat());
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat());
// TODO @xin建议改成 if return 更简洁
if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
// TODO @xin写成静态枚举哈
template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
} else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
@ -103,27 +102,4 @@ public class AiWriteServiceImpl implements AiWriteService {
}
}
// TODO 芋艿复用
private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
Float temperatureF = temperature != null ? temperature.floatValue() : null;
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN:
// TODO 芋艿貌似 model 只要一设置就报错
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens);
case QIAN_WEN:
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
}

View File

@ -0,0 +1,59 @@
package cn.iocoder.yudao.framework.ai.core.util;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.qianfan.QianFanChatOptions;
/**
* Spring AI 工具类
*
* @author 芋道源码
*/
public class AiUtils {
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
Float temperatureF = temperature != null ? temperature.floatValue() : null;
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN:
// TODO @xin貌似 model 只要一设置就报错可以排查下
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens);
case QIAN_WEN:
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
public static Message buildMessage(String type, String content) {
if (MessageType.USER.getValue().equals(type)) {
return new UserMessage(content);
}
if (MessageType.ASSISTANT.getValue().equals(type)) {
return new AssistantMessage(content);
}
if (MessageType.SYSTEM.getValue().equals(type)) {
return new SystemMessage(content);
}
if (MessageType.FUNCTION.getValue().equals(type)) {
return new FunctionMessage(content);
}
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
}
}