From 2fefcf8834ede9c73fbd64a16d04fe2ffd25a425 Mon Sep 17 00:00:00 2001 From: YunaiV Date: Wed, 22 May 2024 12:37:21 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E6=96=B0=E5=A2=9E=E3=80=91AI=EF=BC=9A?= =?UTF-8?q?=E9=80=9A=E8=BF=87=20AiClientFactory=20=E6=8F=90=E4=BE=9B=20cha?= =?UTF-8?q?tclient?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../module/ai/config/AiChatClientFactory.java | 57 ------ .../ai/service/impl/AiChatServiceImpl.java | 55 +++--- .../ai/service/model/AiApiKeyService.java | 11 ++ .../ai/service/model/AiApiKeyServiceImpl.java | 15 ++ .../ai/config/YudaoAiAutoConfiguration.java | 30 ++-- .../ai/core/factory/AiClientFactory.java | 47 +++++ .../ai/core/factory/AiClientFactoryImpl.java | 167 ++++++++++++++++++ .../ai/core/model/tongyi/QianWenOptions.java | 2 + .../ai/core/model/xinghuo/XingHuoOptions.java | 2 +- .../ai/core/model/yiyan/YiYanChatOptions.java | 1 + .../ai/core/model/yiyan/api/YiYanApi.java | 3 +- .../src/main/resources/application.yaml | 7 - 12 files changed, 289 insertions(+), 108 deletions(-) delete mode 100644 yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/config/AiChatClientFactory.java create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/config/AiChatClientFactory.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/config/AiChatClientFactory.java deleted file mode 100644 index 5efd264a9..000000000 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/config/AiChatClientFactory.java +++ /dev/null @@ -1,57 +0,0 @@ -package cn.iocoder.yudao.module.ai.config; - -import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import org.springframework.ai.chat.ChatClient; -import org.springframework.ai.chat.StreamingChatClient; -import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient; -import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient; -import org.springframework.ai.ollama.OllamaChatClient; -import org.springframework.ai.openai.OpenAiChatClient; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.context.ApplicationContext; -import org.springframework.stereotype.Component; - -/** - * factory - * - * @author fansili - * @time 2024/4/25 17:36 - * @since 1.0 - */ -@Component -public class AiChatClientFactory { - - @Autowired - private ApplicationContext applicationContext; - - public ChatClient getChatClient(AiPlatformEnum platformEnum) { - if (AiPlatformEnum.QIAN_WEN == platformEnum) { - return applicationContext.getBean(QianWenChatClient.class); - } else if (AiPlatformEnum.YI_YAN == platformEnum) { - return applicationContext.getBean(YiYanChatClient.class); - } else if (AiPlatformEnum.XING_HUO == platformEnum) { - return applicationContext.getBean(XingHuoChatClient.class); - } - throw new IllegalArgumentException("不支持的 chat client!"); - } - - // TODO yunai 要不再加一个接口,让他们拥有 ChatClient、StreamingChatClient 功能 - public StreamingChatClient getStreamingChatClient(AiPlatformEnum platformEnum) { -// if (true) { -// return applicationContext.getBean(OllamaChatClient.class); -// } - if (AiPlatformEnum.QIAN_WEN == platformEnum) { - return applicationContext.getBean(QianWenChatClient.class); - } else if (AiPlatformEnum.YI_YAN == platformEnum) { - return applicationContext.getBean(YiYanChatClient.class); - } else if (AiPlatformEnum.XING_HUO == platformEnum) { - return applicationContext.getBean(XingHuoChatClient.class); - } else if (AiPlatformEnum.OLLAMA == platformEnum) { - return applicationContext.getBean(OllamaChatClient.class); - } else if (AiPlatformEnum.OPENAI == platformEnum) { - return applicationContext.getBean(OpenAiChatClient.class); - } - throw new IllegalArgumentException("不支持的 chat client!"); - } -} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java index c603f43e8..32511141c 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java @@ -4,13 +4,20 @@ import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; +import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; +import jakarta.annotation.Resource; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert; -import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; @@ -19,12 +26,7 @@ import cn.iocoder.yudao.module.ai.service.AiChatService; import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; -import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.chat.messages.*; -import org.springframework.ai.chat.prompt.Prompt; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; @@ -46,16 +48,22 @@ import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NO */ @Slf4j @Service -@AllArgsConstructor public class AiChatServiceImpl implements AiChatService { - private final AiChatClientFactory chatClientFactory; + @Resource + private AiChatMessageMapper chatMessageMapper; - private final AiChatMessageMapper chatMessageMapper; + @Resource + private AiClientFactory clientFactory; - private final AiChatConversationService chatConversationService; - private final AiChatModelService chatModalService; - private final AiChatRoleService chatRoleService; + @Resource + private AiChatConversationService chatConversationService; + @Resource + private AiChatModelService chatModalService; + @Resource + private AiChatRoleService chatRoleService; + @Resource + private AiApiKeyService apiKeyService; @Transactional(rollbackFor = Exception.class) public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) { @@ -106,8 +114,7 @@ public class AiChatServiceImpl implements AiChatService { List historyMessages = chatMessageMapper.selectByConversationId(conversation.getId()); // 1.2 校验模型 AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); - AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); - StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform); + StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(model.getKeyId()); // 2. 插入 user 发送消息 AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, @@ -118,13 +125,13 @@ public class AiChatServiceImpl implements AiChatService { userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); // 3.2 创建 chat 需要的 Prompt - Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO); + Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); Flux streamResponse = chatClient.stream(prompt); // 3.3 流式返回 // 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题 StringBuffer contentBuffer = new StringBuffer(); - return streamResponse.publishOn(Schedulers.immediate()).map(chunk -> { + return streamResponse.publishOn(Schedulers.single()).map(chunk -> { String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况 contentBuffer.append(newContent); @@ -144,7 +151,8 @@ public class AiChatServiceImpl implements AiChatService { return chatMessageMapper.deleteByConversationId(conversationId) > 0; } - private Prompt buildPrompt(AiChatConversationDO conversation, List messages, AiChatMessageSendReqVO sendReqVO) { + private Prompt buildPrompt(AiChatConversationDO conversation, List messages, + AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) { // 1. 构建 Prompt Message 列表 List chatMessages = new ArrayList<>(); // 1.1 system context 角色设定 @@ -156,10 +164,11 @@ public class AiChatServiceImpl implements AiChatService { chatMessages.add(new UserMessage(sendReqVO.getContent())); // 2. 构建 ChatOptions 对象 TODO 芋艿:临时注释掉;等文心一言兼容了; - // TODO 每一轮 token 数量 -// ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build(); -// return new Prompt(chatMessages, null); - return new Prompt(chatMessages); + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + ChatOptions chatOptions = clientFactory.buildChatOptions(platform, model.getModel(), + conversation.getTemperature(), conversation.getMaxTokens()); + return new Prompt(chatMessages, chatOptions); +// return new Prompt(chatMessages); } /** diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java index 331dd62e2..8056eab78 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java @@ -5,6 +5,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageR import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import jakarta.validation.Valid; +import org.springframework.ai.chat.StreamingChatClient; import java.util.List; @@ -68,4 +69,14 @@ public interface AiApiKeyService { */ List getApiKeyList(); + // ========== 与 spring-ai 集成 ========== + + /** + * 获得 StreamingChatClient 对象 + * + * @param id 编号 + * @return StreamingChatClient 对象 + */ + StreamingChatClient getStreamingChatClient(Long id); + } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java index f8a83ce57..e4db8125b 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java @@ -1,5 +1,7 @@ package cn.iocoder.yudao.module.ai.service.model; +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; @@ -8,6 +10,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import jakarta.annotation.Resource; +import org.springframework.ai.chat.StreamingChatClient; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; @@ -28,6 +31,9 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { @Resource private AiApiKeyMapper apiKeyMapper; + @Resource + private AiClientFactory clientFactory; + @Override public Long createApiKey(AiApiKeySaveReqVO createReqVO) { // 插入 @@ -86,4 +92,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { return apiKeyMapper.selectList(); } + // ========== 与 spring-ai 集成 ========== + + @Override + public StreamingChatClient getStreamingChatClient(Long id) { + AiApiKeyDO apiKey = validateApiKey(id); + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); + return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); + } + } \ No newline at end of file diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java index f942b12bd..55705c40f 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java @@ -1,6 +1,8 @@ package cn.iocoder.yudao.framework.ai.config; import cn.hutool.core.io.IoUtil; +import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; +import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions; @@ -36,17 +38,22 @@ import java.util.HashMap; import java.util.Map; /** - * ai 自动配置 + * 芋道 AI 自动配置 * * @author fansili - * @time 2024/4/12 16:29 - * @since 1.0 */ -@Slf4j @AutoConfiguration @EnableConfigurationProperties(YudaoAiProperties.class) +@Slf4j public class YudaoAiAutoConfiguration { + @Bean + public AiClientFactory aiClientFactory() { + return new AiClientFactoryImpl(); + } + + // ========== 各种 AI Client 创建 ========== + @Bean @ConditionalOnProperty(value = "yudao.ai.xinghuo.enable", havingValue = "true") public XingHuoChatClient xingHuoChatClient(YudaoAiProperties yudaoAiProperties) { @@ -107,21 +114,6 @@ public class YudaoAiAutoConfiguration { ); } - @Bean - @ConditionalOnProperty(value = "yudao.ai.openAiImage.enable", havingValue = "true") - public OpenAiImageClient openAiImageClient(YudaoAiProperties yudaoAiProperties) { - YudaoAiProperties.OpenAiImageProperties openAiImageProperties = yudaoAiProperties.getOpenAiImage(); - OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions(); - openAiImageOptions.setModel(openAiImageProperties.getModel().getModel()); - openAiImageOptions.setStyle(openAiImageProperties.getStyle().getStyle()); - openAiImageOptions.setResponseFormat("url"); // TODO 芋艿:OpenAiImageOptions.ResponseFormatEnum.URL.getValue() - // 创建 client - return new OpenAiImageClient( - new OpenAiImageApi(openAiImageProperties.getApiKey()), - openAiImageOptions, - RetryUtils.DEFAULT_RETRY_TEMPLATE); - } - @Bean @ConditionalOnMissingBean(value = MidjourneyMessageHandler.class) public MidjourneyMessageHandler defaultMidjourneyMessageHandler() { diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java new file mode 100644 index 000000000..98707fdc8 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java @@ -0,0 +1,47 @@ +package cn.iocoder.yudao.framework.ai.core.factory; + +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.chat.prompt.ChatOptions; + +/** + * AI 客户端工厂的接口类 + * + * @author fansili + */ +public interface AiClientFactory { + + /** + * 基于指定配置,获得 StreamingChatClient 对象 + * + * 如果不存在,则进行创建 + * + * @param platform 平台 + * @param apiKey API KEY + * @param url API URL + * @return StreamingChatClient 对象 + */ + StreamingChatClient getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url); + + /** + * 基于默认配置,获得 StreamingChatClient 对象 + * + * 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置 + * + * @param platform 平台 + * @return StreamingChatClient 对象 + */ + StreamingChatClient getDefaultStreamingChatClient(AiPlatformEnum platform); + + /** + * 创建 Chat 参数 + * + * @param platform 平台 + * @param model 模型 + * @param temperature 温度 + * @param maxTokens 生成的最大 Token + * @return Chat 参数 + */ + ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens); + +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java new file mode 100644 index 000000000..b005d7c3e --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java @@ -0,0 +1,167 @@ +package cn.iocoder.yudao.framework.ai.core.factory; + +import cn.hutool.core.lang.Assert; +import cn.hutool.core.lang.Singleton; +import cn.hutool.core.lang.func.Func0; +import cn.hutool.core.util.ArrayUtil; +import cn.hutool.core.util.StrUtil; +import cn.hutool.extra.spring.SpringUtil; +import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration; +import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties; +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient; +import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal; +import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions; +import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi; +import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient; +import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions; +import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi; +import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; +import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; +import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.openai.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.ApiUtils; +import org.springframework.ai.openai.api.OpenAiApi; + +import java.util.List; + +/** + * AI 客户端工厂的实现类 + * + * @author 芋道源码 + */ +public class AiClientFactoryImpl implements AiClientFactory { + + @Override + public StreamingChatClient getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) { + String cacheKey = buildClientCacheKey(StreamingChatClient.class, platform, apiKey, url); + return Singleton.get(cacheKey, (Func0) () -> { + //noinspection EnhancedSwitchMigration + switch (platform) { + case OPENAI: + return buildOpenAiChatClient(apiKey, url); + case OLLAMA: + return buildOllamaChatClient(url); + case YI_YAN: + return buildYiYanChatClient(apiKey); + case XING_HUO: + return buildXingHuoChatClient(apiKey); + case QIAN_WEN: + return buildQianWenChatClient(apiKey); + default: + throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); + } + }); + } + + @Override + public StreamingChatClient getDefaultStreamingChatClient(AiPlatformEnum platform) { + //noinspection EnhancedSwitchMigration + switch (platform) { + case OPENAI: + return SpringUtil.getBean(OpenAiChatClient.class); + case OLLAMA: + return SpringUtil.getBean(OllamaChatClient.class); + case YI_YAN: + return SpringUtil.getBean(YiYanChatClient.class); + case XING_HUO: + return SpringUtil.getBean(XingHuoChatClient.class); + case QIAN_WEN: + return SpringUtil.getBean(QianWenChatClient.class); + default: + throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); + } + } + + private static String buildClientCacheKey(Class clazz, Object... params) { + if (ArrayUtil.isEmpty(params)) { + return clazz.getName(); + } + return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_")); + } + + @Override + public ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { + Float temperatureF = temperature != null ? temperature.floatValue() : null; + //noinspection EnhancedSwitchMigration + switch (platform) { + case OPENAI: + return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + case OLLAMA: + return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); + case YI_YAN: + // TODO @fan:增加一个 model + return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens); + case XING_HUO: + return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) + .setMaxTokens(maxTokens); + case QIAN_WEN: + // TODO @fan:增加 model、temperature 参数 + return new QianWenOptions().setMaxTokens(maxTokens); + default: + throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); + } + } + + // ========== 各种创建 spring-ai 客户端的方法 ========== + + /** + * 可参考 {@link OpenAiAutoConfiguration} + */ + private static OpenAiChatClient buildOpenAiChatClient(String openAiToken, String url) { + url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); + OpenAiApi openAiApi = new OpenAiApi(url, openAiToken); + return new OpenAiChatClient(openAiApi); + } + + /** + * 可参考 {@link OllamaAutoConfiguration} + */ + private static OllamaChatClient buildOllamaChatClient(String url) { + OllamaApi ollamaApi = new OllamaApi(url); + return new OllamaChatClient(ollamaApi); + } + + /** + * 可参考 {@link YudaoAiAutoConfiguration#yiYanChatClient(YudaoAiProperties)} + */ + private static YiYanChatClient buildYiYanChatClient(String key) { + List keys = StrUtil.split(key, '|'); + Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式"); + String appKey = keys.get(0); + String secretKey = keys.get(1); + YiYanApi yiYanApi = new YiYanApi(appKey, secretKey, YiYanApi.DEFAULT_CHAT_MODEL, 0); + return new YiYanChatClient(yiYanApi); + } + + /** + * 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)} + */ + private static XingHuoChatClient buildXingHuoChatClient(String key) { + List keys = StrUtil.split(key, '|'); + Assert.equals(keys.size(), 2, "XingHuoChatClient 的密钥需要 (appKey|secretKey) 格式"); + String appId = keys.get(0); + String appKey = keys.get(1); + String secretKey = keys.get(2); + XingHuoApi xingHuoApi = new XingHuoApi(appId, appKey, secretKey); + return new XingHuoChatClient(xingHuoApi); + } + + /** + * 可参考 {@link YudaoAiAutoConfiguration#qianWenChatClient(YudaoAiProperties)} + */ + private static QianWenChatClient buildQianWenChatClient(String key) { + QianWenApi qianWenApi = new QianWenApi(key, QianWenChatModal.QWEN_72B_CHAT); + return new QianWenChatClient(qianWenApi); + } + +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/tongyi/QianWenOptions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/tongyi/QianWenOptions.java index b6dba53c9..4f7632c97 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/tongyi/QianWenOptions.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/tongyi/QianWenOptions.java @@ -6,6 +6,8 @@ import lombok.experimental.Accessors; import java.util.List; +// TODO @fan:增加一个 model 参数 +// TODO @fan:增加一个 Temperature 参数 /** * 阿里云 千问 属性 * diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoOptions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoOptions.java index cb4753833..ccec8598d 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoOptions.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoOptions.java @@ -14,6 +14,7 @@ import lombok.experimental.Accessors; @Accessors(chain = true) public class XingHuoOptions implements ChatOptions { + // TODO @fan:这里 model 参数,然后使用 string /** * https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E *

@@ -43,7 +44,6 @@ public class XingHuoOptions implements ChatOptions { */ private String chatId; - @Override public Float getTemperature() { return this.temperature; diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/yiyan/YiYanChatOptions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/yiyan/YiYanChatOptions.java index 14146a322..a84b0ec98 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/yiyan/YiYanChatOptions.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/yiyan/YiYanChatOptions.java @@ -6,6 +6,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import java.util.List; +// TODO @fan:增加一个 model // TODO @fan:字段命名,penalty_score 类似的,建议改成驼峰原则 // TODO @fan:字段的注释,可以都删除掉,让用户 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t 即可 /** diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/yiyan/api/YiYanApi.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/yiyan/api/YiYanApi.java index 535901f25..93cd13fe1 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/yiyan/api/YiYanApi.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/yiyan/api/YiYanApi.java @@ -18,7 +18,7 @@ public class YiYanApi { private static final String AUTH_2_TOKEN_URI = "/oauth/2.0/token"; - public static final String DEFAULT_CHAT_MODEL = YiYanChatModel.ERNIE4_0.getModel(); + public static final YiYanChatModel DEFAULT_CHAT_MODEL = YiYanChatModel.ERNIE4_0; private final String appKey; private final String secretKey; @@ -39,6 +39,7 @@ public class YiYanApi { */ private final YiYanChatModel useChatModel; + // TODO fan:看看是不是去掉 refreshTokenSecondTime 字段 public YiYanApi(String appKey, String secretKey, YiYanChatModel useChatModel, int refreshTokenSecondTime) { this.appKey = appKey; this.secretKey = secretKey; diff --git a/yudao-server/src/main/resources/application.yaml b/yudao-server/src/main/resources/application.yaml index 6bb087d81..cc5ed76a3 100644 --- a/yudao-server/src/main/resources/application.yaml +++ b/yudao-server/src/main/resources/application.yaml @@ -150,15 +150,8 @@ spring.ai: chat: model: llama3 openai: -# api-key: sk-QmgIIPc5xiYd8lPb076b1b7774Ea49Af9eD2Ef172c8f7e43 -# base-url: https://openkey.cloud -# api-key: sk-gkgfYxhX9FxyZJznwxRZSJwKeGQYNPDVWjhby2PRRf17GHeT -# base-url: https://api.chatanywhere.tech api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z base-url: https://api.gptsapi.net -# chat: -# options: -# model: gpt-4-0125-preview yudao.ai: yiyan: