From 4f11d00cfd5d344c8b06c6f733a09b04c39b7016 Mon Sep 17 00:00:00 2001 From: YunaiV Date: Sat, 6 Jul 2024 15:45:18 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E3=80=91AI=EF=BC=9A=E5=AE=8C=E5=96=84=20TongYiChatModelTests?= =?UTF-8?q?=20=E5=8D=95=E6=B5=8B=EF=BC=8C=E6=96=B9=E4=BE=BF=E5=A4=A7?= =?UTF-8?q?=E5=AE=B6=E5=BF=AB=E9=80=9F=E4=BD=93=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/core/enums/AiPlatformEnum.java | 2 +- .../ai/core/factory/AiModelFactoryImpl.java | 39 +++---- .../yudao/framework/ai/core/util/AiUtils.java | 2 +- .../ai/chat/QianWenChatClientTests.java | 105 ------------------ .../ai/chat/TongYiChatModelTests.java | 75 +++++++++++++ 5 files changed, 97 insertions(+), 126 deletions(-) delete mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/TongYiChatModelTests.java diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java index 044630add..4aeaee3d9 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java @@ -14,8 +14,8 @@ public enum AiPlatformEnum { // ========== 国内平台 ========== + TONG_YI("TongYi", "通义千问"), // 阿里 YI_YAN("YiYan", "文心一言"), // 百度 - QIAN_WEN("QianWen", "千问"), // 阿里 DEEP_SEEK("DeepSeek", "DeepSeek"), // DeepSeek XING_HUO("XingHuo", "星火"), // 讯飞 diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index b95982ed3..087a6b727 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -55,12 +55,12 @@ public class AiModelFactoryImpl implements AiModelFactory { return Singleton.get(cacheKey, (Func0) () -> { //noinspection EnhancedSwitchMigration switch (platform) { + case TONG_YI: + return buildTongYiChatModel(apiKey); case YI_YAN: return buildYiYanChatClient(apiKey); case XING_HUO: return buildXingHuoChatClient(apiKey); - case QIAN_WEN: - return buildQianWenChatClient(apiKey); case DEEP_SEEK: return buildDeepSeekChatClient(apiKey); case OPENAI: @@ -77,16 +77,16 @@ public class AiModelFactoryImpl implements AiModelFactory { public ChatModel getDefaultChatModel(AiPlatformEnum platform) { //noinspection EnhancedSwitchMigration switch (platform) { - case OLLAMA: - return SpringUtil.getBean(OllamaChatModel.class); + case TONG_YI: + return SpringUtil.getBean(TongYiChatModel.class); case YI_YAN: return SpringUtil.getBean(QianFanChatModel.class); case XING_HUO: return SpringUtil.getBean(XingHuoChatClient.class); - case QIAN_WEN: - return SpringUtil.getBean(TongYiChatModel.class); case OPENAI: return SpringUtil.getBean(OpenAiChatModel.class); + case OLLAMA: + return SpringUtil.getBean(OllamaChatModel.class); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); } @@ -142,6 +142,20 @@ public class AiModelFactoryImpl implements AiModelFactory { // ========== 各种创建 spring-ai 客户端的方法 ========== + /** + * 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)} + */ + private static TongYiChatModel buildTongYiChatModel(String key) { + com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class); + TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class); + // TODO @芋艿:貌似 apiKey 是全局唯一的???得测试下 + // TODO @芋艿:貌似阿里云不是增量返回的 + // 该 issue 进行跟进中 https://github.com/alibaba/spring-cloud-alibaba/issues/3790 + TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties(); + connectionProperties.setApiKey(key); + return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); + } + /** * 可参考 {@link OpenAiAutoConfiguration} */ @@ -196,19 +210,6 @@ public class AiModelFactoryImpl implements AiModelFactory { return new DeepSeekChatClient(apiKey); } - /** - * 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)} - */ - private static TongYiChatModel buildQianWenChatClient(String key) { - com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class); - TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class); - // TODO @xin:貌似 apiKey 是全局唯一的???得测试下 - // TODO @xin:貌似阿里云不是增量返回的 - TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties(); - connectionProperties.setApiKey(key); - return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); - } - private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) { url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL); StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java index 9b5a760f1..9e8433d9e 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java @@ -32,7 +32,7 @@ public class AiUtils { return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build(); case XING_HUO: return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build(); - case QIAN_WEN: + case TONG_YI: return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build(); case DEEP_SEEK: return DeepSeekChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build(); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java deleted file mode 100644 index c5ceb50cf..000000000 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java +++ /dev/null @@ -1,105 +0,0 @@ -//package cn.iocoder.yudao.framework.ai.chat; -// -//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient; -//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal; -//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions; -//import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi; -//import com.alibaba.dashscope.aigc.generation.GenerationResult; -//import com.alibaba.dashscope.aigc.generation.models.QwenParam; -//import com.alibaba.dashscope.common.Message; -//import com.alibaba.dashscope.common.MessageManager; -//import com.alibaba.dashscope.common.Role; -//import com.alibaba.dashscope.exception.InputRequiredException; -//import com.alibaba.dashscope.exception.NoApiKeyException; -//import org.junit.Before; -//import org.junit.Test; -//import org.springframework.ai.chat.messages.SystemMessage; -//import org.springframework.ai.chat.messages.UserMessage; -//import org.springframework.ai.chat.model.ChatResponse; -//import org.springframework.ai.chat.prompt.Prompt; -//import reactor.core.publisher.Flux; -// -//import java.util.ArrayList; -//import java.util.List; -//import java.util.Scanner; -//import java.util.function.Consumer; -// -//// TODO 芋艿:整理单测 -///** -// * author: fansili -// * time: 2024/3/13 21:37 -// */ -//public class QianWenChatClientTests { -// -// private QianWenChatClient qianWenChatClient; -// -// @Before -// public void setup() { -// QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT); -// QianWenOptions qianWenOptions = new QianWenOptions(); -// qianWenOptions.setTopP(0.8F); -//// qianWenOptions.setTopK(3); TODO 芋艿:临时处理 -//// qianWenOptions.setTemperature(0.6F); TODO 芋艿:临时处理 -// qianWenChatClient = new QianWenChatClient( -// qianWenApi, -// qianWenOptions -// ); -// } -// -// @Test -// public void callTest() { -// List messages = new ArrayList<>(); -// messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。")); -// messages.add(new UserMessage("长沙怎么样?")); -// -// ChatResponse call = qianWenChatClient.call(new Prompt(messages)); -// System.err.println(call.getResult()); -// } -// -// @Test -// public void streamTest() { -// List messages = new ArrayList<>(); -// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); -// messages.add(new UserMessage("长沙怎么样?")); -// -// Flux flux = qianWenChatClient.stream(new Prompt(messages)); -// flux.subscribe(new Consumer() { -// @Override -// public void accept(ChatResponse chatResponse) { -// System.err.print(chatResponse.getResult().getOutput().getContent()); -// } -// }); -// -// // 阻止退出 -// Scanner scanner = new Scanner(System.in); -// scanner.nextLine(); -// } -// -// @Test -// public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException { -// com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation(); -// MessageManager msgManager = new MessageManager(10); -// Message systemMsg = -// Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build(); -// Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build(); -// msgManager.add(systemMsg); -// msgManager.add(userMsg); -// QwenParam param = -// QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get()) -// .resultFormat(QwenParam.ResultFormat.MESSAGE) -// .topP(0.8) -// /* set the random seed, optional, default to 1234 if not set */ -// .seed(100) -// .apiKey("sk-Zsd81gZYg7") -// .build(); -// GenerationResult result = gen.call(param); -// System.out.println(result); -// System.out.println("-----------------"); -// System.out.println("-----------------"); -// msgManager.add(result); -// param.setPrompt("能否缩短一些,只讲三点"); -// param.setMessages(msgManager.get()); -// result = gen.call(param); -// System.out.println(result); -// } -//} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/TongYiChatModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/TongYiChatModelTests.java new file mode 100644 index 000000000..00bc2b900 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/TongYiChatModelTests.java @@ -0,0 +1,75 @@ +package cn.iocoder.yudao.framework.ai.chat; + +import cn.hutool.core.util.ReflectUtil; +import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel; +import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; +import com.alibaba.dashscope.aigc.generation.Generation; +import com.alibaba.dashscope.common.MessageManager; +import com.alibaba.dashscope.utils.Constants; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; + +/** + * {@link TongYiChatModel} 集成测试类 + * + * @author fansili + */ +public class TongYiChatModelTests { + + private final Generation generation = new Generation(); + private final TongYiChatModel chatModel = new TongYiChatModel(generation, + TongYiChatOptions.builder().withModel("qwen1.5-72b-chat").build()); + + static { + Constants.apiKey = "sk-Zsd81gZYg7"; + } + + @BeforeEach + public void before() { + // 防止 TongYiChatModel 调用空指针 + ReflectUtil.setFieldValue(chatModel, "msgManager", new MessageManager()); + } + + @Test + @Disabled + public void testCall() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + ChatResponse response = chatModel.call(new Prompt(messages)); + // 打印结果 + System.out.println(response); + System.out.println(response.getResult().getOutput()); + } + + @Test + @Disabled + public void testStream() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + Flux flux = chatModel.stream(new Prompt(messages)); + // 打印结果 + flux.doOnNext(response -> { +// System.out.println(response); + System.out.println(response.getResult().getOutput()); + }).then().block(); + } + +}