优化自动注入,可以创建多个 client

This commit is contained in:
cherishsince 2024-04-13 18:20:28 +08:00
parent 02ac6f30cf
commit ac0de5d485
7 changed files with 320 additions and 91 deletions

View File

@ -1,77 +1,34 @@
//package cn.iocoder.yudao.module.ai.controller.admin; package cn.iocoder.yudao.module.ai.controller.admin;
//
//import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
//import cn.iocoder.yudao.framework.common.pojo.CommonResult; import io.swagger.v3.oas.annotations.tags.Tag;
//import cn.iocoder.yudao.module.ai.ErrorCodeConstants; import lombok.AllArgsConstructor;
//import cn.iocoder.yudao.module.ai.controller.admin.vo.AiChatReqVO; import lombok.extern.slf4j.Slf4j;
//import cn.iocoder.yudao.module.ai.enums.OpenAiModelEnum; import org.springframework.beans.factory.annotation.Autowired;
//import io.swagger.v3.oas.annotations.Operation; import org.springframework.web.bind.annotation.GetMapping;
//import io.swagger.v3.oas.annotations.tags.Tag; import org.springframework.web.bind.annotation.RequestMapping;
//import jakarta.servlet.http.HttpServletResponse; import org.springframework.web.bind.annotation.RequestParam;
//import lombok.extern.slf4j.Slf4j; import org.springframework.web.bind.annotation.RestController;
//import org.springframework.ai.chat.ChatClient;
//import org.springframework.ai.chat.ChatResponse; /**
//import org.springframework.ai.chat.prompt.Prompt; * @author fansili
//import org.springframework.ai.openai.OpenAiChatClient; * @since 1.0
//import org.springframework.beans.factory.annotation.Autowired; * @time 2024/4/13 17:44
//import org.springframework.context.ApplicationContext; */
//import org.springframework.validation.annotation.Validated; @Tag(name = "AI模块")
//import org.springframework.web.bind.annotation.PostMapping; @RestController
//import org.springframework.web.bind.annotation.RequestBody; @RequestMapping("/ai-api")
//import org.springframework.web.bind.annotation.RequestMapping; @Slf4j
//import org.springframework.web.bind.annotation.RestController; @AllArgsConstructor
//import reactor.core.publisher.Flux; public class ChatController {
//
//import java.util.function.Consumer;
//
//// TODO done @fansili有了 swagger 注释就不用类注释了
//@Tag(name = "AI模块")
//@RestController
//@RequestMapping("/ai-api")
//@Slf4j
//public class ChatController {
// //
// @Autowired // @Autowired
// private ApplicationContext applicationContext; // private QianWenChatClient qianWenChatClient;
// //
// @PostMapping("/chat") // @GetMapping("/chat")
// @Operation(summary = "对话聊天", description = "简单的ai聊天") // public String chat(@RequestParam("prompt") String prompt) {
// public CommonResult chat(@RequestBody @Validated AiChatReqVO reqVO) { // return qianWenChatClient.call(prompt);
// ChatClient chatClient = getChatClient(reqVO.getAiModel());
// String res;
// try {
// res = chatClient.call(reqVO.getPrompt());
// } catch (Exception e) {
// res = e.getMessage();
// }
// return CommonResult.success(res);
// }
//
// @PostMapping("/chatStream")
// @Operation(summary = "对话聊天chatStream", description = "简单的ai聊天")
// public CommonResult chatStream(HttpServletResponse response, @RequestBody @Validated AiChatReqVO reqVO) throws InterruptedException {
// OpenAiChatClient chatClient = applicationContext.getBean(OpenAiChatClient.class);
// Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getPrompt()));
// chatResponse.subscribe(new Consumer<ChatResponse>() {
// @Override
// public void accept(ChatResponse chatResponse) {
// System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
// }
// });
// return CommonResult.success(null);
// }
//
// /**
// * 根据 ai模型 获取对于的 模型实现类
// *
// * @param aiModelEnum
// * @return
// */
// private ChatClient getChatClient(OpenAiModelEnum aiModelEnum) {
// if (OpenAiModelEnum.OPEN_AI_GPT_3_5 == aiModelEnum) {
// return applicationContext.getBean(OpenAiChatClient.class);
// }
// // AI模型暂不支持
// throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODULE_NOT_SUPPORTED);
// }
// } // }
}

View File

@ -0,0 +1,77 @@
//package cn.iocoder.yudao.module.ai.controller.admin;
//
//import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
//import cn.iocoder.yudao.framework.common.pojo.CommonResult;
//import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
//import cn.iocoder.yudao.module.ai.controller.admin.vo.AiChatReqVO;
//import cn.iocoder.yudao.module.ai.enums.OpenAiModelEnum;
//import io.swagger.v3.oas.annotations.Operation;
//import io.swagger.v3.oas.annotations.tags.Tag;
//import jakarta.servlet.http.HttpServletResponse;
//import lombok.extern.slf4j.Slf4j;
//import org.springframework.ai.chat.ChatClient;
//import org.springframework.ai.chat.ChatResponse;
//import org.springframework.ai.chat.prompt.Prompt;
//import org.springframework.ai.openai.OpenAiChatClient;
//import org.springframework.beans.factory.annotation.Autowired;
//import org.springframework.context.ApplicationContext;
//import org.springframework.validation.annotation.Validated;
//import org.springframework.web.bind.annotation.PostMapping;
//import org.springframework.web.bind.annotation.RequestBody;
//import org.springframework.web.bind.annotation.RequestMapping;
//import org.springframework.web.bind.annotation.RestController;
//import reactor.core.publisher.Flux;
//
//import java.util.function.Consumer;
//
//// TODO done @fansili有了 swagger 注释就不用类注释了
//@Tag(name = "AI模块")
//@RestController
//@RequestMapping("/ai-api")
//@Slf4j
//public class ChatController {
//
// @Autowired
// private ApplicationContext applicationContext;
//
// @PostMapping("/chat")
// @Operation(summary = "对话聊天", description = "简单的ai聊天")
// public CommonResult chat(@RequestBody @Validated AiChatReqVO reqVO) {
// ChatClient chatClient = getChatClient(reqVO.getAiModel());
// String res;
// try {
// res = chatClient.call(reqVO.getPrompt());
// } catch (Exception e) {
// res = e.getMessage();
// }
// return CommonResult.success(res);
// }
//
// @PostMapping("/chatStream")
// @Operation(summary = "对话聊天chatStream", description = "简单的ai聊天")
// public CommonResult chatStream(HttpServletResponse response, @RequestBody @Validated AiChatReqVO reqVO) throws InterruptedException {
// OpenAiChatClient chatClient = applicationContext.getBean(OpenAiChatClient.class);
// Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getPrompt()));
// chatResponse.subscribe(new Consumer<ChatResponse>() {
// @Override
// public void accept(ChatResponse chatResponse) {
// System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
// }
// });
// return CommonResult.success(null);
// }
//
// /**
// * 根据 ai模型 获取对于的 模型实现类
// *
// * @param aiModelEnum
// * @return
// */
// private ChatClient getChatClient(OpenAiModelEnum aiModelEnum) {
// if (OpenAiModelEnum.OPEN_AI_GPT_3_5 == aiModelEnum) {
// return applicationContext.getBean(OpenAiChatClient.class);
// }
// // AI模型暂不支持
// throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODULE_NOT_SUPPORTED);
// }
//}

View File

@ -0,0 +1,38 @@
package cn.iocoder.yudao.framework.ai;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;
/**
* 讯飞星火 模型
*
* 文档地址https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
*
* 1tokens 约等于1.5个中文汉字 或者 0.8个英文单词
* 星火V1.5支持[搜索]内置插件星火V2.0V3.0和V3.5支持[搜索][天气][日期][诗词][字词][股票]六个内置插件
* 星火V3.5 现已支持systemFunction Call 功能
*
* author: fansili
* time: 2024/3/11 10:12
*/
@Getter
@AllArgsConstructor
public enum AiPlatformEnum {
YI_YAN("yiyan"),
QIAN_WEN("qianwen"),
XING_HUO("xinghuo"),
;
public static final Map<String, AiPlatformEnum> mapValues
= Arrays.stream(values()).collect(Collectors.toMap(AiPlatformEnum::name, o -> o));
private String value;
}

View File

@ -1,11 +1,29 @@
package cn.iocoder.yudao.framework.ai.config; package cn.iocoder.yudao.framework.ai.config;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions;
import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatClient; import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoOptions; import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.chatxinghuo.api.XingHuoApi; import cn.iocoder.yudao.framework.ai.chatxinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanOptions;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
import cn.iocoder.yudao.framework.ai.exception.AiException;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.support.GenericApplicationContext;
import java.util.Map;
/** /**
* ai 自动配置 * ai 自动配置
@ -18,15 +36,118 @@ import org.springframework.context.annotation.Bean;
@EnableConfigurationProperties(YudaoAiProperties.class) @EnableConfigurationProperties(YudaoAiProperties.class)
public class YudaoAiAutoConfiguration { public class YudaoAiAutoConfiguration {
// TODO @芋艿我看sharding jdbc 差不多这么玩的
@Bean @Bean
public XingHuoChatClient xingHuoChatClient(YudaoAiProperties yudaoAiProperties) { @ConditionalOnMissingBean(value = InitChatClient.class)
return new XingHuoChatClient( public InitChatClient initChatClient(YudaoAiProperties yudaoAiProperties) {
return new InitChatClient(yudaoAiProperties);
}
public static class InitChatClient implements InitializingBean, ApplicationContextAware {
private GenericApplicationContext applicationContext;
private YudaoAiProperties yudaoAiProperties;
public InitChatClient(YudaoAiProperties yudaoAiProperties) {
this.yudaoAiProperties = yudaoAiProperties;
}
@Override
public void afterPropertiesSet() throws Exception {
for (Map.Entry<String, Map<String, Object>> properties : yudaoAiProperties.entrySet()) {
String beanName = properties.getKey();
Map<String, Object> aiPlatformMap = properties.getValue();
// 检查平台类型是否正确
String aiPlatform = String.valueOf(aiPlatformMap.get("aiPlatform"));
if (!AiPlatformEnum.mapValues.containsKey(aiPlatform)) {
throw new AiException("AI平台名称错误! 可以参考 AiPlatformEnum 类!");
}
// 获取平台类型
AiPlatformEnum aiPlatformEnum = AiPlatformEnum.mapValues.get(aiPlatform);
// 获取 chat properties
YudaoAiProperties.ChatProperties chatProperties = getChatProperties(aiPlatformEnum, aiPlatformMap);
// 创建客户端
registerChatClient(applicationContext, chatProperties, beanName);
// applicationContext.refresh();
}
System.err.println(applicationContext.getBean("qianWen"));
System.err.println(applicationContext.getBean("yiYan"));
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = (GenericApplicationContext) applicationContext;
}
}
private static void registerChatClient(GenericApplicationContext applicationContext, YudaoAiProperties.ChatProperties chatProperties, String beanName) {
ConfigurableListableBeanFactory beanFactory = applicationContext.getBeanFactory();
Object wrapperBean = null;
if (AiPlatformEnum.XING_HUO == chatProperties.getAiPlatform()) {
YudaoAiProperties.XingHuoProperties xingHuoProperties = (YudaoAiProperties.XingHuoProperties) chatProperties;
wrapperBean = beanFactory.initializeBean(
new XingHuoChatClient(
new XingHuoApi( new XingHuoApi(
yudaoAiProperties.getXingHuo().getAppId(), xingHuoProperties.getAppId(),
yudaoAiProperties.getXingHuo().getAppKey(), xingHuoProperties.getAppKey(),
yudaoAiProperties.getXingHuo().getSecretKey() xingHuoProperties.getSecretKey()
), ),
new XingHuoOptions().setChatModel(yudaoAiProperties.getXingHuo().getChatModel()) new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel())
),
beanName
);
} else if (AiPlatformEnum.QIAN_WEN == chatProperties.getAiPlatform()) {
YudaoAiProperties.QianWenProperties qianWenProperties = (YudaoAiProperties.QianWenProperties) chatProperties;
wrapperBean = beanFactory.initializeBean(new QianWenChatClient(
new QianWenApi(
qianWenProperties.getAccessKeyId(),
qianWenProperties.getAccessKeySecret(),
qianWenProperties.getAgentKey(),
qianWenProperties.getEndpoint()
),
new QianWenOptions()
.setAppId(qianWenProperties.getAppId())
),
beanName
);
} else if (AiPlatformEnum.YI_YAN == chatProperties.getAiPlatform()) {
YudaoAiProperties.YiYanProperties yiYanProperties = (YudaoAiProperties.YiYanProperties) chatProperties;
wrapperBean = beanFactory.initializeBean(new YiYanChatClient(
new YiYanApi(
yiYanProperties.getAppKey(),
yiYanProperties.getSecretKey(),
yiYanProperties.getChatModel(),
yiYanProperties.getRefreshTokenSecondTime()
),
new YiYanOptions().setMax_output_tokens(2048)),
beanName
); );
} }
if (wrapperBean != null) {
beanFactory.registerSingleton(beanName, wrapperBean);
}
}
private static YudaoAiProperties.ChatProperties getChatProperties(AiPlatformEnum aiPlatformEnum, Map<String, Object> aiPlatformMap) {
if (AiPlatformEnum.XING_HUO == aiPlatformEnum) {
YudaoAiProperties.XingHuoProperties xingHuoProperties = new YudaoAiProperties.XingHuoProperties();
BeanUtil.fillBeanWithMap(aiPlatformMap, xingHuoProperties, true);
return xingHuoProperties;
} else if (AiPlatformEnum.YI_YAN == aiPlatformEnum) {
YudaoAiProperties.YiYanProperties yiYanProperties = new YudaoAiProperties.YiYanProperties();
BeanUtil.fillBeanWithMap(aiPlatformMap, yiYanProperties, true);
return yiYanProperties;
} else if (AiPlatformEnum.QIAN_WEN == aiPlatformEnum) {
YudaoAiProperties.QianWenProperties qianWenProperties = new YudaoAiProperties.QianWenProperties();
BeanUtil.fillBeanWithMap(aiPlatformMap, qianWenProperties, true);
return qianWenProperties;
}
throw new AiException("不支持的Ai类型!");
}
} }

View File

@ -1,11 +1,15 @@
package cn.iocoder.yudao.framework.ai.config; package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel; import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.ConfigurationProperties;
import java.util.LinkedHashMap;
import java.util.Map;
/** /**
* ai 自动配置 * ai 自动配置
* *
@ -15,16 +19,18 @@ import org.springframework.boot.context.properties.ConfigurationProperties;
*/ */
@Data @Data
@ConfigurationProperties(prefix = "yudao.ai") @ConfigurationProperties(prefix = "yudao.ai")
public class YudaoAiProperties { public class YudaoAiProperties extends LinkedHashMap<String, Map<String, Object>> {
private QianWenProperties qianWen; // private QianWenProperties qianWen;
private XingHuoProperties xingHuo; // private XingHuoProperties xingHuo;
private YiYanProperties yiYan; // private YiYanProperties yiYan;
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public static class ChatProperties { public static class ChatProperties {
private AiPlatformEnum aiPlatform;
private Float temperature; private Float temperature;
private Float topP; private Float topP;
@ -48,9 +54,14 @@ public class YudaoAiProperties {
*/ */
private String accessKeySecret; private String accessKeySecret;
/** /**
* 阿里云agentKey(相当于应用id) * 阿里云agentKey
*/ */
private String agentKey; private String agentKey;
/**
* 阿里云agentKey(相当于应用id)
*/
private String appId;
} }
@Data @Data

View File

@ -0,0 +1,15 @@
package cn.iocoder.yudao.framework.ai.exception;
/**
* ai 异常
*
* @author fansili
* @time 2024/4/13 17:05
* @since 1.0
*/
public class AiException extends RuntimeException {
public AiException(String message) {
super(message);
}
}

View File

@ -224,20 +224,30 @@ wx:
# 芋道配置项,设置当前项目所有自定义的配置 # 芋道配置项,设置当前项目所有自定义的配置
yudao: yudao:
ai: ai:
qianWen:
aiPlatform: QIAN_WEN
temperature: 1 temperature: 1
topP: 1 topP: 1
topK: 1 topK: 1
qianWen:
endpoint: bailian.cn-beijing.aliyuncs.com endpoint: bailian.cn-beijing.aliyuncs.com
accessKeyId: LTAI5tNTVhXW4fLKUjMrr98z accessKeyId: LTAI5tNTVhXW4fLKUjMrr98z
accessKeySecret: ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP accessKeySecret: ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP
agentKey: f0c1088824594f589c8f10567ccd929f_p_efm agentKey: f0c1088824594f589c8f10567ccd929f_p_efm
appId: 5f14955f201a44eb8dbe0c57250a32ce
xingHuo: xingHuo:
aiPlatform: XING_HUO
temperature: 1
topP: 1
topK: 1
appId: 13c8cca6 appId: 13c8cca6
appKey: cb6415c19d6162cda07b47316fcb0416 appKey: cb6415c19d6162cda07b47316fcb0416
secretKey: Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh secretKey: Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh
chatModel: XING_HUO_3_5 chatModel: XING_HUO_3_5
yiYan: yiYan:
aiPlatform: YI_YAN
temperature: 1
topP: 1
topK: 1
appKey: x0cuLZ7XsaTCU08vuJWO87Lg appKey: x0cuLZ7XsaTCU08vuJWO87Lg
secretKey: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK secretKey: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
refreshTokenSecondTime: 86400 refreshTokenSecondTime: 86400