From e0f08a0f02b5bb53b7f3e6201469eb3a5767bf40 Mon Sep 17 00:00:00 2001 From: YunaiV Date: Sat, 6 Jul 2024 12:54:23 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E3=80=91AI=EF=BC=9A=E5=B0=86=20ChatClient=20=E6=9B=BF=E6=8D=A2?= =?UTF-8?q?=E6=88=90=20ChatModel=EF=BC=8C=E5=92=8C=20Spring=20AI=20?= =?UTF-8?q?=E5=AF=B9=E9=BD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat/AiChatMessageServiceImpl.java | 8 ++-- .../ai/service/image/AiImageServiceImpl.java | 4 +- .../ai/service/model/AiApiKeyService.java | 8 ++-- .../ai/service/model/AiApiKeyServiceImpl.java | 16 ++++---- .../ai/service/write/AiWriteServiceImpl.java | 4 +- .../ai/config/YudaoAiAutoConfiguration.java | 8 ++-- ...ClientFactory.java => AiModelFactory.java} | 18 ++++----- ...ctoryImpl.java => AiModelFactoryImpl.java} | 37 ++++++++++--------- 8 files changed, 53 insertions(+), 50 deletions(-) rename yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/{AiClientFactory.java => AiModelFactory.java} (79%) rename yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/{AiClientFactoryImpl.java => AiModelFactoryImpl.java} (93%) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 44b48a66a..6c8cdeaca 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -70,7 +70,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); // 1.2 校验模型 AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); - ChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); + ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); // 2. 插入 user 发送消息 AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, @@ -82,7 +82,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { // 3.2 创建 chat 需要的 Prompt Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); - ChatResponse chatResponse = chatClient.call(prompt); + ChatResponse chatResponse = chatModel.call(prompt); // 3.3 段式返回 String newContent = chatResponse.getResult().getOutput().getContent(); @@ -101,7 +101,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); // 1.2 校验模型 AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); - StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); + StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); // 2. 插入 user 发送消息 AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, @@ -113,7 +113,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { // 3.2 创建 chat 需要的 Prompt Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); - Flux streamResponse = chatClient.stream(prompt); + Flux streamResponse = chatModel.stream(prompt); // 3.3 流式返回 // TODO 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index d3054d80d..27c0978f1 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -98,8 +98,8 @@ public class AiImageServiceImpl implements AiImageService { // 1.1 构建请求 ImageOptions request = buildImageOptions(req); // 1.2 执行请求 - ImageModel imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform())); - ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request)); + ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform())); + ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request)); // 2. 上传到文件服务 byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json()); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java index a5ba60867..fe8fdd194 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java @@ -81,17 +81,17 @@ public interface AiApiKeyService { * @param id 编号 * @return ChatModel 对象 */ - ChatModel getChatClient(Long id); + ChatModel getChatModel(Long id); /** - * 获得 ImageClient 对象 + * 获得 ImageModel 对象 * * TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择 * * @param platform 平台 - * @return ImageClient 对象 + * @return ImageModel 对象 */ - ImageModel getImageClient(AiPlatformEnum platform); + ImageModel getImageModel(AiPlatformEnum platform); /** * 获得 MidjourneyApi 对象 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java index 8db777f3f..6f8f06076 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java @@ -1,7 +1,7 @@ package cn.iocoder.yudao.module.ai.service.model; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; +import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; @@ -35,7 +35,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { private AiApiKeyMapper apiKeyMapper; @Resource - private AiClientFactory clientFactory; + private AiModelFactory modelFactory; @Override public Long createApiKey(AiApiKeySaveReqVO createReqVO) { @@ -98,19 +98,19 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { // ========== 与 spring-ai 集成 ========== @Override - public ChatModel getChatClient(Long id) { + public ChatModel getChatModel(Long id) { AiApiKeyDO apiKey = validateApiKey(id); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); - return clientFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); + return modelFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); } @Override - public ImageModel getImageClient(AiPlatformEnum platform) { + public ImageModel getImageModel(AiPlatformEnum platform) { AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus()); if (apiKey == null) { throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName()); } - return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl()); + return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl()); } @Override @@ -120,7 +120,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { if (apiKey == null) { throw exception(API_KEY_MIDJOURNEY_NOT_FOUND); } - return clientFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl()); + return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl()); } @Override @@ -130,7 +130,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { if (apiKey == null) { throw exception(API_KEY_SUNO_NOT_FOUND); } - return clientFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); + return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); } } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index 0051e4a8e..d43c11d3a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -54,7 +54,7 @@ public class AiWriteServiceImpl implements AiWriteService { public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的; AiChatModelDO model = chatModalService.getRequiredDefaultChatModel(); - StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); + StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); // 1.2 插入写作信息 @@ -65,7 +65,7 @@ public class AiWriteServiceImpl implements AiWriteService { // 2.1 构建提示词 ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); - Flux streamResponse = chatClient.stream(prompt); + Flux streamResponse = chatModel.stream(prompt); // 2.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java index b8419c87d..128d9d99a 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java @@ -1,7 +1,7 @@ package cn.iocoder.yudao.framework.ai.config; -import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; -import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl; +import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; +import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatClient; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; @@ -28,8 +28,8 @@ import org.springframework.context.annotation.Import; public class YudaoAiAutoConfiguration { @Bean - public AiClientFactory aiClientFactory() { - return new AiClientFactoryImpl(); + public AiModelFactory aiModelFactory() { + return new AiModelFactoryImpl(); } // ========== 各种 AI Client 创建 ========== diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java similarity index 79% rename from yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java rename to yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java index e37afc41d..e1d3aba63 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java @@ -7,11 +7,11 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; /** - * AI 客户端工厂的接口类 + * AI Model 模型工厂的接口类 * * @author fansili */ -public interface AiClientFactory { +public interface AiModelFactory { /** * 基于指定配置,获得 ChatModel 对象 @@ -33,29 +33,29 @@ public interface AiClientFactory { * @param platform 平台 * @return ChatModel 对象 */ - ChatModel getDefaultChatClient(AiPlatformEnum platform); + ChatModel getDefaultChatModel(AiPlatformEnum platform); /** - * 基于默认配置,获得 ImageClient 对象 + * 基于默认配置,获得 ImageModel 对象 * * 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置 * * @param platform 平台 - * @return ImageClient 对象 + * @return ImageModel 对象 */ - ImageModel getDefaultImageClient(AiPlatformEnum platform); + ImageModel getDefaultImageModel(AiPlatformEnum platform); /** - * 基于指定配置,获得 ImageClient 对象 + * 基于指定配置,获得 ImageModel 对象 * * 如果不存在,则进行创建 * * @param platform 平台 * @param apiKey API KEY * @param url API URL - * @return ImageClient 对象 + * @return ImageModel 对象 */ - ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url); + ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url); /** * 基于指定配置,获得 MidjourneyApi 对象 diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java similarity index 93% rename from yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java rename to yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index 7eae7a8db..ca01a6611 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -43,11 +43,11 @@ import org.springframework.web.client.RestClient; import java.util.List; /** - * AI 客户端工厂的实现类 + * AI Model 模型工厂的实现类 * * @author 芋道源码 */ -public class AiClientFactoryImpl implements AiClientFactory { +public class AiModelFactoryImpl implements AiModelFactory { @Override public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) { @@ -55,8 +55,6 @@ public class AiClientFactoryImpl implements AiClientFactory { return Singleton.get(cacheKey, (Func0) () -> { //noinspection EnhancedSwitchMigration switch (platform) { - case OPENAI: - return buildOpenAiChatClient(apiKey, url); case OLLAMA: return buildOllamaChatClient(url); case YI_YAN: @@ -67,6 +65,8 @@ public class AiClientFactoryImpl implements AiClientFactory { return buildQianWenChatClient(apiKey); case DEEP_SEEK: return buildDeepSeekChatClient(apiKey); + case OPENAI: + return buildOpenAiChatModel(apiKey, url); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); } @@ -74,11 +74,9 @@ public class AiClientFactoryImpl implements AiClientFactory { } @Override - public ChatModel getDefaultChatClient(AiPlatformEnum platform) { + public ChatModel getDefaultChatModel(AiPlatformEnum platform) { //noinspection EnhancedSwitchMigration switch (platform) { - case OPENAI: - return SpringUtil.getBean(OpenAiChatModel.class); case OLLAMA: return SpringUtil.getBean(OllamaChatModel.class); case YI_YAN: @@ -87,13 +85,15 @@ public class AiClientFactoryImpl implements AiClientFactory { return SpringUtil.getBean(XingHuoChatClient.class); case QIAN_WEN: return SpringUtil.getBean(TongYiChatModel.class); + case OPENAI: + return SpringUtil.getBean(OpenAiChatModel.class); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); } } @Override - public ImageModel getDefaultImageClient(AiPlatformEnum platform) { + public ImageModel getDefaultImageModel(AiPlatformEnum platform) { //noinspection EnhancedSwitchMigration switch (platform) { case OPENAI: @@ -106,11 +106,11 @@ public class AiClientFactoryImpl implements AiClientFactory { } @Override - public ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) { + public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) { //noinspection EnhancedSwitchMigration switch (platform) { case OPENAI: - return buildOpenAiImageClient(apiKey, url); + return buildOpenAiImageModel(apiKey, url); case STABLE_DIFFUSION: return buildStabilityAiImageClient(apiKey, url); default: @@ -145,12 +145,21 @@ public class AiClientFactoryImpl implements AiClientFactory { /** * 可参考 {@link OpenAiAutoConfiguration} */ - private static OpenAiChatModel buildOpenAiChatClient(String openAiToken, String url) { + private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) { url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); OpenAiApi openAiApi = new OpenAiApi(url, openAiToken); return new OpenAiChatModel(openAiApi); } + /** + * 可参考 {@link OpenAiAutoConfiguration} + */ + private OpenAiImageModel buildOpenAiImageModel(String openAiToken, String url) { + url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); + OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder()); + return new OpenAiImageModel(openAiApi); + } + /** * 可参考 {@link OllamaAutoConfiguration} */ @@ -200,12 +209,6 @@ public class AiClientFactoryImpl implements AiClientFactory { return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); } - private OpenAiImageModel buildOpenAiImageClient(String openAiToken, String url) { - url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); - OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder()); - return new OpenAiImageModel(openAiApi); - } - private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) { url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL); StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);