【代码优化】AI:将 ChatClient 替换成 ChatModel,和 Spring AI 对齐

This commit is contained in:
YunaiV 2024-07-06 12:54:23 +08:00
parent 6c094aaffc
commit e0f08a0f02
8 changed files with 53 additions and 50 deletions

View File

@ -70,7 +70,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型 // 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
ChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 2. 插入 user 发送消息 // 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -82,7 +82,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 3.2 创建 chat 需要的 Prompt // 3.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
ChatResponse chatResponse = chatClient.call(prompt); ChatResponse chatResponse = chatModel.call(prompt);
// 3.3 段式返回 // 3.3 段式返回
String newContent = chatResponse.getResult().getOutput().getContent(); String newContent = chatResponse.getResult().getOutput().getContent();
@ -101,7 +101,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型 // 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 2. 插入 user 发送消息 // 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -113,7 +113,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 3.2 创建 chat 需要的 Prompt // 3.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
Flux<ChatResponse> streamResponse = chatClient.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 3.3 流式返回 // 3.3 流式返回
// TODO 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题 // TODO 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题

View File

@ -98,8 +98,8 @@ public class AiImageServiceImpl implements AiImageService {
// 1.1 构建请求 // 1.1 构建请求
ImageOptions request = buildImageOptions(req); ImageOptions request = buildImageOptions(req);
// 1.2 执行请求 // 1.2 执行请求
ImageModel imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform())); ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform()));
ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request)); ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
// 2. 上传到文件服务 // 2. 上传到文件服务
byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json()); byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json());

View File

@ -81,17 +81,17 @@ public interface AiApiKeyService {
* @param id 编号 * @param id 编号
* @return ChatModel 对象 * @return ChatModel 对象
*/ */
ChatModel getChatClient(Long id); ChatModel getChatModel(Long id);
/** /**
* 获得 ImageClient 对象 * 获得 ImageModel 对象
* *
* TODO 可优化点目前默认获取 platform 对应的第一个开启的配置用于绘画后续可以支持配置选择 * TODO 可优化点目前默认获取 platform 对应的第一个开启的配置用于绘画后续可以支持配置选择
* *
* @param platform 平台 * @param platform 平台
* @return ImageClient 对象 * @return ImageModel 对象
*/ */
ImageModel getImageClient(AiPlatformEnum platform); ImageModel getImageModel(AiPlatformEnum platform);
/** /**
* 获得 MidjourneyApi 对象 * 获得 MidjourneyApi 对象

View File

@ -1,7 +1,7 @@
package cn.iocoder.yudao.module.ai.service.model; package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
@ -35,7 +35,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
private AiApiKeyMapper apiKeyMapper; private AiApiKeyMapper apiKeyMapper;
@Resource @Resource
private AiClientFactory clientFactory; private AiModelFactory modelFactory;
@Override @Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) { public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
@ -98,19 +98,19 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
// ========== spring-ai 集成 ========== // ========== spring-ai 集成 ==========
@Override @Override
public ChatModel getChatClient(Long id) { public ChatModel getChatModel(Long id) {
AiApiKeyDO apiKey = validateApiKey(id); AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return clientFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); return modelFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
} }
@Override @Override
public ImageModel getImageClient(AiPlatformEnum platform) { public ImageModel getImageModel(AiPlatformEnum platform) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus()); AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) { if (apiKey == null) {
throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName()); throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName());
} }
return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl()); return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
} }
@Override @Override
@ -120,7 +120,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
if (apiKey == null) { if (apiKey == null) {
throw exception(API_KEY_MIDJOURNEY_NOT_FOUND); throw exception(API_KEY_MIDJOURNEY_NOT_FOUND);
} }
return clientFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl()); return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
} }
@Override @Override
@ -130,7 +130,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
if (apiKey == null) { if (apiKey == null) {
throw exception(API_KEY_SUNO_NOT_FOUND); throw exception(API_KEY_SUNO_NOT_FOUND);
} }
return clientFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
} }
} }

View File

@ -54,7 +54,7 @@ public class AiWriteServiceImpl implements AiWriteService {
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
// 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok那可以有限拿 chatRole 的角色如果没有则获取默认的 // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok那可以有限拿 chatRole 的角色如果没有则获取默认的
AiChatModelDO model = chatModalService.getRequiredDefaultChatModel(); AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
// 1.2 插入写作信息 // 1.2 插入写作信息
@ -65,7 +65,7 @@ public class AiWriteServiceImpl implements AiWriteService {
// 2.1 构建提示词 // 2.1 构建提示词
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
Flux<ChatResponse> streamResponse = chatClient.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 2.2 流式返回 // 2.2 流式返回
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();

View File

@ -1,7 +1,7 @@
package cn.iocoder.yudao.framework.ai.config; package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatClient; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatClient;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
@ -28,8 +28,8 @@ import org.springframework.context.annotation.Import;
public class YudaoAiAutoConfiguration { public class YudaoAiAutoConfiguration {
@Bean @Bean
public AiClientFactory aiClientFactory() { public AiModelFactory aiModelFactory() {
return new AiClientFactoryImpl(); return new AiModelFactoryImpl();
} }
// ========== 各种 AI Client 创建 ========== // ========== 各种 AI Client 创建 ==========

View File

@ -7,11 +7,11 @@ import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
/** /**
* AI 客户端工厂的接口类 * AI Model 模型工厂的接口类
* *
* @author fansili * @author fansili
*/ */
public interface AiClientFactory { public interface AiModelFactory {
/** /**
* 基于指定配置获得 ChatModel 对象 * 基于指定配置获得 ChatModel 对象
@ -33,29 +33,29 @@ public interface AiClientFactory {
* @param platform 平台 * @param platform 平台
* @return ChatModel 对象 * @return ChatModel 对象
*/ */
ChatModel getDefaultChatClient(AiPlatformEnum platform); ChatModel getDefaultChatModel(AiPlatformEnum platform);
/** /**
* 基于默认配置获得 ImageClient 对象 * 基于默认配置获得 ImageModel 对象
* *
* 默认配置指的是在 application.yaml 配置文件中的 spring.ai 相关的配置 * 默认配置指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
* *
* @param platform 平台 * @param platform 平台
* @return ImageClient 对象 * @return ImageModel 对象
*/ */
ImageModel getDefaultImageClient(AiPlatformEnum platform); ImageModel getDefaultImageModel(AiPlatformEnum platform);
/** /**
* 基于指定配置获得 ImageClient 对象 * 基于指定配置获得 ImageModel 对象
* *
* 如果不存在则进行创建 * 如果不存在则进行创建
* *
* @param platform 平台 * @param platform 平台
* @param apiKey API KEY * @param apiKey API KEY
* @param url API URL * @param url API URL
* @return ImageClient 对象 * @return ImageModel 对象
*/ */
ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url); ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url);
/** /**
* 基于指定配置获得 MidjourneyApi 对象 * 基于指定配置获得 MidjourneyApi 对象

View File

@ -43,11 +43,11 @@ import org.springframework.web.client.RestClient;
import java.util.List; import java.util.List;
/** /**
* AI 客户端工厂的实现类 * AI Model 模型工厂的实现类
* *
* @author 芋道源码 * @author 芋道源码
*/ */
public class AiClientFactoryImpl implements AiClientFactory { public class AiModelFactoryImpl implements AiModelFactory {
@Override @Override
public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) { public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) {
@ -55,8 +55,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> { return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OPENAI:
return buildOpenAiChatClient(apiKey, url);
case OLLAMA: case OLLAMA:
return buildOllamaChatClient(url); return buildOllamaChatClient(url);
case YI_YAN: case YI_YAN:
@ -67,6 +65,8 @@ public class AiClientFactoryImpl implements AiClientFactory {
return buildQianWenChatClient(apiKey); return buildQianWenChatClient(apiKey);
case DEEP_SEEK: case DEEP_SEEK:
return buildDeepSeekChatClient(apiKey); return buildDeepSeekChatClient(apiKey);
case OPENAI:
return buildOpenAiChatModel(apiKey, url);
default: default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
} }
@ -74,11 +74,9 @@ public class AiClientFactoryImpl implements AiClientFactory {
} }
@Override @Override
public ChatModel getDefaultChatClient(AiPlatformEnum platform) { public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OPENAI:
return SpringUtil.getBean(OpenAiChatModel.class);
case OLLAMA: case OLLAMA:
return SpringUtil.getBean(OllamaChatModel.class); return SpringUtil.getBean(OllamaChatModel.class);
case YI_YAN: case YI_YAN:
@ -87,13 +85,15 @@ public class AiClientFactoryImpl implements AiClientFactory {
return SpringUtil.getBean(XingHuoChatClient.class); return SpringUtil.getBean(XingHuoChatClient.class);
case QIAN_WEN: case QIAN_WEN:
return SpringUtil.getBean(TongYiChatModel.class); return SpringUtil.getBean(TongYiChatModel.class);
case OPENAI:
return SpringUtil.getBean(OpenAiChatModel.class);
default: default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
} }
} }
@Override @Override
public ImageModel getDefaultImageClient(AiPlatformEnum platform) { public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OPENAI: case OPENAI:
@ -106,11 +106,11 @@ public class AiClientFactoryImpl implements AiClientFactory {
} }
@Override @Override
public ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) { public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OPENAI: case OPENAI:
return buildOpenAiImageClient(apiKey, url); return buildOpenAiImageModel(apiKey, url);
case STABLE_DIFFUSION: case STABLE_DIFFUSION:
return buildStabilityAiImageClient(apiKey, url); return buildStabilityAiImageClient(apiKey, url);
default: default:
@ -145,12 +145,21 @@ public class AiClientFactoryImpl implements AiClientFactory {
/** /**
* 可参考 {@link OpenAiAutoConfiguration} * 可参考 {@link OpenAiAutoConfiguration}
*/ */
private static OpenAiChatModel buildOpenAiChatClient(String openAiToken, String url) { private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiApi openAiApi = new OpenAiApi(url, openAiToken); OpenAiApi openAiApi = new OpenAiApi(url, openAiToken);
return new OpenAiChatModel(openAiApi); return new OpenAiChatModel(openAiApi);
} }
/**
* 可参考 {@link OpenAiAutoConfiguration}
*/
private OpenAiImageModel buildOpenAiImageModel(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
return new OpenAiImageModel(openAiApi);
}
/** /**
* 可参考 {@link OllamaAutoConfiguration} * 可参考 {@link OllamaAutoConfiguration}
*/ */
@ -200,12 +209,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
} }
private OpenAiImageModel buildOpenAiImageClient(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
return new OpenAiImageModel(openAiApi);
}
private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) { private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) {
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url); StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);