优化自动注入,可以创建多个 client

This commit is contained in:
cherishsince 2024-04-13 18:20:28 +08:00
parent 02ac6f30cf
commit ac0de5d485
7 changed files with 320 additions and 91 deletions

View File

@ -1,77 +1,34 @@
//package cn.iocoder.yudao.module.ai.controller.admin;
//
//import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
//import cn.iocoder.yudao.framework.common.pojo.CommonResult;
//import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
//import cn.iocoder.yudao.module.ai.controller.admin.vo.AiChatReqVO;
//import cn.iocoder.yudao.module.ai.enums.OpenAiModelEnum;
//import io.swagger.v3.oas.annotations.Operation;
//import io.swagger.v3.oas.annotations.tags.Tag;
//import jakarta.servlet.http.HttpServletResponse;
//import lombok.extern.slf4j.Slf4j;
//import org.springframework.ai.chat.ChatClient;
//import org.springframework.ai.chat.ChatResponse;
//import org.springframework.ai.chat.prompt.Prompt;
//import org.springframework.ai.openai.OpenAiChatClient;
//import org.springframework.beans.factory.annotation.Autowired;
//import org.springframework.context.ApplicationContext;
//import org.springframework.validation.annotation.Validated;
//import org.springframework.web.bind.annotation.PostMapping;
//import org.springframework.web.bind.annotation.RequestBody;
//import org.springframework.web.bind.annotation.RequestMapping;
//import org.springframework.web.bind.annotation.RestController;
//import reactor.core.publisher.Flux;
//
//import java.util.function.Consumer;
//
//// TODO done @fansili有了 swagger 注释就不用类注释了
//@Tag(name = "AI模块")
//@RestController
//@RequestMapping("/ai-api")
//@Slf4j
//public class ChatController {
package cn.iocoder.yudao.module.ai.controller.admin;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
/**
* @author fansili
* @since 1.0
* @time 2024/4/13 17:44
*/
@Tag(name = "AI模块")
@RestController
@RequestMapping("/ai-api")
@Slf4j
@AllArgsConstructor
public class ChatController {
//
// @Autowired
// private ApplicationContext applicationContext;
// private QianWenChatClient qianWenChatClient;
//
// @PostMapping("/chat")
// @Operation(summary = "对话聊天", description = "简单的ai聊天")
// public CommonResult chat(@RequestBody @Validated AiChatReqVO reqVO) {
// ChatClient chatClient = getChatClient(reqVO.getAiModel());
// String res;
// try {
// res = chatClient.call(reqVO.getPrompt());
// } catch (Exception e) {
// res = e.getMessage();
// }
// return CommonResult.success(res);
// @GetMapping("/chat")
// public String chat(@RequestParam("prompt") String prompt) {
// return qianWenChatClient.call(prompt);
// }
//
// @PostMapping("/chatStream")
// @Operation(summary = "对话聊天chatStream", description = "简单的ai聊天")
// public CommonResult chatStream(HttpServletResponse response, @RequestBody @Validated AiChatReqVO reqVO) throws InterruptedException {
// OpenAiChatClient chatClient = applicationContext.getBean(OpenAiChatClient.class);
// Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getPrompt()));
// chatResponse.subscribe(new Consumer<ChatResponse>() {
// @Override
// public void accept(ChatResponse chatResponse) {
// System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
// }
// });
// return CommonResult.success(null);
// }
//
// /**
// * 根据 ai模型 获取对于的 模型实现类
// *
// * @param aiModelEnum
// * @return
// */
// private ChatClient getChatClient(OpenAiModelEnum aiModelEnum) {
// if (OpenAiModelEnum.OPEN_AI_GPT_3_5 == aiModelEnum) {
// return applicationContext.getBean(OpenAiChatClient.class);
// }
// // AI模型暂不支持
// throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODULE_NOT_SUPPORTED);
// }
//}
}

View File

@ -0,0 +1,77 @@
//package cn.iocoder.yudao.module.ai.controller.admin;
//
//import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
//import cn.iocoder.yudao.framework.common.pojo.CommonResult;
//import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
//import cn.iocoder.yudao.module.ai.controller.admin.vo.AiChatReqVO;
//import cn.iocoder.yudao.module.ai.enums.OpenAiModelEnum;
//import io.swagger.v3.oas.annotations.Operation;
//import io.swagger.v3.oas.annotations.tags.Tag;
//import jakarta.servlet.http.HttpServletResponse;
//import lombok.extern.slf4j.Slf4j;
//import org.springframework.ai.chat.ChatClient;
//import org.springframework.ai.chat.ChatResponse;
//import org.springframework.ai.chat.prompt.Prompt;
//import org.springframework.ai.openai.OpenAiChatClient;
//import org.springframework.beans.factory.annotation.Autowired;
//import org.springframework.context.ApplicationContext;
//import org.springframework.validation.annotation.Validated;
//import org.springframework.web.bind.annotation.PostMapping;
//import org.springframework.web.bind.annotation.RequestBody;
//import org.springframework.web.bind.annotation.RequestMapping;
//import org.springframework.web.bind.annotation.RestController;
//import reactor.core.publisher.Flux;
//
//import java.util.function.Consumer;
//
//// TODO done @fansili有了 swagger 注释就不用类注释了
//@Tag(name = "AI模块")
//@RestController
//@RequestMapping("/ai-api")
//@Slf4j
//public class ChatController {
//
// @Autowired
// private ApplicationContext applicationContext;
//
// @PostMapping("/chat")
// @Operation(summary = "对话聊天", description = "简单的ai聊天")
// public CommonResult chat(@RequestBody @Validated AiChatReqVO reqVO) {
// ChatClient chatClient = getChatClient(reqVO.getAiModel());
// String res;
// try {
// res = chatClient.call(reqVO.getPrompt());
// } catch (Exception e) {
// res = e.getMessage();
// }
// return CommonResult.success(res);
// }
//
// @PostMapping("/chatStream")
// @Operation(summary = "对话聊天chatStream", description = "简单的ai聊天")
// public CommonResult chatStream(HttpServletResponse response, @RequestBody @Validated AiChatReqVO reqVO) throws InterruptedException {
// OpenAiChatClient chatClient = applicationContext.getBean(OpenAiChatClient.class);
// Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getPrompt()));
// chatResponse.subscribe(new Consumer<ChatResponse>() {
// @Override
// public void accept(ChatResponse chatResponse) {
// System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
// }
// });
// return CommonResult.success(null);
// }
//
// /**
// * 根据 ai模型 获取对于的 模型实现类
// *
// * @param aiModelEnum
// * @return
// */
// private ChatClient getChatClient(OpenAiModelEnum aiModelEnum) {
// if (OpenAiModelEnum.OPEN_AI_GPT_3_5 == aiModelEnum) {
// return applicationContext.getBean(OpenAiChatClient.class);
// }
// // AI模型暂不支持
// throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODULE_NOT_SUPPORTED);
// }
//}

View File

@ -0,0 +1,38 @@
package cn.iocoder.yudao.framework.ai;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;
/**
* 讯飞星火 模型
*
* 文档地址https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
*
* 1tokens 约等于1.5个中文汉字 或者 0.8个英文单词
* 星火V1.5支持[搜索]内置插件星火V2.0V3.0和V3.5支持[搜索][天气][日期][诗词][字词][股票]六个内置插件
* 星火V3.5 现已支持systemFunction Call 功能
*
* author: fansili
* time: 2024/3/11 10:12
*/
@Getter
@AllArgsConstructor
public enum AiPlatformEnum {
YI_YAN("yiyan"),
QIAN_WEN("qianwen"),
XING_HUO("xinghuo"),
;
public static final Map<String, AiPlatformEnum> mapValues
= Arrays.stream(values()).collect(Collectors.toMap(AiPlatformEnum::name, o -> o));
private String value;
}

View File

@ -1,11 +1,29 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions;
import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.chatxinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanOptions;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
import cn.iocoder.yudao.framework.ai.exception.AiException;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Bean;
import org.springframework.context.support.GenericApplicationContext;
import java.util.Map;
/**
* ai 自动配置
@ -18,15 +36,118 @@ import org.springframework.context.annotation.Bean;
@EnableConfigurationProperties(YudaoAiProperties.class)
public class YudaoAiAutoConfiguration {
// TODO @芋艿我看sharding jdbc 差不多这么玩的
@Bean
public XingHuoChatClient xingHuoChatClient(YudaoAiProperties yudaoAiProperties) {
return new XingHuoChatClient(
new XingHuoApi(
yudaoAiProperties.getXingHuo().getAppId(),
yudaoAiProperties.getXingHuo().getAppKey(),
yudaoAiProperties.getXingHuo().getSecretKey()
),
new XingHuoOptions().setChatModel(yudaoAiProperties.getXingHuo().getChatModel())
);
@ConditionalOnMissingBean(value = InitChatClient.class)
public InitChatClient initChatClient(YudaoAiProperties yudaoAiProperties) {
return new InitChatClient(yudaoAiProperties);
}
}
public static class InitChatClient implements InitializingBean, ApplicationContextAware {
private GenericApplicationContext applicationContext;
private YudaoAiProperties yudaoAiProperties;
public InitChatClient(YudaoAiProperties yudaoAiProperties) {
this.yudaoAiProperties = yudaoAiProperties;
}
@Override
public void afterPropertiesSet() throws Exception {
for (Map.Entry<String, Map<String, Object>> properties : yudaoAiProperties.entrySet()) {
String beanName = properties.getKey();
Map<String, Object> aiPlatformMap = properties.getValue();
// 检查平台类型是否正确
String aiPlatform = String.valueOf(aiPlatformMap.get("aiPlatform"));
if (!AiPlatformEnum.mapValues.containsKey(aiPlatform)) {
throw new AiException("AI平台名称错误! 可以参考 AiPlatformEnum 类!");
}
// 获取平台类型
AiPlatformEnum aiPlatformEnum = AiPlatformEnum.mapValues.get(aiPlatform);
// 获取 chat properties
YudaoAiProperties.ChatProperties chatProperties = getChatProperties(aiPlatformEnum, aiPlatformMap);
// 创建客户端
registerChatClient(applicationContext, chatProperties, beanName);
// applicationContext.refresh();
}
System.err.println(applicationContext.getBean("qianWen"));
System.err.println(applicationContext.getBean("yiYan"));
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = (GenericApplicationContext) applicationContext;
}
}
private static void registerChatClient(GenericApplicationContext applicationContext, YudaoAiProperties.ChatProperties chatProperties, String beanName) {
ConfigurableListableBeanFactory beanFactory = applicationContext.getBeanFactory();
Object wrapperBean = null;
if (AiPlatformEnum.XING_HUO == chatProperties.getAiPlatform()) {
YudaoAiProperties.XingHuoProperties xingHuoProperties = (YudaoAiProperties.XingHuoProperties) chatProperties;
wrapperBean = beanFactory.initializeBean(
new XingHuoChatClient(
new XingHuoApi(
xingHuoProperties.getAppId(),
xingHuoProperties.getAppKey(),
xingHuoProperties.getSecretKey()
),
new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel())
),
beanName
);
} else if (AiPlatformEnum.QIAN_WEN == chatProperties.getAiPlatform()) {
YudaoAiProperties.QianWenProperties qianWenProperties = (YudaoAiProperties.QianWenProperties) chatProperties;
wrapperBean = beanFactory.initializeBean(new QianWenChatClient(
new QianWenApi(
qianWenProperties.getAccessKeyId(),
qianWenProperties.getAccessKeySecret(),
qianWenProperties.getAgentKey(),
qianWenProperties.getEndpoint()
),
new QianWenOptions()
.setAppId(qianWenProperties.getAppId())
),
beanName
);
} else if (AiPlatformEnum.YI_YAN == chatProperties.getAiPlatform()) {
YudaoAiProperties.YiYanProperties yiYanProperties = (YudaoAiProperties.YiYanProperties) chatProperties;
wrapperBean = beanFactory.initializeBean(new YiYanChatClient(
new YiYanApi(
yiYanProperties.getAppKey(),
yiYanProperties.getSecretKey(),
yiYanProperties.getChatModel(),
yiYanProperties.getRefreshTokenSecondTime()
),
new YiYanOptions().setMax_output_tokens(2048)),
beanName
);
}
if (wrapperBean != null) {
beanFactory.registerSingleton(beanName, wrapperBean);
}
}
private static YudaoAiProperties.ChatProperties getChatProperties(AiPlatformEnum aiPlatformEnum, Map<String, Object> aiPlatformMap) {
if (AiPlatformEnum.XING_HUO == aiPlatformEnum) {
YudaoAiProperties.XingHuoProperties xingHuoProperties = new YudaoAiProperties.XingHuoProperties();
BeanUtil.fillBeanWithMap(aiPlatformMap, xingHuoProperties, true);
return xingHuoProperties;
} else if (AiPlatformEnum.YI_YAN == aiPlatformEnum) {
YudaoAiProperties.YiYanProperties yiYanProperties = new YudaoAiProperties.YiYanProperties();
BeanUtil.fillBeanWithMap(aiPlatformMap, yiYanProperties, true);
return yiYanProperties;
} else if (AiPlatformEnum.QIAN_WEN == aiPlatformEnum) {
YudaoAiProperties.QianWenProperties qianWenProperties = new YudaoAiProperties.QianWenProperties();
BeanUtil.fillBeanWithMap(aiPlatformMap, qianWenProperties, true);
return qianWenProperties;
}
throw new AiException("不支持的Ai类型!");
}
}

View File

@ -1,11 +1,15 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
import lombok.Data;
import lombok.experimental.Accessors;
import org.springframework.boot.context.properties.ConfigurationProperties;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* ai 自动配置
*
@ -15,16 +19,18 @@ import org.springframework.boot.context.properties.ConfigurationProperties;
*/
@Data
@ConfigurationProperties(prefix = "yudao.ai")
public class YudaoAiProperties {
public class YudaoAiProperties extends LinkedHashMap<String, Map<String, Object>> {
private QianWenProperties qianWen;
private XingHuoProperties xingHuo;
private YiYanProperties yiYan;
// private QianWenProperties qianWen;
// private XingHuoProperties xingHuo;
// private YiYanProperties yiYan;
@Data
@Accessors(chain = true)
public static class ChatProperties {
private AiPlatformEnum aiPlatform;
private Float temperature;
private Float topP;
@ -48,9 +54,14 @@ public class YudaoAiProperties {
*/
private String accessKeySecret;
/**
* 阿里云agentKey(相当于应用id)
* 阿里云agentKey
*/
private String agentKey;
/**
* 阿里云agentKey(相当于应用id)
*/
private String appId;
}
@Data

View File

@ -0,0 +1,15 @@
package cn.iocoder.yudao.framework.ai.exception;
/**
* ai 异常
*
* @author fansili
* @time 2024/4/13 17:05
* @since 1.0
*/
public class AiException extends RuntimeException {
public AiException(String message) {
super(message);
}
}

View File

@ -224,20 +224,30 @@ wx:
# 芋道配置项,设置当前项目所有自定义的配置
yudao:
ai:
temperature: 1
topP: 1
topK: 1
qianWen:
aiPlatform: QIAN_WEN
temperature: 1
topP: 1
topK: 1
endpoint: bailian.cn-beijing.aliyuncs.com
accessKeyId: LTAI5tNTVhXW4fLKUjMrr98z
accessKeySecret: ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP
agentKey: f0c1088824594f589c8f10567ccd929f_p_efm
appId: 5f14955f201a44eb8dbe0c57250a32ce
xingHuo:
aiPlatform: XING_HUO
temperature: 1
topP: 1
topK: 1
appId: 13c8cca6
appKey: cb6415c19d6162cda07b47316fcb0416
secretKey: Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh
chatModel: XING_HUO_3_5
yiYan:
aiPlatform: YI_YAN
temperature: 1
topP: 1
topK: 1
appKey: x0cuLZ7XsaTCU08vuJWO87Lg
secretKey: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
refreshTokenSecondTime: 86400