【代码优化】AI:完善 LlamaChatModelTests 单测,方便大家快速体验

This commit is contained in:
YunaiV 2024-07-06 14:47:56 +08:00
parent 1feff2b12b
commit 0139317ac4
2 changed files with 66 additions and 3 deletions

View File

@ -55,8 +55,6 @@ public class AiModelFactoryImpl implements AiModelFactory {
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> { return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OLLAMA:
return buildOllamaChatClient(url);
case YI_YAN: case YI_YAN:
return buildYiYanChatClient(apiKey); return buildYiYanChatClient(apiKey);
case XING_HUO: case XING_HUO:
@ -67,6 +65,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return buildDeepSeekChatClient(apiKey); return buildDeepSeekChatClient(apiKey);
case OPENAI: case OPENAI:
return buildOpenAiChatModel(apiKey, url); return buildOpenAiChatModel(apiKey, url);
case OLLAMA:
return buildOllamaChatModel(url);
default: default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
} }
@ -163,7 +163,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
/** /**
* 可参考 {@link OllamaAutoConfiguration} * 可参考 {@link OllamaAutoConfiguration}
*/ */
private static OllamaChatModel buildOllamaChatClient(String url) { private static OllamaChatModel buildOllamaChatModel(String url) {
OllamaApi ollamaApi = new OllamaApi(url); OllamaApi ollamaApi = new OllamaApi(url);
return new OllamaChatModel(ollamaApi); return new OllamaChatModel(ollamaApi);
} }

View File

@ -0,0 +1,63 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* {@link OllamaChatModel} 集成测试
*
* @author 芋道源码
*/
public class LlamaChatModelTests {
private final OllamaApi ollamaApi = new OllamaApi(
"http://127.0.0.1:11434");
private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi,
OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName()));
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
System.out.println(response.getResult().getOutput());
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(response -> {
// System.out.println(response);
System.out.println(response.getResult().getOutput());
}).then().block();
}
}