diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index 087a6b727..da4c5e4b2 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -58,7 +58,7 @@ public class AiModelFactoryImpl implements AiModelFactory { case TONG_YI: return buildTongYiChatModel(apiKey); case YI_YAN: - return buildYiYanChatClient(apiKey); + return buildYiYanChatModel(apiKey); case XING_HUO: return buildXingHuoChatClient(apiKey); case DEEP_SEEK: @@ -156,6 +156,18 @@ public class AiModelFactoryImpl implements AiModelFactory { return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); } + /** + * 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} + */ + private static QianFanChatModel buildYiYanChatModel(String key) { + List keys = StrUtil.split(key, '|'); + Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式"); + String appKey = keys.get(0); + String secretKey = keys.get(1); + QianFanApi qianFanApi = new QianFanApi(appKey, secretKey); + return new QianFanChatModel(qianFanApi); + } + /** * 可参考 {@link OpenAiAutoConfiguration} */ @@ -182,19 +194,6 @@ public class AiModelFactoryImpl implements AiModelFactory { return new OllamaChatModel(ollamaApi); } - /** - * 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} - */ - private static QianFanChatModel buildYiYanChatClient(String key) { - // TODO @xin:貌似目前设置,request 势必会报错;看看能不能有办法,参考 buildQianWenChatClient,调用 QianFanAutoConfiguration#qianFanChatModel初始化,当然 key 要用自己的哈 - List keys = StrUtil.split(key, '|'); - Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式"); - String appKey = keys.get(0); - String secretKey = keys.get(1); - QianFanApi qianFanApi = new QianFanApi(appKey, secretKey); - return new QianFanChatModel(qianFanApi); - } - /** * 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)} */ diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java index 9e8433d9e..b6c55c95f 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java @@ -27,9 +27,7 @@ public class AiUtils { case OLLAMA: return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); case YI_YAN: - // TODO @xin:貌似 model 只要一设置,就报错;可以排查下 -// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); case XING_HUO: return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build(); case TONG_YI: diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java index e0e708105..14031a195 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java @@ -1,74 +1,61 @@ package cn.iocoder.yudao.framework.ai.chat; -//import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient; -//import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions; -//import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi; -//import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel; -//import org.junit.Before; -//import org.junit.Test; -//import org.springframework.ai.chat.messages.Message; -//import org.springframework.ai.chat.messages.SystemMessage; -//import org.springframework.ai.chat.messages.UserMessage; -//import org.springframework.ai.chat.model.ChatResponse; -//import org.springframework.ai.chat.prompt.Prompt; -//import reactor.core.publisher.Flux; -// -//import java.util.ArrayList; -//import java.util.List; -//import java.util.Scanner; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.qianfan.QianFanChatModel; +import org.springframework.ai.qianfan.QianFanChatOptions; +import org.springframework.ai.qianfan.api.QianFanApi; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; -// TODO 芋艿:整理单测 /** - * chat 文心一言 - *

- * author: fansili - * time: 2024/3/12 20:59 + * {@link QianFanChatModel} 的集成测试 + * + * @author fansili */ public class YiYanChatTests { -// private YiYanChatClient yiYanChatClient; -// -// @Before -// public void setup() { -// YiYanApi yiYanApi = new YiYanApi( -// "x0cuLZ7XsaTCU08vuJWO87Lg", -// "R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK", -// YiYanChatModel.ERNIE4_3_5_8K, -// 86400 -// ); -// YiYanChatOptions yiYanOptions = new YiYanChatOptions(); -// yiYanOptions.setMaxOutputTokens(2048); -// yiYanOptions.setTopP(0.6f); -// yiYanOptions.setTemperature(0.85f); -// yiYanChatClient = new YiYanChatClient( -// yiYanApi, -// yiYanOptions -// ); -// } -// -// @Test -// public void callTest() { -// -// // tip: 百度的message 有特殊规则(最后一个message为当前请求的信息,前面的message为历史对话信息) -// // tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11 -// List messages = new ArrayList<>(); -// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。")); -// messages.add(new UserMessage("长沙怎么样?")); -// -// ChatResponse call = yiYanChatClient.call(new Prompt(messages)); -// System.err.println(call.getResult()); -// } -// -// @Test -// public void streamTest() { -// List messages = new ArrayList<>(); -// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。")); -// messages.add(new UserMessage("长沙怎么样?")); -// -// Flux fluxResponse = yiYanChatClient.stream(new Prompt(messages)); -// fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent())); -// // 阻止退出 -// Scanner scanner = new Scanner(System.in); -// scanner.nextLine(); -// } + private final QianFanApi qianFanApi = new QianFanApi( + "qS8k8dYr2nXunagK4SSU8Xjj", + "pHGbx51ql2f0hOyabQvSZezahVC3hh3e"); + private final QianFanChatModel chatModel = new QianFanChatModel(qianFanApi, + QianFanChatOptions.builder().withModel(QianFanApi.ChatModel.ERNIE_Tiny_8K.getValue()).build() + ); + + @Test + @Disabled + public void testCall() { + // 准备参数 + List messages = new ArrayList<>(); + // TODO @芋艿:文心一言,只要带上 system message 就报错,已经各种测试,很莫名! +// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + ChatResponse response = chatModel.call(new Prompt(messages)); + // 打印结果 + System.out.println(response); + } + + @Test + @Disabled + public void testStream() { + // 准备参数 + List messages = new ArrayList<>(); + // TODO @芋艿:文心一言,只要带上 system message 就报错,已经各种测试,很莫名! +// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + Flux flux = chatModel.stream(new Prompt(messages)); + // 打印结果 + flux.doOnNext(System.out::println).then().block(); + } + }