【代码优化】AI:完善 YiYanChatTests 单测,方便大家快速体验

This commit is contained in:
YunaiV 2024-07-06 16:45:49 +08:00
parent 4f11d00cfd
commit 4daff93313
3 changed files with 67 additions and 83 deletions

View File

@ -58,7 +58,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
case TONG_YI: case TONG_YI:
return buildTongYiChatModel(apiKey); return buildTongYiChatModel(apiKey);
case YI_YAN: case YI_YAN:
return buildYiYanChatClient(apiKey); return buildYiYanChatModel(apiKey);
case XING_HUO: case XING_HUO:
return buildXingHuoChatClient(apiKey); return buildXingHuoChatClient(apiKey);
case DEEP_SEEK: case DEEP_SEEK:
@ -156,6 +156,18 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
} }
/**
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private static QianFanChatModel buildYiYanChatModel(String key) {
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
String appKey = keys.get(0);
String secretKey = keys.get(1);
QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
return new QianFanChatModel(qianFanApi);
}
/** /**
* 可参考 {@link OpenAiAutoConfiguration} * 可参考 {@link OpenAiAutoConfiguration}
*/ */
@ -182,19 +194,6 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new OllamaChatModel(ollamaApi); return new OllamaChatModel(ollamaApi);
} }
/**
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private static QianFanChatModel buildYiYanChatClient(String key) {
// TODO @xin貌似目前设置request 势必会报错看看能不能有办法参考 buildQianWenChatClient调用 QianFanAutoConfiguration#qianFanChatModel初始化当然 key 要用自己的哈
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
String appKey = keys.get(0);
String secretKey = keys.get(1);
QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
return new QianFanChatModel(qianFanApi);
}
/** /**
* 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)} * 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
*/ */

View File

@ -27,9 +27,7 @@ public class AiUtils {
case OLLAMA: case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN: case YI_YAN:
// TODO @xin貌似 model 只要一设置就报错可以排查下 return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO: case XING_HUO:
return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build(); return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
case TONG_YI: case TONG_YI:

View File

@ -1,74 +1,61 @@
package cn.iocoder.yudao.framework.ai.chat; package cn.iocoder.yudao.framework.ai.chat;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient; import org.junit.jupiter.api.Disabled;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions; import org.junit.jupiter.api.Test;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi; import org.springframework.ai.chat.messages.Message;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel; import org.springframework.ai.chat.messages.UserMessage;
//import org.junit.Before; import org.springframework.ai.chat.model.ChatResponse;
//import org.junit.Test; import org.springframework.ai.chat.prompt.Prompt;
//import org.springframework.ai.chat.messages.Message; import org.springframework.ai.qianfan.QianFanChatModel;
//import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.qianfan.QianFanChatOptions;
//import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.qianfan.api.QianFanApi;
//import org.springframework.ai.chat.model.ChatResponse; import reactor.core.publisher.Flux;
//import org.springframework.ai.chat.prompt.Prompt;
//import reactor.core.publisher.Flux; import java.util.ArrayList;
// import java.util.List;
//import java.util.ArrayList;
//import java.util.List;
//import java.util.Scanner;
// TODO 芋艿整理单测
/** /**
* chat 文心一言 * {@link QianFanChatModel} 的集成测试
* <p> *
* author: fansili * @author fansili
* time: 2024/3/12 20:59
*/ */
public class YiYanChatTests { public class YiYanChatTests {
// private YiYanChatClient yiYanChatClient; private final QianFanApi qianFanApi = new QianFanApi(
// "qS8k8dYr2nXunagK4SSU8Xjj",
// @Before "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
// public void setup() { private final QianFanChatModel chatModel = new QianFanChatModel(qianFanApi,
// YiYanApi yiYanApi = new YiYanApi( QianFanChatOptions.builder().withModel(QianFanApi.ChatModel.ERNIE_Tiny_8K.getValue()).build()
// "x0cuLZ7XsaTCU08vuJWO87Lg", );
// "R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK",
// YiYanChatModel.ERNIE4_3_5_8K, @Test
// 86400 @Disabled
// ); public void testCall() {
// YiYanChatOptions yiYanOptions = new YiYanChatOptions(); // 准备参数
// yiYanOptions.setMaxOutputTokens(2048); List<Message> messages = new ArrayList<>();
// yiYanOptions.setTopP(0.6f); // TODO @芋艿文心一言只要带上 system message 就报错已经各种测试很莫名
// yiYanOptions.setTemperature(0.85f); // messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
// yiYanChatClient = new YiYanChatClient( messages.add(new UserMessage("1 + 1 = "));
// yiYanApi,
// yiYanOptions // 调用
// ); ChatResponse response = chatModel.call(new Prompt(messages));
// } // 打印结果
// System.out.println(response);
// @Test }
// public void callTest() {
// @Test
// // tip: 百度的message 有特殊规则(最后一个message为当前请求的信息前面的message为历史对话信息) @Disabled
// // tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11 public void testStream() {
// List<Message> messages = new ArrayList<>(); // 准备参数
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。")); List<Message> messages = new ArrayList<>();
// messages.add(new UserMessage("长沙怎么样?")); // TODO @芋艿文心一言只要带上 system message 就报错已经各种测试很莫名
// // messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
// ChatResponse call = yiYanChatClient.call(new Prompt(messages)); messages.add(new UserMessage("1 + 1 = "));
// System.err.println(call.getResult());
// } // 调用
// Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// @Test // 打印结果
// public void streamTest() { flux.doOnNext(System.out::println).then().block();
// List<Message> messages = new ArrayList<>(); }
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
// messages.add(new UserMessage("长沙怎么样?"));
//
// Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt(messages));
// fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
// // 阻止退出
// Scanner scanner = new Scanner(System.in);
// scanner.nextLine();
// }
} }