增加yudao ai client

This commit is contained in:
cherishsince 2024-04-14 13:39:09 +08:00
parent 652a8f9633
commit 97df2755f9
4 changed files with 117 additions and 60 deletions

View File

@ -36,6 +36,8 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
private QianWenOptions qianWenOptions;
public QianWenChatClient() {}
public QianWenChatClient(QianWenApi qianWenApi) {
this.qianWenApi = qianWenApi;
}

View File

@ -0,0 +1,19 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
/**
* ai client传入
*
* @author fansili
* @time 2024/4/14 10:27
* @since 1.0
*/
public interface AiClient {
ChatResponse call(Prompt prompt, String clientName);
Flux<ChatResponse> stream(Prompt prompt, String clientName);
}

View File

@ -14,7 +14,6 @@ import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
import cn.iocoder.yudao.framework.ai.exception.AiException;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
@ -23,6 +22,7 @@ import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Bean;
import org.springframework.context.support.GenericApplicationContext;
import java.util.HashMap;
import java.util.Map;
/**
@ -36,11 +36,33 @@ import java.util.Map;
@EnableConfigurationProperties(YudaoAiProperties.class)
public class YudaoAiAutoConfiguration {
// TODO @芋艿我看sharding jdbc 差不多这么玩的
@Bean
@ConditionalOnMissingBean(value = InitChatClient.class)
public InitChatClient initChatClient(YudaoAiProperties yudaoAiProperties) {
return new InitChatClient(yudaoAiProperties);
@ConditionalOnMissingBean(value = AiClient.class)
public AiClient aiClient(YudaoAiProperties yudaoAiProperties) {
Map<String, Object> chatClientMap = buildChatClientMap(yudaoAiProperties);
return new YudaoAiClient(chatClientMap);
}
public Map<String, Object> buildChatClientMap(YudaoAiProperties yudaoAiProperties) {
Map<String, Object> chatMap = new HashMap<>();
for (Map.Entry<String, Map<String, Object>> properties : yudaoAiProperties.entrySet()) {
String beanName = properties.getKey();
Map<String, Object> aiPlatformMap = properties.getValue();
// 检查平台类型是否正确
String aiPlatform = String.valueOf(aiPlatformMap.get("aiPlatform"));
if (!AiPlatformEnum.mapValues.containsKey(aiPlatform)) {
throw new AiException("AI平台名称错误! 可以参考 AiPlatformEnum 类!");
}
// 获取平台类型
AiPlatformEnum aiPlatformEnum = AiPlatformEnum.mapValues.get(aiPlatform);
// 获取 chat properties
YudaoAiProperties.ChatProperties chatProperties = getChatProperties(aiPlatformEnum, aiPlatformMap);
// 创建客户端
Object chatClient = createChatClient(chatProperties);
chatMap.put(beanName, chatClient);
}
return chatMap;
}
public static class InitChatClient implements InitializingBean, ApplicationContextAware {
@ -53,26 +75,8 @@ public class YudaoAiAutoConfiguration {
}
@Override
public void afterPropertiesSet() throws Exception {
for (Map.Entry<String, Map<String, Object>> properties : yudaoAiProperties.entrySet()) {
String beanName = properties.getKey();
Map<String, Object> aiPlatformMap = properties.getValue();
public void afterPropertiesSet() {
// 检查平台类型是否正确
String aiPlatform = String.valueOf(aiPlatformMap.get("aiPlatform"));
if (!AiPlatformEnum.mapValues.containsKey(aiPlatform)) {
throw new AiException("AI平台名称错误! 可以参考 AiPlatformEnum 类!");
}
// 获取平台类型
AiPlatformEnum aiPlatformEnum = AiPlatformEnum.mapValues.get(aiPlatform);
// 获取 chat properties
YudaoAiProperties.ChatProperties chatProperties = getChatProperties(aiPlatformEnum, aiPlatformMap);
// 创建客户端
registerChatClient(applicationContext, chatProperties, beanName);
// applicationContext.refresh();
}
System.err.println(applicationContext.getBean("qianWen"));
System.err.println(applicationContext.getBean("yiYan"));
@ -84,53 +88,41 @@ public class YudaoAiAutoConfiguration {
}
}
private static void registerChatClient(GenericApplicationContext applicationContext, YudaoAiProperties.ChatProperties chatProperties, String beanName) {
ConfigurableListableBeanFactory beanFactory = applicationContext.getBeanFactory();
Object wrapperBean = null;
private static Object createChatClient(YudaoAiProperties.ChatProperties chatProperties) {
if (AiPlatformEnum.XING_HUO == chatProperties.getAiPlatform()) {
YudaoAiProperties.XingHuoProperties xingHuoProperties = (YudaoAiProperties.XingHuoProperties) chatProperties;
wrapperBean = beanFactory.initializeBean(
new XingHuoChatClient(
new XingHuoApi(
xingHuoProperties.getAppId(),
xingHuoProperties.getAppKey(),
xingHuoProperties.getSecretKey()
),
new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel())
return new XingHuoChatClient(
new XingHuoApi(
xingHuoProperties.getAppId(),
xingHuoProperties.getAppKey(),
xingHuoProperties.getSecretKey()
),
beanName
new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel())
);
} else if (AiPlatformEnum.QIAN_WEN == chatProperties.getAiPlatform()) {
YudaoAiProperties.QianWenProperties qianWenProperties = (YudaoAiProperties.QianWenProperties) chatProperties;
wrapperBean = beanFactory.initializeBean(new QianWenChatClient(
new QianWenApi(
qianWenProperties.getAccessKeyId(),
qianWenProperties.getAccessKeySecret(),
qianWenProperties.getAgentKey(),
qianWenProperties.getEndpoint()
),
new QianWenOptions()
.setAppId(qianWenProperties.getAppId())
return new QianWenChatClient(
new QianWenApi(
qianWenProperties.getAccessKeyId(),
qianWenProperties.getAccessKeySecret(),
qianWenProperties.getAgentKey(),
qianWenProperties.getEndpoint()
),
beanName
new QianWenOptions()
.setAppId(qianWenProperties.getAppId())
);
} else if (AiPlatformEnum.YI_YAN == chatProperties.getAiPlatform()) {
YudaoAiProperties.YiYanProperties yiYanProperties = (YudaoAiProperties.YiYanProperties) chatProperties;
wrapperBean = beanFactory.initializeBean(new YiYanChatClient(
new YiYanApi(
yiYanProperties.getAppKey(),
yiYanProperties.getSecretKey(),
yiYanProperties.getChatModel(),
yiYanProperties.getRefreshTokenSecondTime()
),
new YiYanOptions().setMax_output_tokens(2048)),
beanName
);
}
if (wrapperBean != null) {
beanFactory.registerSingleton(beanName, wrapperBean);
return new YiYanChatClient(
new YiYanApi(
yiYanProperties.getAppKey(),
yiYanProperties.getSecretKey(),
yiYanProperties.getChatModel(),
yiYanProperties.getRefreshTokenSecondTime()
),
new YiYanOptions().setMax_output_tokens(2048));
}
throw new AiException("不支持的Ai类型!");
}

View File

@ -0,0 +1,44 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.chat.ChatClient;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.exception.AiException;
import reactor.core.publisher.Flux;
import java.util.Map;
/**
* yudao ai client
*
* @author fansili
* @time 2024/4/14 10:27
* @since 1.0
*/
public class YudaoAiClient implements AiClient{
protected Map<String, Object> chatClientMap;
public YudaoAiClient(Map<String, Object> chatClientMap) {
this.chatClientMap = chatClientMap;
}
@Override
public ChatResponse call(Prompt prompt, String clientName) {
if (!chatClientMap.containsKey(clientName)) {
throw new AiException("clientName不存在!");
}
ChatClient chatClient = (ChatClient) chatClientMap.get(clientName);
return chatClient.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt, String clientName) {
if (!chatClientMap.containsKey(clientName)) {
throw new AiException("clientName不存在!");
}
StreamingChatClient streamingChatClient = (StreamingChatClient) chatClientMap.get(clientName);
return streamingChatClient.stream(prompt);
}
}