mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2024-11-26 01:01:52 +08:00
【新增】AI:会话接入 API KEY 逻辑
This commit is contained in:
parent
6856f5f192
commit
b7180d3481
@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.dal.mysql.model;
|
|||||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
||||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||||
|
import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||||
import org.apache.ibatis.annotations.Mapper;
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
@ -23,4 +24,12 @@ public interface AiApiKeyMapper extends BaseMapperX<AiApiKeyDO> {
|
|||||||
.orderByDesc(AiApiKeyDO::getId));
|
.orderByDesc(AiApiKeyDO::getId));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
default AiApiKeyDO selectFirstByPlatformAndStatus(String platform, Integer status) {
|
||||||
|
return selectOne(new QueryWrapperX<AiApiKeyDO>()
|
||||||
|
.eq("platform", platform)
|
||||||
|
.eq("status", status)
|
||||||
|
.limitN(1)
|
||||||
|
.orderByAsc("id"));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -4,11 +4,13 @@ import cn.hutool.core.collection.CollUtil;
|
|||||||
import cn.hutool.core.util.ObjUtil;
|
import cn.hutool.core.util.ObjUtil;
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
|
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||||
@ -18,6 +20,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
|||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
|
||||||
|
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
||||||
@ -28,6 +31,8 @@ import org.springframework.ai.chat.StreamingChatClient;
|
|||||||
import org.springframework.ai.chat.messages.*;
|
import org.springframework.ai.chat.messages.*;
|
||||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||||
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
|
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||||
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
@ -54,9 +59,6 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
@Resource
|
@Resource
|
||||||
private AiChatMessageMapper chatMessageMapper;
|
private AiChatMessageMapper chatMessageMapper;
|
||||||
|
|
||||||
@Resource
|
|
||||||
private AiClientFactory clientFactory;
|
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private AiChatConversationService chatConversationService;
|
private AiChatConversationService chatConversationService;
|
||||||
@Resource
|
@Resource
|
||||||
@ -168,11 +170,33 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
|
|
||||||
// 2. 构建 ChatOptions 对象
|
// 2. 构建 ChatOptions 对象
|
||||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
ChatOptions chatOptions = clientFactory.buildChatOptions(platform, model.getModel(),
|
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(),
|
||||||
conversation.getTemperature(), conversation.getMaxTokens());
|
conversation.getTemperature(), conversation.getMaxTokens());
|
||||||
return new Prompt(chatMessages, chatOptions);
|
return new Prompt(chatMessages, chatOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
||||||
|
Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
||||||
|
//noinspection EnhancedSwitchMigration
|
||||||
|
switch (platform) {
|
||||||
|
case OPENAI:
|
||||||
|
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||||
|
case OLLAMA:
|
||||||
|
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
||||||
|
case YI_YAN:
|
||||||
|
// TODO @fan:增加一个 model
|
||||||
|
return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens);
|
||||||
|
case XING_HUO:
|
||||||
|
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
||||||
|
.setMaxTokens(maxTokens);
|
||||||
|
case QIAN_WEN:
|
||||||
|
// TODO @fan:增加 model、temperature 参数
|
||||||
|
return new QianWenOptions().setMaxTokens(maxTokens);
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
||||||
*
|
*
|
||||||
@ -183,7 +207,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
* @param sendReqVO 发送请求
|
* @param sendReqVO 发送请求
|
||||||
* @return 消息上下文
|
* @return 消息上下文
|
||||||
*/
|
*/
|
||||||
private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages, AiChatConversationDO conversation, AiChatMessageSendReqVO sendReqVO) {
|
private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages,
|
||||||
|
AiChatConversationDO conversation,
|
||||||
|
AiChatMessageSendReqVO sendReqVO) {
|
||||||
if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
|
if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
@ -64,4 +64,5 @@ public interface AiImageService {
|
|||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO);
|
Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,6 @@ import cn.hutool.core.util.StrUtil;
|
|||||||
import cn.hutool.extra.spring.SpringUtil;
|
import cn.hutool.extra.spring.SpringUtil;
|
||||||
import cn.hutool.http.HttpUtil;
|
import cn.hutool.http.HttpUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
|
|
||||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
@ -23,6 +22,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyIma
|
|||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
||||||
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
||||||
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||||
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@ -57,7 +57,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|||||||
private FileApi fileApi;
|
private FileApi fileApi;
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private AiClientFactory aiClientFactory;
|
private AiApiKeyService apiKeyService;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private MidjourneyProxyClient midjourneyProxyClient;
|
private MidjourneyProxyClient midjourneyProxyClient;
|
||||||
@ -82,17 +82,17 @@ public class AiImageServiceImpl implements AiImageService {
|
|||||||
.setWidth(drawReqVO.getWidth()).setHeight(drawReqVO.getHeight()).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
.setWidth(drawReqVO.getWidth()).setHeight(drawReqVO.getHeight()).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
||||||
imageMapper.insert(image);
|
imageMapper.insert(image);
|
||||||
// 2. 异步绘制,后续前端通过返回的 id 进行轮询结果
|
// 2. 异步绘制,后续前端通过返回的 id 进行轮询结果
|
||||||
getSelf().doDall(image, drawReqVO);
|
getSelf().executeDrawImage(image, drawReqVO);
|
||||||
return image.getId();
|
return image.getId();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Async
|
@Async
|
||||||
public void doDall(AiImageDO image, AiImageDrawReqVO req) {
|
public void executeDrawImage(AiImageDO image, AiImageDrawReqVO req) {
|
||||||
try {
|
try {
|
||||||
// 1.1 构建请求
|
// 1.1 构建请求
|
||||||
ImageOptions request = buildImageOptions(req);
|
ImageOptions request = buildImageOptions(req);
|
||||||
// 1.2 执行请求
|
// 1.2 执行请求
|
||||||
ImageClient imageClient = aiClientFactory.getDefaultImageClient(AiPlatformEnum.validatePlatform(req.getPlatform()));
|
ImageClient imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform()));
|
||||||
ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request));
|
ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request));
|
||||||
|
|
||||||
// 2. 上传到文件服务
|
// 2. 上传到文件服务
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
package cn.iocoder.yudao.module.ai.service.model;
|
package cn.iocoder.yudao.module.ai.service.model;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||||
import jakarta.validation.Valid;
|
import jakarta.validation.Valid;
|
||||||
import org.springframework.ai.chat.StreamingChatClient;
|
import org.springframework.ai.chat.StreamingChatClient;
|
||||||
|
import org.springframework.ai.image.ImageClient;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@ -79,4 +81,14 @@ public interface AiApiKeyService {
|
|||||||
*/
|
*/
|
||||||
StreamingChatClient getStreamingChatClient(Long id);
|
StreamingChatClient getStreamingChatClient(Long id);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获得 ImageClient 对象
|
||||||
|
*
|
||||||
|
* TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择
|
||||||
|
*
|
||||||
|
* @param platform 平台
|
||||||
|
* @return ImageClient 对象
|
||||||
|
*/
|
||||||
|
ImageClient getImageClient(AiPlatformEnum platform);
|
||||||
|
|
||||||
}
|
}
|
@ -11,6 +11,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
|||||||
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import org.springframework.ai.chat.StreamingChatClient;
|
import org.springframework.ai.chat.StreamingChatClient;
|
||||||
|
import org.springframework.ai.image.ImageClient;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.validation.annotation.Validated;
|
import org.springframework.validation.annotation.Validated;
|
||||||
|
|
||||||
@ -101,4 +102,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
|
|||||||
return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
|
return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ImageClient getImageClient(AiPlatformEnum platform) {
|
||||||
|
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus());
|
||||||
|
if (apiKey == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.core.factory;
|
|||||||
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import org.springframework.ai.chat.StreamingChatClient;
|
import org.springframework.ai.chat.StreamingChatClient;
|
||||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
|
||||||
import org.springframework.ai.image.ImageClient;
|
import org.springframework.ai.image.ImageClient;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -45,14 +44,15 @@ public interface AiClientFactory {
|
|||||||
ImageClient getDefaultImageClient(AiPlatformEnum platform);
|
ImageClient getDefaultImageClient(AiPlatformEnum platform);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 创建 Chat 参数
|
* 基于指定配置,获得 ImageClient 对象
|
||||||
|
*
|
||||||
|
* 如果不存在,则进行创建
|
||||||
*
|
*
|
||||||
* @param platform 平台
|
* @param platform 平台
|
||||||
* @param model 模型
|
* @param apiKey API KEY
|
||||||
* @param temperature 温度
|
* @param url API URL
|
||||||
* @param maxTokens 生成的最大 Token
|
* @return ImageClient 对象
|
||||||
* @return Chat 参数
|
|
||||||
*/
|
*/
|
||||||
ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens);
|
ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -11,29 +11,25 @@ import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
|
|||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
|
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
|
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
|
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
|
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
|
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
|
||||||
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
|
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
|
||||||
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
||||||
import org.springframework.ai.chat.StreamingChatClient;
|
import org.springframework.ai.chat.StreamingChatClient;
|
||||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
|
||||||
import org.springframework.ai.image.ImageClient;
|
import org.springframework.ai.image.ImageClient;
|
||||||
import org.springframework.ai.ollama.OllamaChatClient;
|
import org.springframework.ai.ollama.OllamaChatClient;
|
||||||
import org.springframework.ai.ollama.api.OllamaApi;
|
import org.springframework.ai.ollama.api.OllamaApi;
|
||||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
|
||||||
import org.springframework.ai.openai.OpenAiChatClient;
|
import org.springframework.ai.openai.OpenAiChatClient;
|
||||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
|
||||||
import org.springframework.ai.openai.OpenAiImageClient;
|
import org.springframework.ai.openai.OpenAiImageClient;
|
||||||
import org.springframework.ai.openai.api.ApiUtils;
|
import org.springframework.ai.openai.api.ApiUtils;
|
||||||
import org.springframework.ai.openai.api.OpenAiApi;
|
import org.springframework.ai.openai.api.OpenAiApi;
|
||||||
|
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||||
import org.springframework.ai.stabilityai.StabilityAiImageClient;
|
import org.springframework.ai.stabilityai.StabilityAiImageClient;
|
||||||
|
import org.springframework.ai.stabilityai.api.StabilityAiApi;
|
||||||
|
import org.springframework.web.client.RestClient;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@ -100,6 +96,19 @@ public class AiClientFactoryImpl implements AiClientFactory {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) {
|
||||||
|
//noinspection EnhancedSwitchMigration
|
||||||
|
switch (platform) {
|
||||||
|
case OPENAI:
|
||||||
|
return buildOpenAiImageClient(apiKey, url);
|
||||||
|
case STABLE_DIFFUSION:
|
||||||
|
return buildStabilityAiImageClient(apiKey, url);
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
|
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
|
||||||
if (ArrayUtil.isEmpty(params)) {
|
if (ArrayUtil.isEmpty(params)) {
|
||||||
return clazz.getName();
|
return clazz.getName();
|
||||||
@ -107,29 +116,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
|
|||||||
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
|
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
|
||||||
Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
|
||||||
//noinspection EnhancedSwitchMigration
|
|
||||||
switch (platform) {
|
|
||||||
case OPENAI:
|
|
||||||
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
||||||
case OLLAMA:
|
|
||||||
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
|
||||||
case YI_YAN:
|
|
||||||
// TODO @fan:增加一个 model
|
|
||||||
return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens);
|
|
||||||
case XING_HUO:
|
|
||||||
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
|
||||||
.setMaxTokens(maxTokens);
|
|
||||||
case QIAN_WEN:
|
|
||||||
// TODO @fan:增加 model、temperature 参数
|
|
||||||
return new QianWenOptions().setMaxTokens(maxTokens);
|
|
||||||
default:
|
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ========== 各种创建 spring-ai 客户端的方法 ==========
|
// ========== 各种创建 spring-ai 客户端的方法 ==========
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -182,7 +168,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
|
|||||||
return new QianWenChatClient(qianWenApi);
|
return new QianWenChatClient(qianWenApi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// private static VertexAiGeminiChatClient buildGoogleGemir(String key) {
|
// private static VertexAiGeminiChatClient buildGoogleGemir(String key) {
|
||||||
// List<String> keys = StrUtil.split(key, '|');
|
// List<String> keys = StrUtil.split(key, '|');
|
||||||
// Assert.equals(keys.size(), 2, "VertexAiGeminiChatClient 的密钥需要 (projectId|location) 格式");
|
// Assert.equals(keys.size(), 2, "VertexAiGeminiChatClient 的密钥需要 (projectId|location) 格式");
|
||||||
@ -190,4 +175,16 @@ public class AiClientFactoryImpl implements AiClientFactory {
|
|||||||
// return new VertexAiGeminiChatClient(vertexApi);
|
// return new VertexAiGeminiChatClient(vertexApi);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
private ImageClient buildOpenAiImageClient(String openAiToken, String url) {
|
||||||
|
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
|
||||||
|
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
|
||||||
|
return new OpenAiImageClient(openAiApi);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ImageClient buildStabilityAiImageClient(String apiKey, String url) {
|
||||||
|
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
|
||||||
|
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);
|
||||||
|
return new StabilityAiImageClient(stabilityAiApi);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -161,7 +161,6 @@ spring:
|
|||||||
project-id: 1 # TODO 芋艿:缺配置
|
project-id: 1 # TODO 芋艿:缺配置
|
||||||
location: 2
|
location: 2
|
||||||
|
|
||||||
|
|
||||||
yudao.ai:
|
yudao.ai:
|
||||||
yiyan:
|
yiyan:
|
||||||
enable: true
|
enable: true
|
||||||
@ -193,11 +192,6 @@ yudao.ai:
|
|||||||
topP: 0.8
|
topP: 0.8
|
||||||
topK: 0
|
topK: 0
|
||||||
api-key: sk-Zsd81gZYg7
|
api-key: sk-Zsd81gZYg7
|
||||||
openAiImage:
|
|
||||||
enable: true
|
|
||||||
api-key: ${OPEN_AI_KEY}
|
|
||||||
model: dall_e_2
|
|
||||||
style: vivid
|
|
||||||
midjourney:
|
midjourney:
|
||||||
enable: true
|
enable: true
|
||||||
token: MTE4MjE3MjY2MjkxNTY3ODIzOA.GEV1SG.c49F8lZoGCUHwsj8O0UdodmM6nyQHvuD2fXflw
|
token: MTE4MjE3MjY2MjkxNTY3ODIzOA.GEV1SG.c49F8lZoGCUHwsj8O0UdodmM6nyQHvuD2fXflw
|
||||||
@ -206,6 +200,7 @@ yudao.ai:
|
|||||||
suno:
|
suno:
|
||||||
enable: true
|
enable: true
|
||||||
token: 16b4356581984d538652354b60d69ff0
|
token: 16b4356581984d538652354b60d69ff0
|
||||||
|
|
||||||
--- #################### 芋道相关配置 ####################
|
--- #################### 芋道相关配置 ####################
|
||||||
|
|
||||||
yudao:
|
yudao:
|
||||||
|
Loading…
Reference in New Issue
Block a user