From 0139317ac4ee3a05be5ab9d432191460c1107ae3 Mon Sep 17 00:00:00 2001 From: YunaiV Date: Sat, 6 Jul 2024 14:47:56 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E3=80=91AI=EF=BC=9A=E5=AE=8C=E5=96=84=20LlamaChatModelTests=20?= =?UTF-8?q?=E5=8D=95=E6=B5=8B=EF=BC=8C=E6=96=B9=E4=BE=BF=E5=A4=A7=E5=AE=B6?= =?UTF-8?q?=E5=BF=AB=E9=80=9F=E4=BD=93=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/core/factory/AiModelFactoryImpl.java | 6 +- .../ai/chat/LlamaChatModelTests.java | 63 +++++++++++++++++++ 2 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index f561dacb5..b95982ed3 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -55,8 +55,6 @@ public class AiModelFactoryImpl implements AiModelFactory { return Singleton.get(cacheKey, (Func0) () -> { //noinspection EnhancedSwitchMigration switch (platform) { - case OLLAMA: - return buildOllamaChatClient(url); case YI_YAN: return buildYiYanChatClient(apiKey); case XING_HUO: @@ -67,6 +65,8 @@ public class AiModelFactoryImpl implements AiModelFactory { return buildDeepSeekChatClient(apiKey); case OPENAI: return buildOpenAiChatModel(apiKey, url); + case OLLAMA: + return buildOllamaChatModel(url); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); } @@ -163,7 +163,7 @@ public class AiModelFactoryImpl implements AiModelFactory { /** * 可参考 {@link OllamaAutoConfiguration} */ - private static OllamaChatModel buildOllamaChatClient(String url) { + private static OllamaChatModel buildOllamaChatModel(String url) { OllamaApi ollamaApi = new OllamaApi(url); return new OllamaChatModel(ollamaApi); } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java new file mode 100644 index 000000000..c6b99f287 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java @@ -0,0 +1,63 @@ +package cn.iocoder.yudao.framework.ai.chat; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; + +/** + * {@link OllamaChatModel} 集成测试 + * + * @author 芋道源码 + */ +public class LlamaChatModelTests { + + private final OllamaApi ollamaApi = new OllamaApi( + "http://127.0.0.1:11434"); + private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi, + OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName())); + + @Test + @Disabled + public void testCall() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + ChatResponse response = chatModel.call(new Prompt(messages)); + // 打印结果 + System.out.println(response); + System.out.println(response.getResult().getOutput()); + } + + @Test + @Disabled + public void testStream() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + Flux flux = chatModel.stream(new Prompt(messages)); + // 打印结果 + flux.doOnNext(response -> { +// System.out.println(response); + System.out.println(response.getResult().getOutput()); + }).then().block(); + } + +}