1、请求公共部分抽离

2、修改为api和spring ai结构保持一致
This commit is contained in:
cherishsince 2024-04-10 20:02:06 +08:00
parent 0e277b71e1
commit 03f124a82b
2 changed files with 101 additions and 62 deletions

View File

@ -0,0 +1,83 @@
package cn.iocoder.yudao.framework.ai.midjourney.api;
import cn.hutool.core.util.IdUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyConstants;
import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import java.util.HashMap;
// TODO @fansili按照 spring ai 的封装习惯这个类是不是 MidjourneyApi
/**
* 图片生成
*
* author: fansili
* time: 2024/4/3 17:36
*/
@Slf4j
public abstract class MidjourneyInteractions {
// TODO done @fansili静态变量放在最前面哈
/**
* header - referer 头信息
*/
private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s";
/**
* mj配置文件
*/
protected final MidjourneyConfig midjourneyConfig;
protected MidjourneyInteractions(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig;
}
/**
* 获取headers - application json
*
* @return
*/
protected HttpHeaders getHeadersOfAppJson() {
// 设置header值
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.set("Authorization", midjourneyConfig.getToken());
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
return httpHeaders;
}
/**
* 获取headers - http form data
*
* @return
*/
protected HttpHeaders getHeadersOfFormData() {
// 设置header值
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
httpHeaders.set("Authorization", midjourneyConfig.getToken());
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
return httpHeaders;
}
/**
* 获取 - 默认参数
* @return
*/
protected HashMap<String, String> getDefaultParams() {
HashMap<String, String> requestParams = Maps.newHashMap();
// TODO @fansili感觉参数的组装可以搞成一个公用的方法就是 config + 入参的感觉
requestParams.put("guild_id", midjourneyConfig.getGuildId());
requestParams.put("channel_id", midjourneyConfig.getChannelId());
requestParams.put("session_id", midjourneyConfig.getSessionId());
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili建议用 uuid 之类的nextId 跨进程未必合适哈
return requestParams;
}
}

View File

@ -1,20 +1,16 @@
package cn.iocoder.yudao.framework.ai.midjourney.interactions; package cn.iocoder.yudao.framework.ai.midjourney.api;
import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyConstants;
import cn.iocoder.yudao.framework.ai.midjourney.util.MidjourneyUtil; import cn.iocoder.yudao.framework.ai.midjourney.util.MidjourneyUtil;
import cn.iocoder.yudao.framework.ai.midjourney.vo.Attachments; import cn.iocoder.yudao.framework.ai.midjourney.api.req.AttachmentsReq;
import cn.iocoder.yudao.framework.ai.midjourney.vo.Describe; import cn.iocoder.yudao.framework.ai.midjourney.api.req.DescribeReq;
import cn.iocoder.yudao.framework.ai.midjourney.vo.ReRoll; import cn.iocoder.yudao.framework.ai.midjourney.api.req.ReRollReq;
import cn.iocoder.yudao.framework.ai.midjourney.vo.UploadAttachmentsRes; import cn.iocoder.yudao.framework.ai.midjourney.api.res.UploadAttachmentsRes;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.FileSystemResource;
import org.springframework.http.*; import org.springframework.http.*;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
@ -32,18 +28,13 @@ import java.util.HashMap;
* time: 2024/4/3 17:36 * time: 2024/4/3 17:36
*/ */
@Slf4j @Slf4j
public class MidjourneyInteractions { public class MidjourneyInteractionsApi extends MidjourneyInteractions {
// TODO done @fansili静态变量放在最前面哈
private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s";
private final String url; private final String url;
private final MidjourneyConfig midjourneyConfig;
private final RestTemplate restTemplate = new RestTemplate(); // TODO @fansili优先级低后续搞到统一的管理 private final RestTemplate restTemplate = new RestTemplate(); // TODO @fansili优先级低后续搞到统一的管理
public MidjourneyInteractionsApi(MidjourneyConfig midjourneyConfig) {
public MidjourneyInteractions(MidjourneyConfig midjourneyConfig) { super(midjourneyConfig);
this.midjourneyConfig = midjourneyConfig;
this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions()); this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
} }
@ -51,17 +42,12 @@ public class MidjourneyInteractions {
// 获取请求模板 // 获取请求模板
String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine"); String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine");
// 设置参数 // 设置参数
HashMap<String, String> requestParams = Maps.newHashMap(); HashMap<String, String> requestParams = getDefaultParams();
// TODO @fansili感觉参数的组装可以搞成一个公用的方法就是 config + 入参的感觉
requestParams.put("guild_id", midjourneyConfig.getGuildId());
requestParams.put("channel_id", midjourneyConfig.getChannelId());
requestParams.put("session_id", midjourneyConfig.getSessionId());
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili建议用 uuid 之类的nextId 跨进程未必合适哈
requestParams.put("prompt", prompt); requestParams.put("prompt", prompt);
// 解析 template 参数占位符 // 解析 template 参数占位符
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams); String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
// 获取 header // 获取 header
HttpHeaders httpHeaders = getHttpHeaders(); HttpHeaders httpHeaders = getHeadersOfAppJson();
// 发送请求 // 发送请求
HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders); HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders);
String res = restTemplate.postForObject(url, requestEntity, String.class); String res = restTemplate.postForObject(url, requestEntity, String.class);
@ -77,19 +63,15 @@ public class MidjourneyInteractions {
// TODO done @fansili方法和方法之间空一行哈 // TODO done @fansili方法和方法之间空一行哈
public Boolean reRoll(ReRoll reRoll) { public Boolean reRoll(ReRollReq reRoll) {
// 获取请求模板 // 获取请求模板
String requestTemplate = midjourneyConfig.getRequestTemplates().get("reroll"); String requestTemplate = midjourneyConfig.getRequestTemplates().get("reroll");
// 设置参数 // 设置参数
HashMap<String, String> requestParams = Maps.newHashMap(); HashMap<String, String> requestParams = getDefaultParams();
requestParams.put("guild_id", midjourneyConfig.getGuildId());
requestParams.put("channel_id", midjourneyConfig.getChannelId());
requestParams.put("session_id", midjourneyConfig.getSessionId());
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId()));
requestParams.put("custom_id", reRoll.getCustomId()); requestParams.put("custom_id", reRoll.getCustomId());
requestParams.put("message_id", reRoll.getMessageId()); requestParams.put("message_id", reRoll.getMessageId());
// 获取 header // 获取 header
HttpHeaders httpHeaders = getHttpHeaders(); HttpHeaders httpHeaders = getHeadersOfAppJson();
// 设置参数 // 设置参数
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams); String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
// 发送请求 // 发送请求
@ -105,7 +87,7 @@ public class MidjourneyInteractions {
} }
// TODO @fansili搞成私有方法可能会好点 // TODO @fansili搞成私有方法可能会好点
public UploadAttachmentsRes uploadAttachments(Attachments attachments) { public UploadAttachmentsRes uploadAttachments(AttachmentsReq attachments) {
// file // file
JSONObject fileObj = new JSONObject(); JSONObject fileObj = new JSONObject();
fileObj.put("id", "0"); fileObj.put("id", "0");
@ -120,13 +102,7 @@ public class MidjourneyInteractions {
MultiValueMap<String, Object> multipartRequest = new LinkedMultiValueMap<>(); MultiValueMap<String, Object> multipartRequest = new LinkedMultiValueMap<>();
multipartRequest.put("files", Lists.newArrayList(fileObj)); multipartRequest.put("files", Lists.newArrayList(fileObj));
// 设置header值 // 设置header值
HttpHeaders httpHeaders = new HttpHeaders(); HttpHeaders httpHeaders = getHeadersOfAppJson();
// TODO @fansili通用的 header 构建抽一个方法哈
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.set("Authorization", midjourneyConfig.getToken());
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
// 创建HttpEntity对象包含表单数据和头部信息 // 创建HttpEntity对象包含表单数据和头部信息
HttpEntity<MultiValueMap<String, Object>> multiValueMapHttpEntity = new HttpEntity<>(multipartRequest, httpHeaders); HttpEntity<MultiValueMap<String, Object>> multiValueMapHttpEntity = new HttpEntity<>(multipartRequest, httpHeaders);
// 发送POST请求并接收响应 // 发送POST请求并接收响应
@ -144,24 +120,15 @@ public class MidjourneyInteractions {
return uploadAttachmentsRes; return uploadAttachmentsRes;
} }
public Boolean describe(Describe describe) { public Boolean describe(DescribeReq describe) {
// 获取请求模板 // 获取请求模板
String requestTemplate = midjourneyConfig.getRequestTemplates().get("describe"); String requestTemplate = midjourneyConfig.getRequestTemplates().get("describe");
// 设置参数 // 设置参数
HashMap<String, String> requestParams = Maps.newHashMap(); HashMap<String, String> requestParams = getDefaultParams();
requestParams.put("guild_id", midjourneyConfig.getGuildId());
requestParams.put("channel_id", midjourneyConfig.getChannelId());
requestParams.put("session_id", midjourneyConfig.getSessionId());
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId()));
requestParams.put("file_name", describe.getFileName()); requestParams.put("file_name", describe.getFileName());
requestParams.put("final_file_name", describe.getFinalFileName()); requestParams.put("final_file_name", describe.getFinalFileName());
// 设置 header // 设置 header
HttpHeaders httpHeaders = new HttpHeaders(); HttpHeaders httpHeaders = getHeadersOfFormData();
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA); // 设置内容类型为JSON
httpHeaders.set("Authorization", midjourneyConfig.getToken());
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams); String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
// 创建表单数据 // 创建表单数据
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>(); MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
@ -178,15 +145,4 @@ public class MidjourneyInteractions {
return isSuccess; return isSuccess;
} }
@NotNull
private HttpHeaders getHttpHeaders() {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.setContentType(MediaType.APPLICATION_JSON); // 设置内容类型为JSON
httpHeaders.set("Authorization", midjourneyConfig.getToken());
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
return httpHeaders;
}
} }