mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2024-11-23 07:41:53 +08:00
【代码优化】AI:新增 AIUtils,用于对接 spring ai 各种对象的构建
This commit is contained in:
parent
41071fc689
commit
471968eaf2
@ -26,9 +26,10 @@ public class AiWriteController {
|
|||||||
private AiWriteService writeService;
|
private AiWriteService writeService;
|
||||||
|
|
||||||
@PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
@PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||||
@PermitAll
|
|
||||||
@Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
|
@Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
|
||||||
|
@PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
|
||||||
public Flux<CommonResult<String>> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
|
public Flux<CommonResult<String>> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
|
||||||
return writeService.generateWriteContent(generateReqVO, getLoginUserId());
|
return writeService.generateWriteContent(generateReqVO, getLoginUserId());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -4,8 +4,7 @@ import cn.hutool.core.collection.CollUtil;
|
|||||||
import cn.hutool.core.util.ObjUtil;
|
import cn.hutool.core.util.ObjUtil;
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
@ -19,7 +18,6 @@ import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
|
|||||||
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.chat.messages.*;
|
import org.springframework.ai.chat.messages.*;
|
||||||
@ -28,9 +26,6 @@ import org.springframework.ai.chat.model.ChatResponse;
|
|||||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||||
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
|
||||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
|
||||||
import org.springframework.ai.qianfan.QianFanChatOptions;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
@ -148,46 +143,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
}
|
}
|
||||||
// 1.2 history message 历史消息
|
// 1.2 history message 历史消息
|
||||||
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
|
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
|
||||||
contextMessages.forEach(message -> {
|
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
|
||||||
// TODO @芋艿:看看有没优化空间
|
|
||||||
if (MessageType.USER.getValue().equals(message.getType())) {
|
|
||||||
chatMessages.add(new UserMessage(message.getContent()));
|
|
||||||
} else {
|
|
||||||
chatMessages.add(new AssistantMessage(message.getContent()));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
// 1.3 user message 新发送消息
|
// 1.3 user message 新发送消息
|
||||||
chatMessages.add(new UserMessage(sendReqVO.getContent()));
|
chatMessages.add(new UserMessage(sendReqVO.getContent()));
|
||||||
|
|
||||||
// 2. 构建 ChatOptions 对象
|
// 2. 构建 ChatOptions 对象
|
||||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(),
|
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
|
||||||
conversation.getTemperature(), conversation.getMaxTokens());
|
conversation.getTemperature(), conversation.getMaxTokens());
|
||||||
return new Prompt(chatMessages, chatOptions);
|
return new Prompt(chatMessages, chatOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
|
||||||
Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
|
||||||
//noinspection EnhancedSwitchMigration
|
|
||||||
switch (platform) {
|
|
||||||
case OPENAI:
|
|
||||||
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
||||||
case OLLAMA:
|
|
||||||
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
|
||||||
case YI_YAN:
|
|
||||||
// TODO 芋艿:貌似 model 只要一设置,就报错
|
|
||||||
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
||||||
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
||||||
case XING_HUO:
|
|
||||||
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
|
||||||
.setMaxTokens(maxTokens);
|
|
||||||
case QIAN_WEN:
|
|
||||||
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
|
||||||
default:
|
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
||||||
*
|
*
|
||||||
|
@ -2,8 +2,7 @@ package cn.iocoder.yudao.module.ai.service.write;
|
|||||||
|
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
||||||
@ -16,16 +15,12 @@ import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
|
|||||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||||
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
|
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
|
||||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.chat.model.ChatResponse;
|
import org.springframework.ai.chat.model.ChatResponse;
|
||||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||||
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
|
||||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
|
||||||
import org.springframework.ai.qianfan.QianFanChatOptions;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
@ -56,19 +51,21 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
||||||
// 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?
|
// 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的;
|
||||||
AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
|
AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
|
||||||
StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
|
StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
|
||||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
|
||||||
|
|
||||||
// 1.2 插入写作信息
|
// 1.2 插入写作信息
|
||||||
|
// TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性
|
||||||
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
|
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
|
||||||
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
||||||
|
|
||||||
// 2.1 构建提示词
|
// 2.1 构建提示词
|
||||||
|
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||||
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
|
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
|
||||||
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
||||||
|
|
||||||
// 2.2 流式返回
|
// 2.2 流式返回
|
||||||
StringBuffer contentBuffer = new StringBuffer();
|
StringBuffer contentBuffer = new StringBuffer();
|
||||||
return streamResponse.map(chunk -> {
|
return streamResponse.map(chunk -> {
|
||||||
@ -92,7 +89,9 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||||||
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat());
|
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat());
|
||||||
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat());
|
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat());
|
||||||
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat());
|
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat());
|
||||||
|
// TODO @xin:建议改成 if return 哈;更简洁;
|
||||||
if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
|
if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
|
||||||
|
// TODO @xin:写成静态枚举哈
|
||||||
template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
|
template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
|
||||||
return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
|
return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
|
||||||
} else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
|
} else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
|
||||||
@ -103,27 +102,4 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO 芋艿:复用
|
|
||||||
private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
|
||||||
Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
|
||||||
//noinspection EnhancedSwitchMigration
|
|
||||||
switch (platform) {
|
|
||||||
case OPENAI:
|
|
||||||
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
||||||
case OLLAMA:
|
|
||||||
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
|
||||||
case YI_YAN:
|
|
||||||
// TODO 芋艿:貌似 model 只要一设置,就报错
|
|
||||||
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
||||||
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
||||||
case XING_HUO:
|
|
||||||
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
|
||||||
.setMaxTokens(maxTokens);
|
|
||||||
case QIAN_WEN:
|
|
||||||
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
|
||||||
default:
|
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,59 @@
|
|||||||
|
package cn.iocoder.yudao.framework.ai.core.util;
|
||||||
|
|
||||||
|
import cn.hutool.core.util.StrUtil;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
||||||
|
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
||||||
|
import org.springframework.ai.chat.messages.*;
|
||||||
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||||
|
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||||
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||||
|
import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Spring AI 工具类
|
||||||
|
*
|
||||||
|
* @author 芋道源码
|
||||||
|
*/
|
||||||
|
public class AiUtils {
|
||||||
|
|
||||||
|
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
||||||
|
Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
||||||
|
//noinspection EnhancedSwitchMigration
|
||||||
|
switch (platform) {
|
||||||
|
case OPENAI:
|
||||||
|
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||||
|
case OLLAMA:
|
||||||
|
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
||||||
|
case YI_YAN:
|
||||||
|
// TODO @xin:貌似 model 只要一设置,就报错;可以排查下
|
||||||
|
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||||
|
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||||
|
case XING_HUO:
|
||||||
|
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
||||||
|
.setMaxTokens(maxTokens);
|
||||||
|
case QIAN_WEN:
|
||||||
|
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Message buildMessage(String type, String content) {
|
||||||
|
if (MessageType.USER.getValue().equals(type)) {
|
||||||
|
return new UserMessage(content);
|
||||||
|
}
|
||||||
|
if (MessageType.ASSISTANT.getValue().equals(type)) {
|
||||||
|
return new AssistantMessage(content);
|
||||||
|
}
|
||||||
|
if (MessageType.SYSTEM.getValue().equals(type)) {
|
||||||
|
return new SystemMessage(content);
|
||||||
|
}
|
||||||
|
if (MessageType.FUNCTION.getValue().equals(type)) {
|
||||||
|
return new FunctionMessage(content);
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user