diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/image/ImageClient.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/image/ImageClient.java index 00bd3e176..98fc44ff6 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/image/ImageClient.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/image/ImageClient.java @@ -19,7 +19,6 @@ package cn.iocoder.yudao.framework.ai.image; import cn.iocoder.yudao.framework.ai.model.ModelClient; -@FunctionalInterface public interface ImageClient extends ModelClient { /** diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageApi.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageApi.java index bcc03f2ee..0f651ac09 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageApi.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageApi.java @@ -1,14 +1,25 @@ package cn.iocoder.yudao.framework.ai.imageopenai; +import cn.hutool.json.JSONUtil; import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageRequest; import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageResponse; import cn.iocoder.yudao.framework.ai.util.JacksonUtil; import io.netty.channel.ChannelOption; +import lombok.extern.slf4j.Slf4j; +import org.apache.http.HttpEntity; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.util.EntityUtils; import org.springframework.http.client.reactive.ReactorClientHttpConnector; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.WebClient; import reactor.netty.http.client.HttpClient; +import java.io.IOException; +import java.net.URI; import java.time.Duration; /** @@ -17,6 +28,7 @@ import java.time.Duration; * author: fansili * time: 2024/3/17 09:53 */ +@Slf4j public class OpenAiImageApi { private static final String DEFAULT_BASE_URL = "https://api.openai.com"; @@ -24,6 +36,8 @@ public class OpenAiImageApi { // 发送请求 webClient private final WebClient webClient; + private CloseableHttpClient httpclient = HttpClients.createDefault(); + public OpenAiImageApi(String apiKey) { this.apiKey = apiKey; // 创建一个HttpClient实例并设置超时 @@ -37,18 +51,40 @@ public class OpenAiImageApi { } public OpenAiImageResponse createImage(OpenAiImageRequest request) { - String res = webClient.post() - .uri(uriBuilder -> uriBuilder.path("/v1/images/generations").build()) - .header("Content-Type", "application/json") - .header("Authorization", "Bearer " + apiKey) - // 设置请求体(这里假设jsonStr是一个JSON格式的字符串) - .body(BodyInserters.fromValue(JacksonUtil.toJson(request))) - // 发送请求并获取响应体 - .retrieve() - // 转换响应体为String类型 - .bodyToMono(String.class) - .block(); - // TODO: 2024/3/17 这里发送请求会失败! - return null; + HttpPost httpPost = new HttpPost(); + httpPost.setURI(URI.create(DEFAULT_BASE_URL.concat("/v1/images/generations"))); + httpPost.setHeader("Content-Type", "application/json"); + httpPost.setHeader("Authorization", "Bearer " + apiKey); + httpPost.setEntity(new StringEntity(JacksonUtil.toJson(request), "UTF-8")); + + CloseableHttpResponse response= null; + try { + response = httpclient.execute(httpPost); + HttpEntity entity = response.getEntity(); + String resultJson = EntityUtils.toString(entity); + log.info("openai 图片生成结果: {}", resultJson); + return JSONUtil.toBean(resultJson, OpenAiImageResponse.class); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + if (response != null) { + try { + response.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } +// String res = webClient.post() +// .uri(uriBuilder -> uriBuilder.path("/v1/images/generations").build()) +// .header("Content-Type", "application/json") +// .header("Authorization", "Bearer " + apiKey) +// // 设置请求体(这里假设jsonStr是一个JSON格式的字符串) +// .body(BodyInserters.fromValue(JacksonUtil.toJson(request))) +// // 发送请求并获取响应体 +// .retrieve() +// // 转换响应体为String类型 +// .bodyToMono(String.class) +// .block(); } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageClient.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageClient.java index a1bb59db1..a1083cd18 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageClient.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageClient.java @@ -1,15 +1,13 @@ package cn.iocoder.yudao.framework.ai.imageopenai; import cn.hutool.core.bean.BeanUtil; +import cn.hutool.core.codec.Base64; +import cn.hutool.http.HttpUtil; import cn.iocoder.yudao.framework.ai.chat.ChatException; import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException; -import cn.iocoder.yudao.framework.ai.image.ImageClient; -import cn.iocoder.yudao.framework.ai.image.ImageOptions; -import cn.iocoder.yudao.framework.ai.image.ImagePrompt; -import cn.iocoder.yudao.framework.ai.image.ImageResponse; +import cn.iocoder.yudao.framework.ai.image.*; import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageRequest; import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageResponse; -import jdk.jfr.Frequency; import lombok.extern.slf4j.Slf4j; import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryContext; @@ -74,9 +72,15 @@ public class OpenAiImageClient implements ImageClient { // 创建请求 OpenAiImageRequest request = new OpenAiImageRequest(); BeanUtil.copyProperties(openAiImageOptions, request); + request.setPrompt(imagePrompt.getInstructions().get(0).getText()); // 发送请求 OpenAiImageResponse response = openAiImageApi.createImage(request); - return null; + return new ImageResponse(response.getData().stream().map(res -> { + byte[] bytes = HttpUtil.downloadBytes(res.getUrl()); + String base64 = Base64.encode(bytes); + return new ImageGeneration(new Image(res.getUrl(), base64)); + }).toList()); }); } + } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageOptions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageOptions.java index f18ccd298..c80396521 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageOptions.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageOptions.java @@ -2,6 +2,7 @@ package cn.iocoder.yudao.framework.ai.imageopenai; import cn.iocoder.yudao.framework.ai.image.ImageOptions; import lombok.Data; +import lombok.Getter; import lombok.experimental.Accessors; /** @@ -47,6 +48,21 @@ public class OpenAiImageOptions implements ImageOptions { // 代表您的终端用户的唯一标识符,有助于OpenAI监控并检测滥用行为。了解更多信息请参考官方文档。 private String endUserId; + @Getter + public enum ResponseFormatEnum { + + URL("url"), + BASE64("b64_json"), + + ; + + ResponseFormatEnum(String value) { + this.value = value; + } + + private String value; + } + // // 适配 spring ai diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/api/OpenAiImageResponse.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/api/OpenAiImageResponse.java index 04de1494d..1f4ab6152 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/api/OpenAiImageResponse.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/api/OpenAiImageResponse.java @@ -23,6 +23,7 @@ public class OpenAiImageResponse { public static class Item { private String url; + private String b64_json; } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageClientTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageClientTests.java index 83d556930..8a6fb5f23 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageClientTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageClientTests.java @@ -6,6 +6,14 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions; import org.junit.Before; import org.junit.Test; +import javax.imageio.ImageIO; +import javax.swing.*; +import java.awt.image.BufferedImage; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.Base64; +import java.util.Scanner; + /** * author: fansili * time: 2024/3/17 10:40 @@ -20,12 +28,40 @@ public class OpenAiImageClientTests { // 初始化 openAiImageClient this.openAiImageClient = new OpenAiImageClient( new OpenAiImageApi(""), - new OpenAiImageOptions() + new OpenAiImageOptions().setResponseFormat(OpenAiImageOptions.ResponseFormatEnum.URL.getValue()) ); } @Test public void callTest() { - openAiImageClient.call(new ImagePrompt("我和我的小狗,一起在北极和企鹅玩排球。")); + ImageResponse call = openAiImageClient.call(new ImagePrompt("我和我的小狗,一起在北极和企鹅玩排球。")); + System.err.println("url: " + call.getResult().getOutput().getUrl()); + System.err.println("base64: " + call.getResult().getOutput().getB64Json()); + + String base64String = call.getResult().getOutput().getB64Json(); + ImageIcon imageIcon = new ImageIcon(decodeBase64ToImage(base64String)); + JLabel label = new JLabel(imageIcon); + + JFrame frame = new JFrame("Base64 Image Display"); + frame.getContentPane().add(label); + frame.pack(); + frame.setVisible(true); + + // 阻止退出 + Scanner scanner = new Scanner(System.in); + scanner.nextLine(); + } + + + // 将Base64解码为BufferedImage + private static BufferedImage decodeBase64ToImage(String base64String) { + try { + byte[] decodedBytes = Base64.getDecoder().decode(base64String); + ByteArrayInputStream bis = new ByteArrayInputStream(decodedBytes); + return ImageIO.read(bis); + } catch (IOException e) { + System.out.println("Error decoding the base64 image: " + e.getMessage()); + return null; + } } }