【代码优化】AI:完善 TongYiChatModelTests 单测,方便大家快速体验

This commit is contained in:
YunaiV 2024-07-06 15:45:18 +08:00
parent 0139317ac4
commit 4f11d00cfd
5 changed files with 97 additions and 126 deletions

View File

@ -14,8 +14,8 @@ public enum AiPlatformEnum {
// ========== 国内平台 ========== // ========== 国内平台 ==========
TONG_YI("TongYi", "通义千问"), // 阿里
YI_YAN("YiYan", "文心一言"), // 百度 YI_YAN("YiYan", "文心一言"), // 百度
QIAN_WEN("QianWen", "千问"), // 阿里
DEEP_SEEK("DeepSeek", "DeepSeek"), // DeepSeek DEEP_SEEK("DeepSeek", "DeepSeek"), // DeepSeek
XING_HUO("XingHuo", "星火"), // 讯飞 XING_HUO("XingHuo", "星火"), // 讯飞

View File

@ -55,12 +55,12 @@ public class AiModelFactoryImpl implements AiModelFactory {
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> { return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case TONG_YI:
return buildTongYiChatModel(apiKey);
case YI_YAN: case YI_YAN:
return buildYiYanChatClient(apiKey); return buildYiYanChatClient(apiKey);
case XING_HUO: case XING_HUO:
return buildXingHuoChatClient(apiKey); return buildXingHuoChatClient(apiKey);
case QIAN_WEN:
return buildQianWenChatClient(apiKey);
case DEEP_SEEK: case DEEP_SEEK:
return buildDeepSeekChatClient(apiKey); return buildDeepSeekChatClient(apiKey);
case OPENAI: case OPENAI:
@ -77,16 +77,16 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ChatModel getDefaultChatModel(AiPlatformEnum platform) { public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OLLAMA: case TONG_YI:
return SpringUtil.getBean(OllamaChatModel.class); return SpringUtil.getBean(TongYiChatModel.class);
case YI_YAN: case YI_YAN:
return SpringUtil.getBean(QianFanChatModel.class); return SpringUtil.getBean(QianFanChatModel.class);
case XING_HUO: case XING_HUO:
return SpringUtil.getBean(XingHuoChatClient.class); return SpringUtil.getBean(XingHuoChatClient.class);
case QIAN_WEN:
return SpringUtil.getBean(TongYiChatModel.class);
case OPENAI: case OPENAI:
return SpringUtil.getBean(OpenAiChatModel.class); return SpringUtil.getBean(OpenAiChatModel.class);
case OLLAMA:
return SpringUtil.getBean(OllamaChatModel.class);
default: default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
} }
@ -142,6 +142,20 @@ public class AiModelFactoryImpl implements AiModelFactory {
// ========== 各种创建 spring-ai 客户端的方法 ========== // ========== 各种创建 spring-ai 客户端的方法 ==========
/**
* 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)}
*/
private static TongYiChatModel buildTongYiChatModel(String key) {
com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class);
TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class);
// TODO @芋艿貌似 apiKey 是全局唯一的得测试下
// TODO @芋艿貌似阿里云不是增量返回的
// issue 进行跟进中 https://github.com/alibaba/spring-cloud-alibaba/issues/3790
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
connectionProperties.setApiKey(key);
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
}
/** /**
* 可参考 {@link OpenAiAutoConfiguration} * 可参考 {@link OpenAiAutoConfiguration}
*/ */
@ -196,19 +210,6 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new DeepSeekChatClient(apiKey); return new DeepSeekChatClient(apiKey);
} }
/**
* 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)}
*/
private static TongYiChatModel buildQianWenChatClient(String key) {
com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class);
TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class);
// TODO @xin貌似 apiKey 是全局唯一的得测试下
// TODO @xin貌似阿里云不是增量返回的
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
connectionProperties.setApiKey(key);
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
}
private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) { private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) {
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url); StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);

View File

@ -32,7 +32,7 @@ public class AiUtils {
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build(); return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO: case XING_HUO:
return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build(); return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
case QIAN_WEN: case TONG_YI:
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build(); return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
case DEEP_SEEK: case DEEP_SEEK:
return DeepSeekChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build(); return DeepSeekChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();

View File

@ -1,105 +0,0 @@
//package cn.iocoder.yudao.framework.ai.chat;
//
//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
//import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
//import com.alibaba.dashscope.aigc.generation.GenerationResult;
//import com.alibaba.dashscope.aigc.generation.models.QwenParam;
//import com.alibaba.dashscope.common.Message;
//import com.alibaba.dashscope.common.MessageManager;
//import com.alibaba.dashscope.common.Role;
//import com.alibaba.dashscope.exception.InputRequiredException;
//import com.alibaba.dashscope.exception.NoApiKeyException;
//import org.junit.Before;
//import org.junit.Test;
//import org.springframework.ai.chat.messages.SystemMessage;
//import org.springframework.ai.chat.messages.UserMessage;
//import org.springframework.ai.chat.model.ChatResponse;
//import org.springframework.ai.chat.prompt.Prompt;
//import reactor.core.publisher.Flux;
//
//import java.util.ArrayList;
//import java.util.List;
//import java.util.Scanner;
//import java.util.function.Consumer;
//
//// TODO 芋艿整理单测
///**
// * author: fansili
// * time: 2024/3/13 21:37
// */
//public class QianWenChatClientTests {
//
// private QianWenChatClient qianWenChatClient;
//
// @Before
// public void setup() {
// QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT);
// QianWenOptions qianWenOptions = new QianWenOptions();
// qianWenOptions.setTopP(0.8F);
//// qianWenOptions.setTopK(3); TODO 芋艿临时处理
//// qianWenOptions.setTemperature(0.6F); TODO 芋艿临时处理
// qianWenChatClient = new QianWenChatClient(
// qianWenApi,
// qianWenOptions
// );
// }
//
// @Test
// public void callTest() {
// List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
// messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。"));
// messages.add(new UserMessage("长沙怎么样?"));
//
// ChatResponse call = qianWenChatClient.call(new Prompt(messages));
// System.err.println(call.getResult());
// }
//
// @Test
// public void streamTest() {
// List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
// messages.add(new UserMessage("长沙怎么样?"));
//
// Flux<ChatResponse> flux = qianWenChatClient.stream(new Prompt(messages));
// flux.subscribe(new Consumer<ChatResponse>() {
// @Override
// public void accept(ChatResponse chatResponse) {
// System.err.print(chatResponse.getResult().getOutput().getContent());
// }
// });
//
// // 阻止退出
// Scanner scanner = new Scanner(System.in);
// scanner.nextLine();
// }
//
// @Test
// public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException {
// com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation();
// MessageManager msgManager = new MessageManager(10);
// Message systemMsg =
// Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build();
// Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build();
// msgManager.add(systemMsg);
// msgManager.add(userMsg);
// QwenParam param =
// QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get())
// .resultFormat(QwenParam.ResultFormat.MESSAGE)
// .topP(0.8)
// /* set the random seed, optional, default to 1234 if not set */
// .seed(100)
// .apiKey("sk-Zsd81gZYg7")
// .build();
// GenerationResult result = gen.call(param);
// System.out.println(result);
// System.out.println("-----------------");
// System.out.println("-----------------");
// msgManager.add(result);
// param.setPrompt("能否缩短一些,只讲三点");
// param.setMessages(msgManager.get());
// result = gen.call(param);
// System.out.println(result);
// }
//}

View File

@ -0,0 +1,75 @@
package cn.iocoder.yudao.framework.ai.chat;
import cn.hutool.core.util.ReflectUtil;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.common.MessageManager;
import com.alibaba.dashscope.utils.Constants;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* {@link TongYiChatModel} 集成测试类
*
* @author fansili
*/
public class TongYiChatModelTests {
private final Generation generation = new Generation();
private final TongYiChatModel chatModel = new TongYiChatModel(generation,
TongYiChatOptions.builder().withModel("qwen1.5-72b-chat").build());
static {
Constants.apiKey = "sk-Zsd81gZYg7";
}
@BeforeEach
public void before() {
// 防止 TongYiChatModel 调用空指针
ReflectUtil.setFieldValue(chatModel, "msgManager", new MessageManager());
}
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
System.out.println(response.getResult().getOutput());
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(response -> {
// System.out.println(response);
System.out.println(response.getResult().getOutput());
}).then().block();
}
}