【优化】ai chat client自动配置和初始化。

This commit is contained in:
cherishsince 2024-04-25 18:05:08 +08:00
parent 44f7c841de
commit 2adb5accc4
11 changed files with 226 additions and 212 deletions

View File

@ -1,37 +0,0 @@
package cn.iocoder.yudao.module.ai.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* ai client 名字
*
* 这个需要根据配置文件起的来决定
*
* @author fansili
* @time 2024/4/14 16:02
* @since 1.0
*/
@AllArgsConstructor
@Getter
public enum AiClientNameEnum {
QIAN_WEN("qianWen", "千问模型!"),
YI_YAN_3_5_8K("yiYan3_5_8k", "文心一言(3.5-8k)"),
XING_HUO("xingHuo", "星火模型!"),
;
private String name;
private String message;
public static AiClientNameEnum valueOfName(String name) {
for (AiClientNameEnum nameEnum : AiClientNameEnum.values()) {
if (nameEnum.getName().equals(name)) {
return nameEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + name);
}
}

View File

@ -0,0 +1,48 @@
package cn.iocoder.yudao.module.ai.config;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chat.ChatClient;
import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;
/**
* factory
*
* @author fansili
* @time 2024/4/25 17:36
* @since 1.0
*/
@Component
public class AiChatClientFactory {
@Autowired
private ApplicationContext applicationContext;
public ChatClient getChatClient(AiPlatformEnum platformEnum) {
if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class);
}
throw new IllegalArgumentException("不支持的 chat client!");
}
// TODO yunai 要不再加一个接口让他们拥有 ChatClientStreamingChatClient 功能
public StreamingChatClient getStreamingChatClient(AiPlatformEnum platformEnum) {
if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class);
}
throw new IllegalArgumentException("不支持的 chat client!");
}
}

View File

@ -1,14 +1,16 @@
package cn.iocoder.yudao.module.ai.service.impl;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chat.ChatClient;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.config.AiClient;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
@ -38,7 +40,7 @@ import java.util.function.Consumer;
@AllArgsConstructor
public class AiChatServiceImpl implements AiChatService {
private final AiClient aiClient;
private final AiChatClientFactory aiChatClientFactory;
private final AiChatRoleMapper aiChatRoleMapper;
private final AiChatMessageMapper aiChatMessageMapper;
private final AiChatConversationMapper aiChatConversationMapper;
@ -54,7 +56,7 @@ public class AiChatServiceImpl implements AiChatService {
public String chat(AiChatReq req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
// 获取对话信息
AiChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
// 保存 chat message
@ -67,7 +69,8 @@ public class AiChatServiceImpl implements AiChatService {
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
// 发送 call 调用
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
ChatResponse call = chatClient.call(prompt);
content = call.getResult().getOutput().getContent();
// 更新 conversation
@ -128,7 +131,7 @@ public class AiChatServiceImpl implements AiChatService {
public void chatStream(AiChatReq req, Utf8SseEmitter sseEmitter) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
// 获取对话信息
AiChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
// 创建 chat 需要的 Prompt
@ -138,7 +141,8 @@ public class AiChatServiceImpl implements AiChatService {
req.setTemperature(req.getTemperature());
// 保存 chat message
saveChatMessage(req, conversationRes, loginUserId);
Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
StringBuffer contentBuffer = new StringBuffer();
streamResponse.subscribe(

View File

@ -16,7 +16,7 @@ tenant-id: 1
}
### chat call
GET {{baseUrl}}/ai/chat?modal=qianWen&conversationId=1781604279872581644&prompt=中国好看吗?
GET {{baseUrl}}/ai/chat?modal=qianwen&conversationId=1781604279872581644&prompt=中国好看吗?
Authorization: {{token}}

View File

@ -3,10 +3,6 @@ package cn.iocoder.yudao.framework.ai;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;
/**
* 讯飞星火 模型
*
@ -24,15 +20,22 @@ import java.util.stream.Collectors;
public enum AiPlatformEnum {
YI_YAN("yiyan"),
QIAN_WEN("qianwen"),
XING_HUO("xinghuo"),
YI_YAN("yiyan", "一言"),
QIAN_WEN("qianwen", "千问"),
XING_HUO("xinghuo", "星火"),
;
public static final Map<String, AiPlatformEnum> mapValues
= Arrays.stream(values()).collect(Collectors.toMap(AiPlatformEnum::name, o -> o));
private String platform;
private String name;
private String value;
public static AiPlatformEnum valueOfPlatform(String platform) {
for (AiPlatformEnum itemEnum : AiPlatformEnum.values()) {
if (itemEnum.getPlatform().equals(platform)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + platform);
}
}

View File

@ -1,19 +0,0 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
/**
* ai client传入
*
* @author fansili
* @time 2024/4/14 10:27
* @since 1.0
*/
public interface AiClient {
ChatResponse call(Prompt prompt, String clientName);
Flux<ChatResponse> stream(Prompt prompt, String clientName);
}

View File

@ -1,7 +1,5 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions;
import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
@ -11,15 +9,11 @@ import cn.iocoder.yudao.framework.ai.chatxinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanOptions;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
import cn.iocoder.yudao.framework.ai.exception.AiException;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import java.util.HashMap;
import java.util.Map;
/**
* ai 自动配置
*
@ -32,86 +26,46 @@ import java.util.Map;
public class YudaoAiAutoConfiguration {
@Bean
@ConditionalOnMissingBean(value = AiClient.class)
public AiClient aiClient(YudaoAiProperties yudaoAiProperties) {
Map<String, Object> chatClientMap = buildChatClientMap(yudaoAiProperties);
return new YudaoAiClient(chatClientMap);
@ConditionalOnProperty(value = "yudao.ai.xinghuo.enable", havingValue = "true")
public XingHuoChatClient xingHuoChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.XingHuoProperties xingHuoProperties = yudaoAiProperties.getXinghuo();
return new XingHuoChatClient(
new XingHuoApi(
xingHuoProperties.getAppId(),
xingHuoProperties.getAppKey(),
xingHuoProperties.getSecretKey()
),
new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel())
);
}
public Map<String, Object> buildChatClientMap(YudaoAiProperties yudaoAiProperties) {
Map<String, Object> chatMap = new HashMap<>();
for (Map.Entry<String, Map<String, Object>> properties : yudaoAiProperties.entrySet()) {
String beanName = properties.getKey();
Map<String, Object> aiPlatformMap = properties.getValue();
// 检查平台类型是否正确
String aiPlatform = String.valueOf(aiPlatformMap.get("aiPlatform"));
if (!AiPlatformEnum.mapValues.containsKey(aiPlatform)) {
throw new AiException("AI平台名称错误! 可以参考 AiPlatformEnum 类!");
}
// 获取平台类型
AiPlatformEnum aiPlatformEnum = AiPlatformEnum.mapValues.get(aiPlatform);
// 获取 chat properties
YudaoAiProperties.ChatProperties chatProperties = getChatProperties(aiPlatformEnum, aiPlatformMap);
// 创建客户端
Object chatClient = createChatClient(chatProperties);
chatMap.put(beanName, chatClient);
}
return chatMap;
@Bean
@ConditionalOnProperty(value = "yudao.ai.qianwen.enable", havingValue = "true")
public QianWenChatClient qianWenChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.QianWenProperties qianWenProperties = yudaoAiProperties.getQianwen();
return new QianWenChatClient(
new QianWenApi(
qianWenProperties.getAccessKeyId(),
qianWenProperties.getAccessKeySecret(),
qianWenProperties.getAgentKey(),
qianWenProperties.getEndpoint()
),
new QianWenOptions()
.setAppId(qianWenProperties.getAppId())
);
}
private static Object createChatClient(YudaoAiProperties.ChatProperties chatProperties) {
if (AiPlatformEnum.XING_HUO == chatProperties.getAiPlatform()) {
YudaoAiProperties.XingHuoProperties xingHuoProperties = (YudaoAiProperties.XingHuoProperties) chatProperties;
return new XingHuoChatClient(
new XingHuoApi(
xingHuoProperties.getAppId(),
xingHuoProperties.getAppKey(),
xingHuoProperties.getSecretKey()
),
new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel())
);
} else if (AiPlatformEnum.QIAN_WEN == chatProperties.getAiPlatform()) {
YudaoAiProperties.QianWenProperties qianWenProperties = (YudaoAiProperties.QianWenProperties) chatProperties;
return new QianWenChatClient(
new QianWenApi(
qianWenProperties.getAccessKeyId(),
qianWenProperties.getAccessKeySecret(),
qianWenProperties.getAgentKey(),
qianWenProperties.getEndpoint()
),
new QianWenOptions()
.setAppId(qianWenProperties.getAppId())
);
} else if (AiPlatformEnum.YI_YAN == chatProperties.getAiPlatform()) {
YudaoAiProperties.YiYanProperties yiYanProperties = (YudaoAiProperties.YiYanProperties) chatProperties;
return new YiYanChatClient(
new YiYanApi(
yiYanProperties.getAppKey(),
yiYanProperties.getSecretKey(),
yiYanProperties.getChatModel(),
yiYanProperties.getRefreshTokenSecondTime()
),
new YiYanOptions().setMax_output_tokens(2048));
}
throw new AiException("不支持的Ai类型!");
}
private static YudaoAiProperties.ChatProperties getChatProperties(AiPlatformEnum aiPlatformEnum, Map<String, Object> aiPlatformMap) {
if (AiPlatformEnum.XING_HUO == aiPlatformEnum) {
YudaoAiProperties.XingHuoProperties xingHuoProperties = new YudaoAiProperties.XingHuoProperties();
BeanUtil.fillBeanWithMap(aiPlatformMap, xingHuoProperties, true);
return xingHuoProperties;
} else if (AiPlatformEnum.YI_YAN == aiPlatformEnum) {
YudaoAiProperties.YiYanProperties yiYanProperties = new YudaoAiProperties.YiYanProperties();
BeanUtil.fillBeanWithMap(aiPlatformMap, yiYanProperties, true);
return yiYanProperties;
} else if (AiPlatformEnum.QIAN_WEN == aiPlatformEnum) {
YudaoAiProperties.QianWenProperties qianWenProperties = new YudaoAiProperties.QianWenProperties();
BeanUtil.fillBeanWithMap(aiPlatformMap, qianWenProperties, true);
return qianWenProperties;
}
throw new AiException("不支持的Ai类型!");
@Bean
@ConditionalOnProperty(value = "yudao.ai.yiyan.enable", havingValue = "true")
public YiYanChatClient yiYanChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.YiYanProperties yiYanProperties = yudaoAiProperties.getYiyan();
return new YiYanChatClient(
new YiYanApi(
yiYanProperties.getAppKey(),
yiYanProperties.getSecretKey(),
yiYanProperties.getChatModel(),
yiYanProperties.getRefreshTokenSecondTime()
),
new YiYanOptions().setMax_output_tokens(2048));
}
}

View File

@ -1,44 +0,0 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.chat.ChatClient;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.exception.AiException;
import reactor.core.publisher.Flux;
import java.util.Map;
/**
* yudao ai client
*
* @author fansili
* @time 2024/4/14 10:27
* @since 1.0
*/
public class YudaoAiClient implements AiClient {
protected Map<String, Object> chatClientMap;
public YudaoAiClient(Map<String, Object> chatClientMap) {
this.chatClientMap = chatClientMap;
}
@Override
public ChatResponse call(Prompt prompt, String clientName) {
if (!chatClientMap.containsKey(clientName)) {
throw new AiException("clientName不存在!");
}
ChatClient chatClient = (ChatClient) chatClientMap.get(clientName);
return chatClient.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt, String clientName) {
if (!chatClientMap.containsKey(clientName)) {
throw new AiException("clientName不存在!");
}
StreamingChatClient streamingChatClient = (StreamingChatClient) chatClientMap.get(clientName);
return streamingChatClient.stream(prompt);
}
}

View File

@ -0,0 +1,99 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
import lombok.Data;
import lombok.experimental.Accessors;
import org.springframework.boot.context.properties.ConfigurationProperties;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* ai 自动配置
*
* @author fansili
* @time 2024/4/12 16:29
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class YudaoAiImageProperties extends LinkedHashMap<String, Map<String, Object>> {
private String initType;
private QianWenProperties qianwen;
private XingHuoOptions xinghuo;
private YiYanProperties yiyan;
@Data
@Accessors(chain = true)
public static class QianWenProperties extends ChatProperties {
/**
* 阿里云服务器接入点
*/
private String endpoint = "bailian.cn-beijing.aliyuncs.com";
/**
* 阿里云权限 accessKeyId
*/
private String accessKeyId;
/**
* 阿里云权限 accessKeySecret
*/
private String accessKeySecret;
/**
* 阿里云agentKey
*/
private String agentKey;
/**
* 阿里云agentKey(相当于应用id)
*/
private String appId;
}
@Data
@Accessors(chain = true)
public static class XingHuoProperties extends ChatProperties {
private String appId;
private String appKey;
private String secretKey;
private XingHuoChatModel chatModel;
}
@Data
@Accessors(chain = true)
public static class YiYanProperties extends ChatProperties {
/**
* appKey
*/
private String appKey;
/**
* secretKey
*/
private String secretKey;
/**
* 模型
*/
private YiYanChatModel chatModel = YiYanChatModel.ERNIE4_3_5_8K;
/**
* token 刷新时间(默认 86400 = 24小时)
*/
private int refreshTokenSecondTime = 86400;
}
@Data
@Accessors(chain = true)
public static class ChatProperties {
private AiPlatformEnum aiPlatform;
private Float temperature;
private Float topP;
private Integer topK;
}
}

View File

@ -7,9 +7,6 @@ import lombok.Data;
import lombok.experimental.Accessors;
import org.springframework.boot.context.properties.ConfigurationProperties;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* ai 自动配置
*
@ -18,17 +15,21 @@ import java.util.Map;
* @since 1.0
*/
@Data
@Accessors(chain = true)
@ConfigurationProperties(prefix = "yudao.ai")
public class YudaoAiProperties extends LinkedHashMap<String, Map<String, Object>> {
public class YudaoAiProperties {
// private QianWenProperties qianWen;
// private XingHuoProperties xingHuo;
// private YiYanProperties yiYan;
private String initSource;
private QianWenProperties qianwen;
private XingHuoProperties xinghuo;
private YiYanProperties yiyan;
@Data
@Accessors(chain = true)
public static class ChatProperties {
private boolean enable = false;
private AiPlatformEnum aiPlatform;
private Float temperature;

View File

@ -224,7 +224,9 @@ wx:
# 芋道配置项,设置当前项目所有自定义的配置
yudao:
ai:
qianWen:
initSource: yaml
qianwen:
enable: true
aiPlatform: QIAN_WEN
temperature: 1
topP: 1
@ -234,7 +236,8 @@ yudao:
accessKeySecret: ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP
agentKey: f0c1088824594f589c8f10567ccd929f_p_efm
appId: 5f14955f201a44eb8dbe0c57250a32ce
xingHuo:
xinghuo:
enable: true
aiPlatform: XING_HUO
temperature: 1
topP: 1
@ -243,7 +246,8 @@ yudao:
appKey: cb6415c19d6162cda07b47316fcb0416
secretKey: Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh
chatModel: XING_HUO_3_5
yiYan3_5_8k:
yiyan:
enable: true
aiPlatform: YI_YAN
temperature: 1
topP: 1
@ -252,6 +256,7 @@ yudao:
secretKey: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
refreshTokenSecondTime: 86400
chatModel: ERNIE4_3_5_8K
captcha:
enable: false # 本地环境,暂时关闭图片验证码,方便登录等接口的测试;
security: