From 03f124a82b5e2a083c06b39ea44be9c0373e3d36 Mon Sep 17 00:00:00 2001 From: cherishsince Date: Wed, 10 Apr 2024 20:02:06 +0800 Subject: [PATCH] =?UTF-8?q?1=E3=80=81=E8=AF=B7=E6=B1=82=E5=85=AC=E5=85=B1?= =?UTF-8?q?=E9=83=A8=E5=88=86=E6=8A=BD=E7=A6=BB=202=E3=80=81=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E4=B8=BAapi=E5=92=8Cspring=20ai=E7=BB=93=E6=9E=84?= =?UTF-8?q?=E4=BF=9D=E6=8C=81=E4=B8=80=E8=87=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api/MidjourneyInteractions.java | 83 +++++++++++++++++++ .../MidjourneyInteractionsApi.java} | 80 ++++-------------- 2 files changed, 101 insertions(+), 62 deletions(-) create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractions.java rename yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/{interactions/MidjourneyInteractions.java => api/MidjourneyInteractionsApi.java} (59%) diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractions.java new file mode 100644 index 000000000..4077912a5 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractions.java @@ -0,0 +1,83 @@ +package cn.iocoder.yudao.framework.ai.midjourney.api; + +import cn.hutool.core.util.IdUtil; +import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; +import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyConstants; +import com.google.common.collect.Maps; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; + +import java.util.HashMap; + +// TODO @fansili:按照 spring ai 的封装习惯,这个类是不是 MidjourneyApi + +/** + * 图片生成 + * + * author: fansili + * time: 2024/4/3 17:36 + */ +@Slf4j +public abstract class MidjourneyInteractions { + + // TODO done @fansili:静态变量,放在最前面哈; + /** + * header - referer 头信息 + */ + private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s"; + /** + * mj配置文件 + */ + protected final MidjourneyConfig midjourneyConfig; + + protected MidjourneyInteractions(MidjourneyConfig midjourneyConfig) { + this.midjourneyConfig = midjourneyConfig; + } + + /** + * 获取headers - application json + * + * @return + */ + protected HttpHeaders getHeadersOfAppJson() { + // 设置header值 + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.setContentType(MediaType.APPLICATION_JSON); + httpHeaders.set("Authorization", midjourneyConfig.getToken()); + httpHeaders.set("User-Agent", midjourneyConfig.getUserAage()); + httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE); + httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId())); + return httpHeaders; + } + + /** + * 获取headers - http form data + * + * @return + */ + protected HttpHeaders getHeadersOfFormData() { + // 设置header值 + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA); + httpHeaders.set("Authorization", midjourneyConfig.getToken()); + httpHeaders.set("User-Agent", midjourneyConfig.getUserAage()); + httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE); + httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId())); + return httpHeaders; + } + + /** + * 获取 - 默认参数 + * @return + */ + protected HashMap getDefaultParams() { + HashMap requestParams = Maps.newHashMap(); + // TODO @fansili:感觉参数的组装,可以搞成一个公用的方法;就是 config + 入参的感觉; + requestParams.put("guild_id", midjourneyConfig.getGuildId()); + requestParams.put("channel_id", midjourneyConfig.getChannelId()); + requestParams.put("session_id", midjourneyConfig.getSessionId()); + requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili:建议用 uuid 之类的;nextId 跨进程未必合适哈; + return requestParams; + } +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/interactions/MidjourneyInteractions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java similarity index 59% rename from yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/interactions/MidjourneyInteractions.java rename to yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java index 104e1dd2c..f71a81667 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/interactions/MidjourneyInteractions.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java @@ -1,20 +1,16 @@ -package cn.iocoder.yudao.framework.ai.midjourney.interactions; +package cn.iocoder.yudao.framework.ai.midjourney.api; -import cn.hutool.core.util.IdUtil; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; -import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyConstants; import cn.iocoder.yudao.framework.ai.midjourney.util.MidjourneyUtil; -import cn.iocoder.yudao.framework.ai.midjourney.vo.Attachments; -import cn.iocoder.yudao.framework.ai.midjourney.vo.Describe; -import cn.iocoder.yudao.framework.ai.midjourney.vo.ReRoll; -import cn.iocoder.yudao.framework.ai.midjourney.vo.UploadAttachmentsRes; +import cn.iocoder.yudao.framework.ai.midjourney.api.req.AttachmentsReq; +import cn.iocoder.yudao.framework.ai.midjourney.api.req.DescribeReq; +import cn.iocoder.yudao.framework.ai.midjourney.api.req.ReRollReq; +import cn.iocoder.yudao.framework.ai.midjourney.api.res.UploadAttachmentsRes; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import lombok.extern.slf4j.Slf4j; -import org.jetbrains.annotations.NotNull; import org.springframework.core.io.FileSystemResource; import org.springframework.http.*; import org.springframework.util.LinkedMultiValueMap; @@ -32,18 +28,13 @@ import java.util.HashMap; * time: 2024/4/3 17:36 */ @Slf4j -public class MidjourneyInteractions { - - // TODO done @fansili:静态变量,放在最前面哈; - private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s"; +public class MidjourneyInteractionsApi extends MidjourneyInteractions { private final String url; - private final MidjourneyConfig midjourneyConfig; private final RestTemplate restTemplate = new RestTemplate(); // TODO @fansili:优先级低:后续搞到统一的管理 - - public MidjourneyInteractions(MidjourneyConfig midjourneyConfig) { - this.midjourneyConfig = midjourneyConfig; + public MidjourneyInteractionsApi(MidjourneyConfig midjourneyConfig) { + super(midjourneyConfig); this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions()); } @@ -51,17 +42,12 @@ public class MidjourneyInteractions { // 获取请求模板 String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine"); // 设置参数 - HashMap requestParams = Maps.newHashMap(); - // TODO @fansili:感觉参数的组装,可以搞成一个公用的方法;就是 config + 入参的感觉; - requestParams.put("guild_id", midjourneyConfig.getGuildId()); - requestParams.put("channel_id", midjourneyConfig.getChannelId()); - requestParams.put("session_id", midjourneyConfig.getSessionId()); - requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili:建议用 uuid 之类的;nextId 跨进程未必合适哈; + HashMap requestParams = getDefaultParams(); requestParams.put("prompt", prompt); // 解析 template 参数占位符 String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams); // 获取 header - HttpHeaders httpHeaders = getHttpHeaders(); + HttpHeaders httpHeaders = getHeadersOfAppJson(); // 发送请求 HttpEntity requestEntity = new HttpEntity<>(requestBody, httpHeaders); String res = restTemplate.postForObject(url, requestEntity, String.class); @@ -77,19 +63,15 @@ public class MidjourneyInteractions { // TODO done @fansili:方法和方法之间,空一行哈; - public Boolean reRoll(ReRoll reRoll) { + public Boolean reRoll(ReRollReq reRoll) { // 获取请求模板 String requestTemplate = midjourneyConfig.getRequestTemplates().get("reroll"); // 设置参数 - HashMap requestParams = Maps.newHashMap(); - requestParams.put("guild_id", midjourneyConfig.getGuildId()); - requestParams.put("channel_id", midjourneyConfig.getChannelId()); - requestParams.put("session_id", midjourneyConfig.getSessionId()); - requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); + HashMap requestParams = getDefaultParams(); requestParams.put("custom_id", reRoll.getCustomId()); requestParams.put("message_id", reRoll.getMessageId()); // 获取 header - HttpHeaders httpHeaders = getHttpHeaders(); + HttpHeaders httpHeaders = getHeadersOfAppJson(); // 设置参数 String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams); // 发送请求 @@ -105,7 +87,7 @@ public class MidjourneyInteractions { } // TODO @fansili:搞成私有方法,可能会好点; - public UploadAttachmentsRes uploadAttachments(Attachments attachments) { + public UploadAttachmentsRes uploadAttachments(AttachmentsReq attachments) { // file JSONObject fileObj = new JSONObject(); fileObj.put("id", "0"); @@ -120,13 +102,7 @@ public class MidjourneyInteractions { MultiValueMap multipartRequest = new LinkedMultiValueMap<>(); multipartRequest.put("files", Lists.newArrayList(fileObj)); // 设置header值 - HttpHeaders httpHeaders = new HttpHeaders(); - // TODO @fansili:通用的 header 构建,抽一个方法哈; - httpHeaders.setContentType(MediaType.APPLICATION_JSON); - httpHeaders.set("Authorization", midjourneyConfig.getToken()); - httpHeaders.set("User-Agent", midjourneyConfig.getUserAage()); - httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE); - httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId())); + HttpHeaders httpHeaders = getHeadersOfAppJson(); // 创建HttpEntity对象,包含表单数据和头部信息 HttpEntity> multiValueMapHttpEntity = new HttpEntity<>(multipartRequest, httpHeaders); // 发送POST请求并接收响应 @@ -144,24 +120,15 @@ public class MidjourneyInteractions { return uploadAttachmentsRes; } - public Boolean describe(Describe describe) { + public Boolean describe(DescribeReq describe) { // 获取请求模板 String requestTemplate = midjourneyConfig.getRequestTemplates().get("describe"); // 设置参数 - HashMap requestParams = Maps.newHashMap(); - requestParams.put("guild_id", midjourneyConfig.getGuildId()); - requestParams.put("channel_id", midjourneyConfig.getChannelId()); - requestParams.put("session_id", midjourneyConfig.getSessionId()); - requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); + HashMap requestParams = getDefaultParams(); requestParams.put("file_name", describe.getFileName()); requestParams.put("final_file_name", describe.getFinalFileName()); // 设置 header - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA); // 设置内容类型为JSON - httpHeaders.set("Authorization", midjourneyConfig.getToken()); - httpHeaders.set("User-Agent", midjourneyConfig.getUserAage()); - httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE); - httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId())); + HttpHeaders httpHeaders = getHeadersOfFormData(); String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams); // 创建表单数据 MultiValueMap formData = new LinkedMultiValueMap<>(); @@ -178,15 +145,4 @@ public class MidjourneyInteractions { return isSuccess; } - @NotNull - private HttpHeaders getHttpHeaders() { - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.APPLICATION_JSON); // 设置内容类型为JSON - httpHeaders.set("Authorization", midjourneyConfig.getToken()); - httpHeaders.set("User-Agent", midjourneyConfig.getUserAage()); - httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE); - httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId())); - return httpHeaders; - } - }