【优化】AI:依赖从 io.springboot.ai 调整为 org.springframework.ai

This commit is contained in:
YunaiV 2024-06-29 18:13:36 +08:00
parent 6225e18f70
commit 7dfa7a1573
16 changed files with 124 additions and 90 deletions

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.enums.model;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
// TODO @芋艿可以考虑清理掉
/** /**
* ai 模型 * ai 模型
* *

View File

@ -26,9 +26,9 @@ import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.OllamaOptions;
@ -44,8 +44,8 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_NOT_EXIST;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_NOT_EXIST;
/** /**
* AI 聊天消息 Service 实现类 * AI 聊天消息 Service 实现类
@ -117,7 +117,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型 // 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(model.getKeyId()); StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
// 1.3 获取用户头像角色头像 // 1.3 获取用户头像角色头像
AiChatRoleDO role = conversation.getRoleId() != null ? chatRoleService.getChatRole(conversation.getRoleId()) : null; AiChatRoleDO role = conversation.getRoleId() != null ? chatRoleService.getChatRole(conversation.getRoleId()) : null;
@ -164,7 +164,14 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
// 1.2 history message 历史消息 // 1.2 history message 历史消息
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO); List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent()))); contextMessages.forEach(message -> {
// TODO @芋艿看看有没优化空间
if (MessageType.USER.getValue().equals(message.getType())) {
chatMessages.add(new UserMessage(message.getContent()));
} else {
chatMessages.add(new AssistantMessage(message.getContent()));
}
});
// 1.3 user message 新发送消息 // 1.3 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent())); chatMessages.add(new UserMessage(sendReqVO.getContent()));

View File

@ -25,7 +25,7 @@ import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.infra.api.file.FileApi; import cn.iocoder.yudao.module.infra.api.file.FileApi;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.image.ImageClient; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions; import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponse;
@ -97,7 +97,7 @@ public class AiImageServiceImpl implements AiImageService {
// 1.1 构建请求 // 1.1 构建请求
ImageOptions request = buildImageOptions(req); ImageOptions request = buildImageOptions(req);
// 1.2 执行请求 // 1.2 执行请求
ImageClient imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform())); ImageModel imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform()));
ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request)); ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request));
// 2. 上传到文件服务 // 2. 上传到文件服务

View File

@ -8,8 +8,8 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageR
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageClient; import org.springframework.ai.image.ImageModel;
import java.util.List; import java.util.List;
@ -81,7 +81,7 @@ public interface AiApiKeyService {
* @param id 编号 * @param id 编号
* @return StreamingChatClient 对象 * @return StreamingChatClient 对象
*/ */
StreamingChatClient getStreamingChatClient(Long id); StreamingChatModel getStreamingChatClient(Long id);
/** /**
* 获得 ImageClient 对象 * 获得 ImageClient 对象
@ -91,7 +91,7 @@ public interface AiApiKeyService {
* @param platform 平台 * @param platform 平台
* @return ImageClient 对象 * @return ImageClient 对象
*/ */
ImageClient getImageClient(AiPlatformEnum platform); ImageModel getImageClient(AiPlatformEnum platform);
/** /**
* 获得 MidjourneyApi 对象 * 获得 MidjourneyApi 对象

View File

@ -12,8 +12,8 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageClient; import org.springframework.ai.image.ImageModel;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
@ -98,14 +98,14 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
// ========== spring-ai 集成 ========== // ========== spring-ai 集成 ==========
@Override @Override
public StreamingChatClient getStreamingChatClient(Long id) { public StreamingChatModel getStreamingChatClient(Long id) {
AiApiKeyDO apiKey = validateApiKey(id); AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
} }
@Override @Override
public ImageClient getImageClient(AiPlatformEnum platform) { public ImageModel getImageClient(AiPlatformEnum platform) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus()); AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) { if (apiKey == null) {
return null; return null;

View File

@ -2,33 +2,38 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" <project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent> <parent>
<groupId>cn.iocoder.boot</groupId> <groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-module-ai</artifactId> <artifactId>yudao-module-ai</artifactId>
<version>${revision}</version> <version>${revision}</version>
</parent> </parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>yudao-spring-boot-starter-ai</artifactId> <artifactId>yudao-spring-boot-starter-ai</artifactId>
<name>${project.artifactId}</name>
<description>AI 大模型拓展,接入国内外大模型</description>
<properties>
<spring-ai.version>1.0.0-M1</spring-ai.version>
</properties>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>io.springboot.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId> <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
<version>1.0.3</version> <version>${spring-ai.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>io.springboot.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId> <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<version>1.0.3</version> <version>${spring-ai.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>io.springboot.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-stability-ai</artifactId> <artifactId>spring-ai-stability-ai-spring-boot-starter</artifactId>
<version>1.0.3</version> <version>${spring-ai.version}</version>
</dependency> </dependency>
<!-- <dependency>--> <!-- <dependency>-->
<!-- <groupId>io.springboot.ai</groupId>--> <!-- <groupId>org.springframework.ai</groupId>-->
<!-- <artifactId>spring-ai-vertex-ai-gemini</artifactId>--> <!-- <artifactId>spring-ai-vertex-ai-gemini</artifactId>-->
<!-- <version>1.0.3</version>--> <!-- <version>1.0.3</version>-->
<!-- </dependency>--> <!-- </dependency>-->

View File

@ -3,8 +3,8 @@ package cn.iocoder.yudao.framework.ai.core.factory;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageClient; import org.springframework.ai.image.ImageModel;
/** /**
* AI 客户端工厂的接口类 * AI 客户端工厂的接口类
@ -23,7 +23,7 @@ public interface AiClientFactory {
* @param url API URL * @param url API URL
* @return StreamingChatClient 对象 * @return StreamingChatClient 对象
*/ */
StreamingChatClient getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url); StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url);
/** /**
* 基于默认配置获得 StreamingChatClient 对象 * 基于默认配置获得 StreamingChatClient 对象
@ -33,7 +33,7 @@ public interface AiClientFactory {
* @param platform 平台 * @param platform 平台
* @return StreamingChatClient 对象 * @return StreamingChatClient 对象
*/ */
StreamingChatClient getDefaultStreamingChatClient(AiPlatformEnum platform); StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform);
/** /**
* 基于默认配置获得 ImageClient 对象 * 基于默认配置获得 ImageClient 对象
@ -43,7 +43,7 @@ public interface AiClientFactory {
* @param platform 平台 * @param platform 平台
* @return ImageClient 对象 * @return ImageClient 对象
*/ */
ImageClient getDefaultImageClient(AiPlatformEnum platform); ImageModel getDefaultImageClient(AiPlatformEnum platform);
/** /**
* 基于指定配置获得 ImageClient 对象 * 基于指定配置获得 ImageClient 对象
@ -55,7 +55,7 @@ public interface AiClientFactory {
* @param url API URL * @param url API URL
* @return ImageClient 对象 * @return ImageClient 对象
*/ */
ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url); ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
/** /**
* 基于指定配置获得 MidjourneyApi 对象 * 基于指定配置获得 MidjourneyApi 对象

View File

@ -20,16 +20,16 @@ import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi; import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageClient; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiImageClient; import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.ApiUtils; import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.stabilityai.StabilityAiImageClient; import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.ai.stabilityai.api.StabilityAiApi; import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.web.client.RestClient; import org.springframework.web.client.RestClient;
@ -43,9 +43,9 @@ import java.util.List;
public class AiClientFactoryImpl implements AiClientFactory { public class AiClientFactoryImpl implements AiClientFactory {
@Override @Override
public StreamingChatClient getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) { public StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(StreamingChatClient.class, platform, apiKey, url); String cacheKey = buildClientCacheKey(StreamingChatModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<StreamingChatClient>) () -> { return Singleton.get(cacheKey, (Func0<StreamingChatModel>) () -> {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OPENAI: case OPENAI:
@ -67,13 +67,13 @@ public class AiClientFactoryImpl implements AiClientFactory {
} }
@Override @Override
public StreamingChatClient getDefaultStreamingChatClient(AiPlatformEnum platform) { public StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OPENAI: case OPENAI:
return SpringUtil.getBean(OpenAiChatClient.class); return SpringUtil.getBean(OpenAiChatModel.class);
case OLLAMA: case OLLAMA:
return SpringUtil.getBean(OllamaChatClient.class); return SpringUtil.getBean(OllamaChatModel.class);
case YI_YAN: case YI_YAN:
return SpringUtil.getBean(YiYanChatClient.class); return SpringUtil.getBean(YiYanChatClient.class);
case XING_HUO: case XING_HUO:
@ -86,20 +86,20 @@ public class AiClientFactoryImpl implements AiClientFactory {
} }
@Override @Override
public ImageClient getDefaultImageClient(AiPlatformEnum platform) { public ImageModel getDefaultImageClient(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OPENAI: case OPENAI:
return SpringUtil.getBean(OpenAiImageClient.class); return SpringUtil.getBean(OpenAiImageModel.class);
case STABLE_DIFFUSION: case STABLE_DIFFUSION:
return SpringUtil.getBean(StabilityAiImageClient.class); return SpringUtil.getBean(StabilityAiImageModel.class);
default: default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
} }
} }
@Override @Override
public ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) { public ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OPENAI: case OPENAI:
@ -138,18 +138,18 @@ public class AiClientFactoryImpl implements AiClientFactory {
/** /**
* 可参考 {@link OpenAiAutoConfiguration} * 可参考 {@link OpenAiAutoConfiguration}
*/ */
private static OpenAiChatClient buildOpenAiChatClient(String openAiToken, String url) { private static OpenAiChatModel buildOpenAiChatClient(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiApi openAiApi = new OpenAiApi(url, openAiToken); OpenAiApi openAiApi = new OpenAiApi(url, openAiToken);
return new OpenAiChatClient(openAiApi); return new OpenAiChatModel(openAiApi);
} }
/** /**
* 可参考 {@link OllamaAutoConfiguration} * 可参考 {@link OllamaAutoConfiguration}
*/ */
private static OllamaChatClient buildOllamaChatClient(String url) { private static OllamaChatModel buildOllamaChatClient(String url) {
OllamaApi ollamaApi = new OllamaApi(url); OllamaApi ollamaApi = new OllamaApi(url);
return new OllamaChatClient(ollamaApi); return new OllamaChatModel(ollamaApi);
} }
/** /**
@ -192,16 +192,16 @@ public class AiClientFactoryImpl implements AiClientFactory {
// return new VertexAiGeminiChatClient(vertexApi); // return new VertexAiGeminiChatClient(vertexApi);
// } // }
private ImageClient buildOpenAiImageClient(String openAiToken, String url) { private OpenAiImageModel buildOpenAiImageClient(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder()); OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
return new OpenAiImageClient(openAiApi); return new OpenAiImageModel(openAiApi);
} }
private ImageClient buildStabilityAiImageClient(String apiKey, String url) { private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) {
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url); StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);
return new StabilityAiImageClient(stabilityAiApi); return new StabilityAiImageModel(stabilityAiApi);
} }
} }

View File

@ -1,11 +1,7 @@
package cn.iocoder.yudao.framework.ai.core.model.tongyi; package cn.iocoder.yudao.framework.ai.core.model.tongyi;
import cn.hutool.core.util.NumberUtil;
import cn.iocoder.yudao.framework.ai.core.exception.ChatException; import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi; import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import org.springframework.ai.chat.*;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException; import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import com.alibaba.dashscope.aigc.generation.GenerationResult; import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam; import com.alibaba.dashscope.aigc.generation.models.QwenParam;
@ -14,6 +10,12 @@ import com.google.common.collect.Lists;
import io.reactivex.Flowable; import io.reactivex.Flowable;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext; import org.springframework.retry.RetryContext;
@ -35,7 +37,7 @@ import java.util.stream.Collectors;
* time: 2024/3/13 21:06 * time: 2024/3/13 21:06
*/ */
@Slf4j @Slf4j
public class QianWenChatClient implements ChatClient, StreamingChatClient { public class QianWenChatClient implements ChatModel, StreamingChatModel {
private QianWenApi qianWenApi; private QianWenApi qianWenApi;
@ -90,6 +92,12 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
}); });
} }
@Override
public ChatOptions getDefaultOptions() {
// TODO 芋艿需要跟进下
throw new UnsupportedOperationException();
}
private QwenParam createRequest(Prompt prompt, boolean stream) { private QwenParam createRequest(Prompt prompt, boolean stream) {
// 获取 ChatOptions // 获取 ChatOptions
QianWenOptions chatOptions = getChatOptions(prompt); QianWenOptions chatOptions = getChatOptions(prompt);

View File

@ -6,10 +6,13 @@ import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoChatCompletion; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoChatCompletion;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoChatCompletionRequest; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoChatCompletionRequest;
import org.springframework.ai.chat.*; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext; import org.springframework.retry.RetryContext;
@ -29,7 +32,7 @@ import java.util.stream.Collectors;
* time: 2024/3/11 10:19 * time: 2024/3/11 10:19
*/ */
@Slf4j @Slf4j
public class XingHuoChatClient implements ChatClient, StreamingChatClient { public class XingHuoChatClient implements ChatModel, StreamingChatModel {
private XingHuoApi xingHuoApi; private XingHuoApi xingHuoApi;
@ -64,7 +67,6 @@ public class XingHuoChatClient implements ChatClient, StreamingChatClient {
@Override @Override
public ChatResponse call(Prompt prompt) { public ChatResponse call(Prompt prompt) {
return this.retryTemplate.execute(ctx -> { return this.retryTemplate.execute(ctx -> {
// ctx 会有重试的信息 // ctx 会有重试的信息
// 获取 chatOptions 属性 // 获取 chatOptions 属性
@ -78,6 +80,12 @@ public class XingHuoChatClient implements ChatClient, StreamingChatClient {
}); });
} }
@Override
public ChatOptions getDefaultOptions() {
// TODO 芋艿需要跟进下
throw new UnsupportedOperationException();
}
@Override @Override
public Flux<ChatResponse> stream(Prompt prompt) { public Flux<ChatResponse> stream(Prompt prompt) {
// 获取 chatOptions 属性 // 获取 chatOptions 属性

View File

@ -7,12 +7,13 @@ import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionReq
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionResponse; import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionResponse;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException; import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
@ -33,7 +34,7 @@ import java.util.stream.Collectors;
* @author fansili * @author fansili
*/ */
@Slf4j @Slf4j
public class YiYanChatClient implements ChatClient, StreamingChatClient { public class YiYanChatClient implements ChatModel, StreamingChatModel {
private final YiYanApi yiYanApi; private final YiYanApi yiYanApi;
@ -86,6 +87,12 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
}); });
} }
@Override
public ChatOptions getDefaultOptions() {
// TODO 芋艿需要跟进下
throw new UnsupportedOperationException();
}
@Override @Override
public Flux<ChatResponse> stream(Prompt prompt) { public Flux<ChatResponse> stream(Prompt prompt) {
YiYanChatCompletionRequest request = this.createRequest(prompt, true); YiYanChatCompletionRequest request = this.createRequest(prompt, true);
@ -99,8 +106,6 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
}); });
} }
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) { private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t 文档system 是独立字段 // 参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t 文档system 是独立字段
// 1.1 获取 user assistant // 1.1 获取 user assistant

View File

@ -1,9 +1,5 @@
package cn.iocoder.yudao.framework.ai.chat; package cn.iocoder.yudao.framework.ai.chat;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
@ -17,6 +13,10 @@ import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException; import com.alibaba.dashscope.exception.NoApiKeyException;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.util.ArrayList; import java.util.ArrayList;

View File

@ -1,16 +1,16 @@
package cn.iocoder.yudao.framework.ai.chat; package cn.iocoder.yudao.framework.ai.chat;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.util.ArrayList; import java.util.ArrayList;

View File

@ -1,16 +1,16 @@
package cn.iocoder.yudao.framework.ai.chat; package cn.iocoder.yudao.framework.ai.chat;
import org.springframework.ai.chat.ChatResponse; import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import org.junit.Before;
import org.junit.Test;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.util.ArrayList; import java.util.ArrayList;

View File

@ -4,7 +4,7 @@ import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponse;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.ai.openai.OpenAiImageClient; import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.ai.openai.api.OpenAiImageApi;
import javax.imageio.ImageIO; import javax.imageio.ImageIO;
@ -23,12 +23,12 @@ import java.util.Scanner;
public class OpenAiImageClientTests { public class OpenAiImageClientTests {
private OpenAiImageClient openAiImageClient; private OpenAiImageModel openAiImageClient;
@Before @Before
public void setup() { public void setup() {
// 初始化 openAiImageClient // 初始化 openAiImageClient
this.openAiImageClient = new OpenAiImageClient( this.openAiImageClient = new OpenAiImageModel(
new OpenAiImageApi("") new OpenAiImageApi("")
// new OpenAiImageOptions().setResponseFormat(OpenAiImageOptions.ResponseFormatEnum.URL.getValue()) TODO 芋艿临时处理 // new OpenAiImageOptions().setResponseFormat(OpenAiImageOptions.ResponseFormatEnum.URL.getValue()) TODO 芋艿临时处理
); );