From 97df2755f985fa7d9d92613a3236e4933d0480ad Mon Sep 17 00:00:00 2001 From: cherishsince Date: Sun, 14 Apr 2024 13:39:09 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0yudao=20ai=20client?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/chatqianwen/QianWenChatClient.java | 2 + .../yudao/framework/ai/config/AiClient.java | 19 +++ .../ai/config/YudaoAiAutoConfiguration.java | 112 ++++++++---------- .../framework/ai/config/YudaoAiClient.java | 44 +++++++ 4 files changed, 117 insertions(+), 60 deletions(-) create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/AiClient.java create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiClient.java diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java index a8fb9105f..7bab37586 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java @@ -36,6 +36,8 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient { private QianWenOptions qianWenOptions; + + public QianWenChatClient() {} public QianWenChatClient(QianWenApi qianWenApi) { this.qianWenApi = qianWenApi; } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/AiClient.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/AiClient.java new file mode 100644 index 000000000..f976dcfbf --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/AiClient.java @@ -0,0 +1,19 @@ +package cn.iocoder.yudao.framework.ai.config; + +import cn.iocoder.yudao.framework.ai.chat.ChatResponse; +import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; +import reactor.core.publisher.Flux; + +/** + * ai client传入 + * + * @author fansili + * @time 2024/4/14 10:27 + * @since 1.0 + */ +public interface AiClient { + + ChatResponse call(Prompt prompt, String clientName); + + Flux stream(Prompt prompt, String clientName); +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java index 751b9df06..7e734cd2a 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java @@ -14,7 +14,6 @@ import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi; import cn.iocoder.yudao.framework.ai.exception.AiException; import org.springframework.beans.BeansException; import org.springframework.beans.factory.InitializingBean; -import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; @@ -23,6 +22,7 @@ import org.springframework.context.ApplicationContextAware; import org.springframework.context.annotation.Bean; import org.springframework.context.support.GenericApplicationContext; +import java.util.HashMap; import java.util.Map; /** @@ -36,11 +36,33 @@ import java.util.Map; @EnableConfigurationProperties(YudaoAiProperties.class) public class YudaoAiAutoConfiguration { - // TODO @芋艿:我看sharding jdbc 差不多这么玩的 @Bean - @ConditionalOnMissingBean(value = InitChatClient.class) - public InitChatClient initChatClient(YudaoAiProperties yudaoAiProperties) { - return new InitChatClient(yudaoAiProperties); + @ConditionalOnMissingBean(value = AiClient.class) + public AiClient aiClient(YudaoAiProperties yudaoAiProperties) { + Map chatClientMap = buildChatClientMap(yudaoAiProperties); + return new YudaoAiClient(chatClientMap); + } + + public Map buildChatClientMap(YudaoAiProperties yudaoAiProperties) { + Map chatMap = new HashMap<>(); + for (Map.Entry> properties : yudaoAiProperties.entrySet()) { + String beanName = properties.getKey(); + Map aiPlatformMap = properties.getValue(); + + // 检查平台类型是否正确 + String aiPlatform = String.valueOf(aiPlatformMap.get("aiPlatform")); + if (!AiPlatformEnum.mapValues.containsKey(aiPlatform)) { + throw new AiException("AI平台名称错误! 可以参考 AiPlatformEnum 类!"); + } + // 获取平台类型 + AiPlatformEnum aiPlatformEnum = AiPlatformEnum.mapValues.get(aiPlatform); + // 获取 chat properties + YudaoAiProperties.ChatProperties chatProperties = getChatProperties(aiPlatformEnum, aiPlatformMap); + // 创建客户端 + Object chatClient = createChatClient(chatProperties); + chatMap.put(beanName, chatClient); + } + return chatMap; } public static class InitChatClient implements InitializingBean, ApplicationContextAware { @@ -53,26 +75,8 @@ public class YudaoAiAutoConfiguration { } @Override - public void afterPropertiesSet() throws Exception { - for (Map.Entry> properties : yudaoAiProperties.entrySet()) { - String beanName = properties.getKey(); - Map aiPlatformMap = properties.getValue(); + public void afterPropertiesSet() { - // 检查平台类型是否正确 - String aiPlatform = String.valueOf(aiPlatformMap.get("aiPlatform")); - if (!AiPlatformEnum.mapValues.containsKey(aiPlatform)) { - throw new AiException("AI平台名称错误! 可以参考 AiPlatformEnum 类!"); - } - // 获取平台类型 - AiPlatformEnum aiPlatformEnum = AiPlatformEnum.mapValues.get(aiPlatform); - // 获取 chat properties - YudaoAiProperties.ChatProperties chatProperties = getChatProperties(aiPlatformEnum, aiPlatformMap); - // 创建客户端 - registerChatClient(applicationContext, chatProperties, beanName); -// applicationContext.refresh(); - - - } System.err.println(applicationContext.getBean("qianWen")); System.err.println(applicationContext.getBean("yiYan")); @@ -84,53 +88,41 @@ public class YudaoAiAutoConfiguration { } } - private static void registerChatClient(GenericApplicationContext applicationContext, YudaoAiProperties.ChatProperties chatProperties, String beanName) { - ConfigurableListableBeanFactory beanFactory = applicationContext.getBeanFactory(); - Object wrapperBean = null; + private static Object createChatClient(YudaoAiProperties.ChatProperties chatProperties) { if (AiPlatformEnum.XING_HUO == chatProperties.getAiPlatform()) { YudaoAiProperties.XingHuoProperties xingHuoProperties = (YudaoAiProperties.XingHuoProperties) chatProperties; - wrapperBean = beanFactory.initializeBean( - new XingHuoChatClient( - new XingHuoApi( - xingHuoProperties.getAppId(), - xingHuoProperties.getAppKey(), - xingHuoProperties.getSecretKey() - ), - new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel()) + return new XingHuoChatClient( + new XingHuoApi( + xingHuoProperties.getAppId(), + xingHuoProperties.getAppKey(), + xingHuoProperties.getSecretKey() ), - beanName + new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel()) ); } else if (AiPlatformEnum.QIAN_WEN == chatProperties.getAiPlatform()) { YudaoAiProperties.QianWenProperties qianWenProperties = (YudaoAiProperties.QianWenProperties) chatProperties; - wrapperBean = beanFactory.initializeBean(new QianWenChatClient( - new QianWenApi( - qianWenProperties.getAccessKeyId(), - qianWenProperties.getAccessKeySecret(), - qianWenProperties.getAgentKey(), - qianWenProperties.getEndpoint() - ), - new QianWenOptions() - .setAppId(qianWenProperties.getAppId()) + return new QianWenChatClient( + new QianWenApi( + qianWenProperties.getAccessKeyId(), + qianWenProperties.getAccessKeySecret(), + qianWenProperties.getAgentKey(), + qianWenProperties.getEndpoint() ), - beanName + new QianWenOptions() + .setAppId(qianWenProperties.getAppId()) ); } else if (AiPlatformEnum.YI_YAN == chatProperties.getAiPlatform()) { YudaoAiProperties.YiYanProperties yiYanProperties = (YudaoAiProperties.YiYanProperties) chatProperties; - - wrapperBean = beanFactory.initializeBean(new YiYanChatClient( - new YiYanApi( - yiYanProperties.getAppKey(), - yiYanProperties.getSecretKey(), - yiYanProperties.getChatModel(), - yiYanProperties.getRefreshTokenSecondTime() - ), - new YiYanOptions().setMax_output_tokens(2048)), - beanName - ); - } - if (wrapperBean != null) { - beanFactory.registerSingleton(beanName, wrapperBean); + return new YiYanChatClient( + new YiYanApi( + yiYanProperties.getAppKey(), + yiYanProperties.getSecretKey(), + yiYanProperties.getChatModel(), + yiYanProperties.getRefreshTokenSecondTime() + ), + new YiYanOptions().setMax_output_tokens(2048)); } + throw new AiException("不支持的Ai类型!"); } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiClient.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiClient.java new file mode 100644 index 000000000..2f584c40a --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiClient.java @@ -0,0 +1,44 @@ +package cn.iocoder.yudao.framework.ai.config; + +import cn.iocoder.yudao.framework.ai.chat.ChatClient; +import cn.iocoder.yudao.framework.ai.chat.ChatResponse; +import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient; +import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; +import cn.iocoder.yudao.framework.ai.exception.AiException; +import reactor.core.publisher.Flux; + +import java.util.Map; + +/** + * yudao ai client + * + * @author fansili + * @time 2024/4/14 10:27 + * @since 1.0 + */ +public class YudaoAiClient implements AiClient{ + + protected Map chatClientMap; + + public YudaoAiClient(Map chatClientMap) { + this.chatClientMap = chatClientMap; + } + + @Override + public ChatResponse call(Prompt prompt, String clientName) { + if (!chatClientMap.containsKey(clientName)) { + throw new AiException("clientName不存在!"); + } + ChatClient chatClient = (ChatClient) chatClientMap.get(clientName); + return chatClient.call(prompt); + } + + @Override + public Flux stream(Prompt prompt, String clientName) { + if (!chatClientMap.containsKey(clientName)) { + throw new AiException("clientName不存在!"); + } + StreamingChatClient streamingChatClient = (StreamingChatClient) chatClientMap.get(clientName); + return streamingChatClient.stream(prompt); + } +}