From 1feff2b12b4f3c7b5711ce980ae94492bbad44c2 Mon Sep 17 00:00:00 2001 From: YunaiV Date: Sat, 6 Jul 2024 14:11:21 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E3=80=91AI=EF=BC=9A=E5=AE=8C=E5=96=84=20OpenAIChatModelTests?= =?UTF-8?q?=E3=80=81OpenAiImageModelTests=20=E5=8D=95=E6=B5=8B=EF=BC=8C?= =?UTF-8?q?=E6=96=B9=E4=BE=BF=E5=A4=A7=E5=AE=B6=E5=BF=AB=E9=80=9F=E4=BD=93?= =?UTF-8?q?=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/service/model/AiApiKeyServiceImpl.java | 2 +- .../ai/core/factory/AiModelFactory.java | 2 +- .../ai/core/factory/AiModelFactoryImpl.java | 2 +- .../ai/chat/OpenAIChatModelTests.java | 64 +++++++++++++++++ .../ai/image/OpenAiImageClientTests.java | 70 ------------------- .../ai/image/OpenAiImageModelTests.java | 41 +++++++++++ 6 files changed, 108 insertions(+), 73 deletions(-) create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/OpenAIChatModelTests.java delete mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageClientTests.java create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java index 6f8f06076..7404df83c 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java @@ -101,7 +101,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { public ChatModel getChatModel(Long id) { AiApiKeyDO apiKey = validateApiKey(id); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); - return modelFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); + return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl()); } @Override diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java index e1d3aba63..b6d7b3dd0 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java @@ -23,7 +23,7 @@ public interface AiModelFactory { * @param url API URL * @return ChatModel 对象 */ - ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url); + ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url); /** * 基于默认配置,获得 ChatModel 对象 diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index ca01a6611..f561dacb5 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -50,7 +50,7 @@ import java.util.List; public class AiModelFactoryImpl implements AiModelFactory { @Override - public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) { + public ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url) { String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url); return Singleton.get(cacheKey, (Func0) () -> { //noinspection EnhancedSwitchMigration diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/OpenAIChatModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/OpenAIChatModelTests.java new file mode 100644 index 000000000..77a77e364 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/OpenAIChatModelTests.java @@ -0,0 +1,64 @@ +package cn.iocoder.yudao.framework.ai.chat; + +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; + +/** + * {@link XingHuoChatClient} 集成测试 + * + * @author 芋道源码 + */ +public class OpenAIChatModelTests { + + private final OpenAiApi openAiApi = new OpenAiApi( + "https://api.holdai.top", + "sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf"); + private final OpenAiChatModel chatModel = new OpenAiChatModel(openAiApi, + OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O).build()); + + @Test + @Disabled + public void testCall() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + ChatResponse response = chatModel.call(new Prompt(messages)); + // 打印结果 + System.out.println(response); + System.out.println(response.getResult().getOutput()); + } + + @Test + @Disabled + public void testStream() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + Flux flux = chatModel.stream(new Prompt(messages)); + // 打印结果 + flux.doOnNext(response -> { +// System.out.println(response); + System.out.println(response.getResult().getOutput()); + }).then().block(); + } + +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageClientTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageClientTests.java deleted file mode 100644 index f942f25bf..000000000 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageClientTests.java +++ /dev/null @@ -1,70 +0,0 @@ -package cn.iocoder.yudao.framework.ai.image; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.springframework.ai.image.ImagePrompt; -import org.springframework.ai.image.ImageResponse; -import org.springframework.ai.openai.OpenAiImageModel; -import org.springframework.ai.openai.api.OpenAiImageApi; - -import javax.imageio.ImageIO; -import javax.swing.*; -import java.awt.image.BufferedImage; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.util.Base64; -import java.util.Scanner; - -// TODO 芋艿:整理单测 -/** - * author: fansili - * time: 2024/3/17 10:40 - */ -public class OpenAiImageClientTests { - - - private OpenAiImageModel openAiImageClient; - - @BeforeEach - public void setup() { - // 初始化 openAiImageClient - this.openAiImageClient = new OpenAiImageModel( - new OpenAiImageApi("") -// new OpenAiImageOptions().setResponseFormat(OpenAiImageOptions.ResponseFormatEnum.URL.getValue()) TODO 芋艿:临时处理 - ); - } - - @Test - public void callTest() { - ImageResponse call = openAiImageClient.call(new ImagePrompt("中国长城!")); - System.err.println("url: " + call.getResult().getOutput().getUrl()); - System.err.println("base64: " + call.getResult().getOutput().getB64Json()); - - String base64String = call.getResult().getOutput().getB64Json(); - ImageIcon imageIcon = new ImageIcon(decodeBase64ToImage(base64String)); - JLabel label = new JLabel(imageIcon); - - JFrame frame = new JFrame("Base64 Image Display"); - frame.getContentPane().add(label); - frame.pack(); - frame.setVisible(true); - - // 阻止退出 - Scanner scanner = new Scanner(System.in); - scanner.nextLine(); - } - - - // 将Base64解码为BufferedImage - private static BufferedImage decodeBase64ToImage(String base64String) { - try { - byte[] decodedBytes = Base64.getDecoder().decode(base64String); - ByteArrayInputStream bis = new ByteArrayInputStream(decodedBytes); - return ImageIO.read(bis); - } catch (IOException e) { - System.out.println("Error decoding the base64 image: " + e.getMessage()); - return null; - } - } - -} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java new file mode 100644 index 000000000..7b2919d1e --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java @@ -0,0 +1,41 @@ +package cn.iocoder.yudao.framework.ai.image; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.openai.OpenAiImageModel; +import org.springframework.ai.openai.OpenAiImageOptions; +import org.springframework.ai.openai.api.OpenAiImageApi; +import org.springframework.web.client.RestClient; + +/** + * {@link OpenAiImageModel} 集成测试类 + * + * @author fansili + */ +public class OpenAiImageModelTests { + + private final OpenAiImageApi imageApi = new OpenAiImageApi( + "https://api.holdai.top", + "sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf", + RestClient.builder()); + private final OpenAiImageModel imageClient = new OpenAiImageModel(imageApi); + + @Test + @Disabled + public void testCall() { + // 准备参数 + ImageOptions options = OpenAiImageOptions.builder() + .withModel(OpenAiImageApi.ImageModel.DALL_E_2.getValue()) // 这个模型比较便宜 + .withHeight(256).withWidth(256) + .build(); + ImagePrompt prompt = new ImagePrompt("中国长城!", options); + + // 方法调用 + ImageResponse response = imageClient.call(prompt); + System.out.println(response); + } + +}