处理todo

This commit is contained in:
cherishsince 2024-04-09 20:12:41 +08:00
parent 337ae04551
commit 32bc632947
26 changed files with 243 additions and 346 deletions

View File

@ -1,28 +0,0 @@
package cn.iocoder.yudao.module.ai.enums;
import lombok.Getter;
// TODO @fansili1类注释要加下2author time javadoc@author @since3@AllArgsConstructor 使用这个注解去掉构造方法4value 改成 model 字段然后注释都写下哈5message 改成 name然后注释都写下哈
/**
* author: fansili
* time: 2024/3/4 12:36
*/
@Getter
public enum AiModelEnum {
OPEN_AI_GPT_3_5("gpt-3.5-turbo", "GPT3.5"),
OPEN_AI_GPT_4("gpt-4-turbo", "GPT4")
;
AiModelEnum(String value, String message) {
this.value = value;
this.message = message;
}
// TODO @fan
private String value;
private String message;
}

View File

@ -0,0 +1,35 @@
package cn.iocoder.yudao.module.ai.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
// TODO done @fansili1类注释要加下2author time javadoc@author @since3@AllArgsConstructor 使用这个注解去掉构造方法4value 改成 model 字段然后注释都写下哈5message 改成 name然后注释都写下哈
/**
* @author: fansili
* @time: 2024/3/4 12:36
*/
@Getter
@AllArgsConstructor
public enum OpenAiModelEnum {
/**
* open ai 3.5模型
*/
OPEN_AI_GPT_3_5("gpt-3.5-turbo", "GPT3.5"),
/**
* open ai 4.0 收费模型
*/
OPEN_AI_GPT_4("gpt-4-turbo", "GPT4")
;
/**
* 模型 - 用于参数传递
*/
private String model;
/**
* 模型名字 - 用于展示
*/
private String name;
}

View File

@ -4,7 +4,7 @@ import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.admin.vo.AiChatReqVO; import cn.iocoder.yudao.module.ai.controller.admin.vo.AiChatReqVO;
import cn.iocoder.yudao.module.ai.enums.AiModelEnum; import cn.iocoder.yudao.module.ai.enums.OpenAiModelEnum;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
@ -13,7 +13,6 @@ import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
@ -23,16 +22,9 @@ import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.util.Scanner;
import java.util.function.Consumer; import java.util.function.Consumer;
// TODO @fansili有了 swagger 注释就不用类注释了 // TODO done @fansili有了 swagger 注释就不用类注释了
/**
* AI模块
*
* author: fansili
* time: 2024/3/3 20:28
*/
@Tag(name = "AI模块") @Tag(name = "AI模块")
@RestController @RestController
@RequestMapping("/ai-api") @RequestMapping("/ai-api")
@ -48,7 +40,7 @@ public class ChatController {
ChatClient chatClient = getChatClient(reqVO.getAiModel()); ChatClient chatClient = getChatClient(reqVO.getAiModel());
String res; String res;
try { try {
res = chatClient.call(reqVO.getInputText()); res = chatClient.call(reqVO.getPrompt());
} catch (Exception e) { } catch (Exception e) {
res = e.getMessage(); res = e.getMessage();
} }
@ -59,33 +51,14 @@ public class ChatController {
@Operation(summary = "对话聊天chatStream", description = "简单的ai聊天") @Operation(summary = "对话聊天chatStream", description = "简单的ai聊天")
public CommonResult chatStream(HttpServletResponse response, @RequestBody @Validated AiChatReqVO reqVO) throws InterruptedException { public CommonResult chatStream(HttpServletResponse response, @RequestBody @Validated AiChatReqVO reqVO) throws InterruptedException {
OpenAiChatClient chatClient = applicationContext.getBean(OpenAiChatClient.class); OpenAiChatClient chatClient = applicationContext.getBean(OpenAiChatClient.class);
Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getInputText())); Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getPrompt()));
chatResponse.subscribe(new Consumer<ChatResponse>() { chatResponse.subscribe(new Consumer<ChatResponse>() {
@Override @Override
public void accept(ChatResponse chatResponse) { public void accept(ChatResponse chatResponse) {
System.err.println(chatResponse.getResults().get(0).getOutput().getContent()); System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
} }
}); });
return CommonResult.success("1"); return CommonResult.success(null);
}
public static void main(String[] args) {
OpenAiChatClient openAiChatClient = new OpenAiChatClient(new OpenAiApi("openkey"));
Flux<ChatResponse> responseFlux = openAiChatClient.stream(new Prompt("最好的编程语言!"));
long now = System.currentTimeMillis();
responseFlux.subscribe(new Consumer<ChatResponse>() {
@Override
public void accept(ChatResponse chatResponse) {
if (chatResponse.getResults().get(0).getOutput() == null) {
return;
}
System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
}
});
// 阻止退出
Scanner scanner = new Scanner(System.in);
scanner.nextLine();
} }
/** /**
@ -94,8 +67,8 @@ public class ChatController {
* @param aiModelEnum * @param aiModelEnum
* @return * @return
*/ */
private ChatClient getChatClient(AiModelEnum aiModelEnum) { private ChatClient getChatClient(OpenAiModelEnum aiModelEnum) {
if (AiModelEnum.OPEN_AI_GPT_3_5 == aiModelEnum) { if (OpenAiModelEnum.OPEN_AI_GPT_3_5 == aiModelEnum) {
return applicationContext.getBean(OpenAiChatClient.class); return applicationContext.getBean(OpenAiChatClient.class);
} }
// AI模型暂不支持 // AI模型暂不支持

View File

@ -1,27 +1,21 @@
package cn.iocoder.yudao.module.ai.controller.admin.vo; package cn.iocoder.yudao.module.ai.controller.admin.vo;
import cn.iocoder.yudao.module.ai.enums.AiModelEnum; import cn.iocoder.yudao.module.ai.enums.OpenAiModelEnum;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
// TODO @fansili 1swagger 注释不太对2有了 swagger 注释就不用类注释了 // TODO done @fansili 1swagger 注释不太对2有了 swagger 注释就不用类注释了
/**
* ai 聊天 req
*
* author: fansili
* time: 2024/3/4 12:33
*/
@Schema(description = "用户 App - 上传文件 Request VO")
@Data @Data
@Schema(description = "用户 App - 上传文件 Request VO")
public class AiChatReqVO { public class AiChatReqVO {
@Schema(description = "输入内容", requiredMode = Schema.RequiredMode.REQUIRED) @Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED)
@NotNull(message = "输入内容不能为空") @NotNull(message = "提示词不能为空!")
private String inputText; private String prompt;
@Schema(description = "AI模型", requiredMode = Schema.RequiredMode.REQUIRED) @Schema(description = "AI模型", requiredMode = Schema.RequiredMode.REQUIRED)
@NotNull(message = "AI模型不能为空") @NotNull(message = "AI模型不能为空")
private AiModelEnum aiModel; private OpenAiModelEnum aiModel;
} }

View File

@ -149,6 +149,10 @@
<!-- </exclusion>--> <!-- </exclusion>-->
<!-- </exclusions>--> <!-- </exclusions>-->
</dependency> </dependency>
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-common</artifactId>
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -6,6 +6,7 @@ import cn.iocoder.yudao.framework.ai.chat.*;
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType; import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions; import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException; import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
import com.aliyun.broadscope.bailian.sdk.models.*; import com.aliyun.broadscope.bailian.sdk.models.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.framework.ai.chatqianwen; package cn.iocoder.yudao.framework.ai.chatqianwen.api;
import com.aliyun.broadscope.bailian.sdk.AccessTokenClient; import com.aliyun.broadscope.bailian.sdk.AccessTokenClient;
import com.aliyun.broadscope.bailian.sdk.ApplicationClient; import com.aliyun.broadscope.bailian.sdk.ApplicationClient;
@ -9,7 +9,7 @@ import org.springframework.http.HttpStatusCode;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
// TODO @fansili是不是挪到 api 包里按照 spring ai 的结构根目录只放 client options // TODO done @fansili是不是挪到 api 包里按照 spring ai 的结构根目录只放 client options
/** /**
* 阿里 通义千问 * 阿里 通义千问
* *

View File

@ -4,6 +4,7 @@ import cn.hutool.json.JSONUtil;
import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageRequest; import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageRequest;
import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageResponse; import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageResponse;
import cn.iocoder.yudao.framework.ai.util.JacksonUtil; import cn.iocoder.yudao.framework.ai.util.JacksonUtil;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOption;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpEntity; import org.apache.http.HttpEntity;
@ -55,7 +56,7 @@ public class OpenAiImageApi {
httpPost.setURI(URI.create(DEFAULT_BASE_URL.concat("/v1/images/generations"))); httpPost.setURI(URI.create(DEFAULT_BASE_URL.concat("/v1/images/generations")));
httpPost.setHeader("Content-Type", "application/json"); httpPost.setHeader("Content-Type", "application/json");
httpPost.setHeader("Authorization", "Bearer " + apiKey); httpPost.setHeader("Authorization", "Bearer " + apiKey);
httpPost.setEntity(new StringEntity(JacksonUtil.toJson(request), "UTF-8")); httpPost.setEntity(new StringEntity(JsonUtils.toJsonString(request), "UTF-8"));
CloseableHttpResponse response= null; CloseableHttpResponse response= null;
try { try {

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.framework.ai.midjourney; package cn.iocoder.yudao.framework.ai.midjourney;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
@ -7,7 +8,7 @@ import java.util.List;
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public class MjMessage { public class MidjourneyMessage {
/** /**
* id是一个重要的字段在同时生成多个的时候可以区分生成信息 * id是一个重要的字段在同时生成多个的时候可以区分生成信息
@ -41,7 +42,7 @@ public class MjMessage {
* 1等待 * 1等待
* 2进行中 * 2进行中
* 3完成 * 3完成
* {@link cn.iocoder.yudao.framework.ai.midjourney.constants.MjGennerateStatusEnum} * {@link MidjourneyGennerateStatusEnum}
*/ */
private String generateStatus; private String generateStatus;

View File

@ -1,6 +1,6 @@
package cn.iocoder.yudao.framework.ai.midjourney.constants; package cn.iocoder.yudao.framework.ai.midjourney.constants;
public final class MjConstants { public final class MidjourneyConstants {
/** /**
* 消息 - 编号 * 消息 - 编号

View File

@ -0,0 +1,31 @@
package cn.iocoder.yudao.framework.ai.midjourney.constants;
import lombok.AllArgsConstructor;
import lombok.Getter;
// TODO done @fansili1Mj 缩写还是搞成全称虽然长一点但是感觉会相对清晰一些哈2lombok 相关的注解可以用用哈3value status
/**
* mj 生成状态
*
* author: fansili
* time: 2024/4/6 21:07
*/
@Getter
@AllArgsConstructor
public enum MidjourneyGennerateStatusEnum {
WAITING("waiting", "等待..."),
IN_PROGRESS("in_progress", "进行中"),
COMPLETED("completed", "完成"),
;
/**
* 状态
*/
private String status;
/**
* 状态信息
*/
private String message;
}

View File

@ -6,7 +6,7 @@ import lombok.Getter;
* MJ 命令 * MJ 命令
*/ */
@Getter @Getter
public enum MjInteractionsEnum { public enum MidjourneyInteractionsEnum {
IMAGINE("imagine", "生成图片"), IMAGINE("imagine", "生成图片"),
DESCRIBE("describe", "生成描述"), DESCRIBE("describe", "生成描述"),
@ -17,7 +17,7 @@ public enum MjInteractionsEnum {
; ;
MjInteractionsEnum(String value, String message) { MidjourneyInteractionsEnum(String value, String message) {
this.value =value; this.value =value;
this.message =message; this.message =message;
} }

View File

@ -1,7 +1,7 @@
package cn.iocoder.yudao.framework.ai.midjourney.constants; package cn.iocoder.yudao.framework.ai.midjourney.constants;
public enum MjMessageTypeEnum { public enum MidjourneyMessageTypeEnum {
/** /**
* 创建. * 创建.
*/ */
@ -15,7 +15,7 @@ public enum MjMessageTypeEnum {
*/ */
DELETE; DELETE;
public static MjMessageTypeEnum of(String type) { public static MidjourneyMessageTypeEnum of(String type) {
return switch (type) { return switch (type) {
case "MESSAGE_CREATE" -> CREATE; case "MESSAGE_CREATE" -> CREATE;
case "MESSAGE_UPDATE" -> UPDATE; case "MESSAGE_UPDATE" -> UPDATE;

View File

@ -3,7 +3,7 @@ package cn.iocoder.yudao.framework.ai.midjourney.constants;
import lombok.experimental.UtilityClass; import lombok.experimental.UtilityClass;
@UtilityClass @UtilityClass
public final class MjNotifyCode { public final class MidjourneyNotifyCode {
/** /**
* 成功. * 成功.
*/ */

View File

@ -1,30 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.constants;
import lombok.Getter;
// TODO @fansili1Mj 缩写还是搞成全称虽然长一点但是感觉会相对清晰一些哈2lombok 相关的注解可以用用哈3value status
/**
* mj 生成状态
*
* author: fansili
* time: 2024/4/6 21:07
*/
@Getter
public enum MjGennerateStatusEnum {
WAITING("waiting", "等待..."),
IN_PROGRESS("in_progress", "进行中"),
COMPLETED("completed", "完成"),
;
MjGennerateStatusEnum(String value, String message) {
this.value = value;
this.message = message;
}
private String value;
private String message;
}

View File

@ -3,8 +3,8 @@ package cn.iocoder.yudao.framework.ai.midjourney.interactions;
import cn.hutool.core.util.IdUtil; import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjConstants; import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyConstants;
import cn.iocoder.yudao.framework.ai.midjourney.util.MjUtil; import cn.iocoder.yudao.framework.ai.midjourney.util.MidjourneyUtil;
import cn.iocoder.yudao.framework.ai.midjourney.vo.Attachments; import cn.iocoder.yudao.framework.ai.midjourney.vo.Attachments;
import cn.iocoder.yudao.framework.ai.midjourney.vo.Describe; import cn.iocoder.yudao.framework.ai.midjourney.vo.Describe;
import cn.iocoder.yudao.framework.ai.midjourney.vo.ReRoll; import cn.iocoder.yudao.framework.ai.midjourney.vo.ReRoll;
@ -32,15 +32,17 @@ import java.util.HashMap;
* time: 2024/4/3 17:36 * time: 2024/4/3 17:36
*/ */
@Slf4j @Slf4j
public class MjInteractions { public class MidjourneyInteractions {
// TODO done @fansili静态变量放在最前面哈
private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s";
private final String url; private final String url;
private final MidjourneyConfig midjourneyConfig; private final MidjourneyConfig midjourneyConfig;
private final RestTemplate restTemplate = new RestTemplate(); // TODO @fansili优先级低后续搞到统一的管理 private final RestTemplate restTemplate = new RestTemplate(); // TODO @fansili优先级低后续搞到统一的管理
// TODO @fansili静态变量放在最前面哈
private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s";
public MjInteractions(MidjourneyConfig midjourneyConfig) {
public MidjourneyInteractions(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig; this.midjourneyConfig = midjourneyConfig;
this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions()); this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
} }
@ -57,7 +59,7 @@ public class MjInteractions {
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili建议用 uuid 之类的nextId 跨进程未必合适哈 requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili建议用 uuid 之类的nextId 跨进程未必合适哈
requestParams.put("prompt", prompt); requestParams.put("prompt", prompt);
// 解析 template 参数占位符 // 解析 template 参数占位符
String requestBody = MjUtil.parseTemplate(requestTemplate, requestParams); String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
// 获取 header // 获取 header
HttpHeaders httpHeaders = getHttpHeaders(); HttpHeaders httpHeaders = getHttpHeaders();
// 发送请求 // 发送请求
@ -65,14 +67,14 @@ public class MjInteractions {
String res = restTemplate.postForObject(url, requestEntity, String.class); String res = restTemplate.postForObject(url, requestEntity, String.class);
// 这个 res 只要不返回值就是成功! // 这个 res 只要不返回值就是成功!
// TODO @fansili可以直接 if (StrUtil.isBlank(res)) // TODO @fansili可以直接 if (StrUtil.isBlank(res))
boolean isSuccess = StrUtil.isBlank(res); if (StrUtil.isBlank(res)) {
if (isSuccess) {
return true; return true;
} else {
log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
return false;
} }
log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
return isSuccess;
} }
// TODO @fansili方法和方法之间空一行哈 // TODO done @fansili方法和方法之间空一行哈
public Boolean reRoll(ReRoll reRoll) { public Boolean reRoll(ReRoll reRoll) {
@ -89,7 +91,7 @@ public class MjInteractions {
// 获取 header // 获取 header
HttpHeaders httpHeaders = getHttpHeaders(); HttpHeaders httpHeaders = getHttpHeaders();
// 设置参数 // 设置参数
String requestBody = MjUtil.parseTemplate(requestTemplate, requestParams); String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
// 发送请求 // 发送请求
HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders); HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders);
String res = restTemplate.postForObject(url, requestEntity, String.class); String res = restTemplate.postForObject(url, requestEntity, String.class);
@ -123,7 +125,7 @@ public class MjInteractions {
httpHeaders.setContentType(MediaType.APPLICATION_JSON); httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.set("Authorization", midjourneyConfig.getToken()); httpHeaders.set("Authorization", midjourneyConfig.getToken());
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage()); httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
httpHeaders.set("Cookie", MjConstants.HTTP_COOKIE); httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId())); httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
// 创建HttpEntity对象包含表单数据和头部信息 // 创建HttpEntity对象包含表单数据和头部信息
HttpEntity<MultiValueMap<String, Object>> multiValueMapHttpEntity = new HttpEntity<>(multipartRequest, httpHeaders); HttpEntity<MultiValueMap<String, Object>> multiValueMapHttpEntity = new HttpEntity<>(multipartRequest, httpHeaders);
@ -132,16 +134,13 @@ public class MjInteractions {
String response = restTemplate.postForObject(midjourneyConfig.getServerUrl().concat(uri), multiValueMapHttpEntity, String.class); String response = restTemplate.postForObject(midjourneyConfig.getServerUrl().concat(uri), multiValueMapHttpEntity, String.class);
UploadAttachmentsRes uploadAttachmentsRes = JSON.parseObject(response, UploadAttachmentsRes.class); UploadAttachmentsRes uploadAttachmentsRes = JSON.parseObject(response, UploadAttachmentsRes.class);
// //
// 上传文件 // 上传文件
String uploadUrl = uploadAttachmentsRes.getAttachments().getFirst().getUploadUrl(); String uploadUrl = uploadAttachmentsRes.getAttachments().getFirst().getUploadUrl();
String uploadAttachmentsUrl = midjourneyConfig.getApiAttachmentsUpload().concat(uploadUrl);
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA); httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
HttpEntity<FileSystemResource> fileSystemResourceHttpEntity = new HttpEntity<>(attachments.getFileSystemResource(), httpHeaders); HttpEntity<FileSystemResource> fileSystemResourceHttpEntity = new HttpEntity<>(attachments.getFileSystemResource(), httpHeaders);
ResponseEntity<String> exchange = restTemplate.exchange(uploadUrl, HttpMethod.PUT, fileSystemResourceHttpEntity, String.class); ResponseEntity<String> exchange = restTemplate.exchange(uploadUrl, HttpMethod.PUT, fileSystemResourceHttpEntity, String.class);
String uploadRes = exchange.getBody(); String uploadRes = exchange.getBody();
return uploadAttachmentsRes; return uploadAttachmentsRes;
} }
@ -161,9 +160,9 @@ public class MjInteractions {
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA); // 设置内容类型为JSON httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA); // 设置内容类型为JSON
httpHeaders.set("Authorization", midjourneyConfig.getToken()); httpHeaders.set("Authorization", midjourneyConfig.getToken());
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage()); httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
httpHeaders.set("Cookie", MjConstants.HTTP_COOKIE); httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId())); httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
String requestBody = MjUtil.parseTemplate(requestTemplate, requestParams); String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
// 创建表单数据 // 创建表单数据
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>(); MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
formData.add("payload_json", requestBody); formData.add("payload_json", requestBody);
@ -185,7 +184,7 @@ public class MjInteractions {
httpHeaders.setContentType(MediaType.APPLICATION_JSON); // 设置内容类型为JSON httpHeaders.setContentType(MediaType.APPLICATION_JSON); // 设置内容类型为JSON
httpHeaders.set("Authorization", midjourneyConfig.getToken()); httpHeaders.set("Authorization", midjourneyConfig.getToken());
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage()); httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
httpHeaders.set("Cookie", MjConstants.HTTP_COOKIE); httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId())); httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
return httpHeaders; return httpHeaders;
} }

View File

@ -1,7 +1,7 @@
package cn.iocoder.yudao.framework.ai.midjourney.util; package cn.iocoder.yudao.framework.ai.midjourney.util;
import cn.hutool.core.text.CharSequenceUtil; import cn.hutool.core.text.CharSequenceUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MjMessage; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
import java.util.Map; import java.util.Map;
import java.util.regex.Matcher; import java.util.regex.Matcher;
@ -13,7 +13,7 @@ import java.util.regex.Pattern;
* author: fansili * author: fansili
* time: 2024/4/6 19:00 * time: 2024/4/6 19:00
*/ */
public class MjUtil { public class MidjourneyUtil {
/** /**
* content正则匹配prompt和进度. * content正则匹配prompt和进度.
*/ */
@ -26,12 +26,12 @@ public class MjUtil {
* @param content * @param content
* @return * @return
*/ */
public static MjMessage.Content parseContent(String content) { public static MidjourneyMessage.Content parseContent(String content) {
// 有三种格式 // 有三种格式
// 南极应该是什么样子 // 南极应该是什么样子
// "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (32%) (fast, stealth)", // "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (32%) (fast, stealth)",
// "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (fast, stealth)" // "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (fast, stealth)"
MjMessage.Content mjContent = new MjMessage.Content(); MidjourneyMessage.Content mjContent = new MidjourneyMessage.Content();
if (CharSequenceUtil.isBlank(content)) { if (CharSequenceUtil.isBlank(content)) {
return null; return null;
} }

View File

@ -4,9 +4,9 @@ package cn.iocoder.yudao.framework.ai.midjourney.webSocket;
import cn.hutool.core.text.CharSequenceUtil; import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.thread.ThreadUtil; import cn.hutool.core.thread.ThreadUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode; import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyNotifyCode;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MidjourneyWebSocketHandler;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.tomcat.websocket.Constants; import org.apache.tomcat.websocket.Constants;
@ -20,11 +20,10 @@ import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.util.concurrent.TimeoutException;
// TODO @fansilimj 这块 websocket 有点小复杂虽然代码量 400 多行感觉可以考虑有没第三方 sdk通过它透明接入 mj // TODO @fansilimj 这块 websocket 有点小复杂虽然代码量 400 多行感觉可以考虑有没第三方 sdk通过它透明接入 mj
@Slf4j @Slf4j
public class MjWebSocketStarter implements WebSocketStarter { public class MidjourneyWebSocketStarter implements WebSocketStarter {
/** /**
* 链接重试次数 * 链接重试次数
*/ */
@ -36,7 +35,7 @@ public class MjWebSocketStarter implements WebSocketStarter {
/** /**
* mj 监听(所有message 都会 callback到这里) * mj 监听(所有message 都会 callback到这里)
*/ */
private final MjMessageListener userMessageListener; private final MidjourneyMessageListener userMessageListener;
/** /**
* wss 服务器 * wss 服务器
*/ */
@ -58,10 +57,10 @@ public class MjWebSocketStarter implements WebSocketStarter {
*/ */
private WebSocketSession webSocketSession = null; private WebSocketSession webSocketSession = null;
public MjWebSocketStarter(String wssServer, public MidjourneyWebSocketStarter(String wssServer,
String resumeWss, String resumeWss,
MidjourneyConfig midjourneyConfig, MidjourneyConfig midjourneyConfig,
MjMessageListener userMessageListener) { MidjourneyMessageListener userMessageListener) {
this.wssServer = wssServer; this.wssServer = wssServer;
this.resumeWss = resumeWss; this.resumeWss = resumeWss;
this.midjourneyConfig = midjourneyConfig; this.midjourneyConfig = midjourneyConfig;
@ -83,7 +82,7 @@ public class MjWebSocketStarter implements WebSocketStarter {
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits"); headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
headers.add("User-Agent", this.midjourneyConfig.getUserAage()); headers.add("User-Agent", this.midjourneyConfig.getUserAage());
// 创建 mjHeader // 创建 mjHeader
MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler( MidjourneyWebSocketHandler mjWebSocketHandler = new MidjourneyWebSocketHandler(
this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure); this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
// //
String gatewayUrl; String gatewayUrl;
@ -105,12 +104,12 @@ public class MjWebSocketStarter implements WebSocketStarter {
socketSessionFuture.addCallback(new ListenableFutureCallback<>() { socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
@Override @Override
public void onFailure(@NotNull Throwable e) { public void onFailure(@NotNull Throwable e) {
onSocketFailure(MjWebSocketHandler.CLOSE_CODE_EXCEPTION, e.getMessage()); onSocketFailure(MidjourneyWebSocketHandler.CLOSE_CODE_EXCEPTION, e.getMessage());
} }
@Override @Override
public void onSuccess(WebSocketSession session) { public void onSuccess(WebSocketSession session) {
MjWebSocketStarter.this.webSocketSession = session; MidjourneyWebSocketStarter.this.webSocketSession = session;
} }
}); });
} }
@ -118,7 +117,7 @@ public class MjWebSocketStarter implements WebSocketStarter {
private void onSocketSuccess(String sessionId, Object sequence, String resumeGatewayUrl) { private void onSocketSuccess(String sessionId, Object sequence, String resumeGatewayUrl) {
this.resumeData = new ResumeData(sessionId, sequence, resumeGatewayUrl); this.resumeData = new ResumeData(sessionId, sequence, resumeGatewayUrl);
this.running = true; this.running = true;
notifyWssLock(MjNotifyCode.SUCCESS, ""); notifyWssLock(MidjourneyNotifyCode.SUCCESS, "");
} }
private void onSocketFailure(int code, String reason) { private void onSocketFailure(int code, String reason) {

View File

@ -8,7 +8,7 @@ import cn.hutool.http.useragent.UserAgentUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.FailureCallback; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.FailureCallback;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.SuccessCallback; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.SuccessCallback;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataArray; import net.dv8tion.jda.api.utils.data.DataArray;
@ -29,7 +29,7 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@Slf4j @Slf4j
public class MjWebSocketHandler implements WebSocketHandler { public class MidjourneyWebSocketHandler implements WebSocketHandler {
/** /**
* close 错误码重连 * close 错误码重连
*/ */
@ -49,7 +49,7 @@ public class MjWebSocketHandler implements WebSocketHandler {
/** /**
* mj 消息监听 * mj 消息监听
*/ */
private final MjMessageListener userMessageListener; private final MidjourneyMessageListener userMessageListener;
/** /**
* 成功回调 * 成功回调
*/ */
@ -85,10 +85,10 @@ public class MjWebSocketHandler implements WebSocketHandler {
*/ */
private final Decompressor decompressor = new ZlibDecompressor(2048); private final Decompressor decompressor = new ZlibDecompressor(2048);
public MjWebSocketHandler(MidjourneyConfig account, public MidjourneyWebSocketHandler(MidjourneyConfig account,
MjMessageListener userMessageListener, MidjourneyMessageListener userMessageListener,
SuccessCallback successCallback, SuccessCallback successCallback,
FailureCallback failureCallback) { FailureCallback failureCallback) {
this.midjourneyConfig = account; this.midjourneyConfig = account;
this.userMessageListener = userMessageListener; this.userMessageListener = userMessageListener;
this.successCallback = successCallback; this.successCallback = successCallback;

View File

@ -0,0 +1,83 @@
package cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyConstants;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyMessageTypeEnum;
import cn.iocoder.yudao.framework.ai.midjourney.util.MidjourneyUtil;
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataObject;
import java.util.List;
@Slf4j
public class MidjourneyMessageListener {
private MidjourneyConfig midjourneyConfig;
public MidjourneyMessageListener(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig;
}
public void onMessage(DataObject raw) {
MidjourneyMessageTypeEnum messageType = MidjourneyMessageTypeEnum.of(raw.getString("t"));
if (messageType == null || MidjourneyMessageTypeEnum.DELETE == messageType) {
return;
}
DataObject data = raw.getObject("d");
if (ignoreAndLogMessage(data, messageType)) {
return;
}
// 转换几个重要的信息
MidjourneyMessage mjMessage = new MidjourneyMessage();
mjMessage.setId(data.getString(MidjourneyConstants.MSG_ID));
mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE));
mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT)));
// 转换 components
if (!data.getArray(MidjourneyConstants.MSG_COMPONENTS).isEmpty()) {
String componentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_COMPONENTS).toJson(), "UTF-8");
List<MidjourneyMessage.ComponentType> components = JSON.parseArray(componentsJson, MidjourneyMessage.ComponentType.class);
mjMessage.setComponents(components);
}
// 转换附件
if (!data.getArray(MidjourneyConstants.MSG_ATTACHMENTS).isEmpty()) {
String attachmentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_ATTACHMENTS).toJson(), "UTF-8");
List<MidjourneyMessage.Attachment> attachments = JSON.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class);
mjMessage.setAttachments(attachments);
}
// 转换状态
convertGenerateStatus(mjMessage);
//
log.info("message 信息 {}", JSONUtil.toJsonPrettyStr(mjMessage));
System.err.println(JSONUtil.toJsonPrettyStr(mjMessage));
}
private void convertGenerateStatus(MidjourneyMessage mjMessage) {
if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getValue());
} else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.IN_PROGRESS.getValue());
} else if (mjMessage.getType() == 0 && !CollUtil.isEmpty(mjMessage.getComponents())) {
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.COMPLETED.getValue());
}
}
private boolean ignoreAndLogMessage(DataObject data, MidjourneyMessageTypeEnum messageType) {
String channelId = data.getString(MidjourneyConstants.MSG_CHANNEL_ID);
if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) {
return true;
}
String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse(""));
return false;
}
}

View File

@ -1,83 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.MjMessage;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjConstants;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjGennerateStatusEnum;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjMessageTypeEnum;
import cn.iocoder.yudao.framework.ai.midjourney.util.MjUtil;
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataObject;
import java.util.List;
@Slf4j
public class MjMessageListener {
private MidjourneyConfig midjourneyConfig;
public MjMessageListener(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig;
}
public void onMessage(DataObject raw) {
MjMessageTypeEnum messageType = MjMessageTypeEnum.of(raw.getString("t"));
if (messageType == null || MjMessageTypeEnum.DELETE == messageType) {
return;
}
DataObject data = raw.getObject("d");
if (ignoreAndLogMessage(data, messageType)) {
return;
}
// 转换几个重要的信息
MjMessage mjMessage = new MjMessage();
mjMessage.setId(data.getString(MjConstants.MSG_ID));
mjMessage.setType(data.getInt(MjConstants.MSG_TYPE));
mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
mjMessage.setContent(MjUtil.parseContent(data.getString(MjConstants.MSG_CONTENT)));
// 转换 components
if (!data.getArray(MjConstants.MSG_COMPONENTS).isEmpty()) {
String componentsJson = StrUtil.str(data.getArray(MjConstants.MSG_COMPONENTS).toJson(), "UTF-8");
List<MjMessage.ComponentType> components = JSON.parseArray(componentsJson, MjMessage.ComponentType.class);
mjMessage.setComponents(components);
}
// 转换附件
if (!data.getArray(MjConstants.MSG_ATTACHMENTS).isEmpty()) {
String attachmentsJson = StrUtil.str(data.getArray(MjConstants.MSG_ATTACHMENTS).toJson(), "UTF-8");
List<MjMessage.Attachment> attachments = JSON.parseArray(attachmentsJson, MjMessage.Attachment.class);
mjMessage.setAttachments(attachments);
}
// 转换状态
convertGenerateStatus(mjMessage);
//
log.info("message 信息 {}", JSONUtil.toJsonPrettyStr(mjMessage));
System.err.println(JSONUtil.toJsonPrettyStr(mjMessage));
}
private void convertGenerateStatus(MjMessage mjMessage) {
if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
mjMessage.setGenerateStatus(MjGennerateStatusEnum.WAITING.getValue());
} else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {
mjMessage.setGenerateStatus(MjGennerateStatusEnum.IN_PROGRESS.getValue());
} else if (mjMessage.getType() == 0 && !CollUtil.isEmpty(mjMessage.getComponents())) {
mjMessage.setGenerateStatus(MjGennerateStatusEnum.COMPLETED.getValue());
}
}
private boolean ignoreAndLogMessage(DataObject data, MjMessageTypeEnum messageType) {
String channelId = data.getString(MjConstants.MSG_CHANNEL_ID);
if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) {
return true;
}
String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse(""));
return false;
}
}

View File

@ -1,80 +0,0 @@
package cn.iocoder.yudao.framework.ai.util;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import java.io.IOException;
// TODO @fansili看看能不能用 JsonUtils
/**
* Jackson工具类
*
* author: fansili
* time: 2024/3/17 10:13
*/
public class JacksonUtil {
private static final ObjectMapper objectMapper = new ObjectMapper();
/**
* 初始化 ObjectMapper 以美化输出即格式化JSON内容
*/
static {
// 美化输出缩进
objectMapper.enable(SerializationFeature.INDENT_OUTPUT);
// 忽略值为 null 的属性
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
// 配置一个模块来将 Long 类型转换为 String 类型
SimpleModule module = new SimpleModule();
module.addSerializer(Long.class, ToStringSerializer.instance);
objectMapper.registerModule(module);
}
/**
* 将对象转换为 JSON 字符串
*
* @param obj 需要序列化的Java对象
* @return 序列化后的 JSON 字符串
* @throws JsonProcessingException JSON 序列化过程中出现错误时抛出异常
*/
public static String toJson(Object obj) {
try {
return objectMapper.writeValueAsString(obj);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* JSON 字符串反序列化为指定类型的对象
*
* @param json JSON 字符串
* @param clazz 目标类型 Class 对象
* @param <T> 泛型类型参数
* @return 反序列化后的 Java 对象
* @throws IOException JSON 解析过程中出现错误时抛出异常
*/
public static <T> T fromJson(String json, Class<T> clazz) {
try {
return objectMapper.readValue(json, clazz);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* 将对象转换为格式化的 JSON 字符串已启用 INDENT_OUTPUT 功能所以所有方法都会返回格式化后的 JSON
*
* @param obj 需要序列化的Java对象
* @return 格式化后的 JSON 字符串
* @throws JsonProcessingException JSON 序列化过程中出现错误时抛出异常
*/
public static String toFormattedJson(Object obj) {
// 已在类初始化时设置了 SerializationFeature.INDENT_OUTPUT此处无需额外操作
return toJson(obj);
}
}

View File

@ -1,10 +1,9 @@
package cn.iocoder.yudao.framework.ai.chat; package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenApi; import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient; import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions; import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions;
import com.aliyun.broadscope.bailian.sdk.models.CompletionsRequest;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;

View File

@ -1,8 +1,7 @@
package cn.iocoder.yudao.framework.ai.mj; package cn.iocoder.yudao.framework.ai.midjourney;
import cn.hutool.core.io.FileUtil; import cn.hutool.core.io.FileUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.midjourney.interactions.MidjourneyInteractions;
import cn.iocoder.yudao.framework.ai.midjourney.interactions.MjInteractions;
import cn.iocoder.yudao.framework.ai.midjourney.vo.Attachments; import cn.iocoder.yudao.framework.ai.midjourney.vo.Attachments;
import cn.iocoder.yudao.framework.ai.midjourney.vo.Describe; import cn.iocoder.yudao.framework.ai.midjourney.vo.Describe;
import cn.iocoder.yudao.framework.ai.midjourney.vo.ReRoll; import cn.iocoder.yudao.framework.ai.midjourney.vo.ReRoll;
@ -23,7 +22,7 @@ import java.util.Map;
* author: fansili * author: fansili
* time: 2024/4/4 18:59 * time: 2024/4/4 18:59
*/ */
public class MjInteractionsTests { public class MidjourneyInteractionsTests {
private MidjourneyConfig midjourneyConfig; private MidjourneyConfig midjourneyConfig;
@Before @Before
@ -39,14 +38,14 @@ public class MjInteractionsTests {
@Test @Test
public void mjImageTest() { public void mjImageTest() {
MjInteractions mjImagineInteractions = new MjInteractions(midjourneyConfig); MidjourneyInteractions mjImagineInteractions = new MidjourneyInteractions(midjourneyConfig);
mjImagineInteractions.imagine("童话里应该是什么样子?"); mjImagineInteractions.imagine("童话里应该是什么样子?");
} }
@Test @Test
public void reRollTest() { public void reRollTest() {
MjInteractions mjImagineInteractions = new MjInteractions(midjourneyConfig); MidjourneyInteractions mjImagineInteractions = new MidjourneyInteractions(midjourneyConfig);
mjImagineInteractions.reRoll(new ReRoll() mjImagineInteractions.reRoll(new ReRoll()
.setMessageId("1226165117448753243") .setMessageId("1226165117448753243")
.setCustomId("MJ::JOB::upsample::3::2aeefbef-43e2-4057-bcf1-43b5f39ab6f7")); .setCustomId("MJ::JOB::upsample::3::2aeefbef-43e2-4057-bcf1-43b5f39ab6f7"));
@ -54,7 +53,7 @@ public class MjInteractionsTests {
@Test @Test
public void uploadAttachmentsTest() { public void uploadAttachmentsTest() {
MjInteractions mjImagineInteractions = new MjInteractions(midjourneyConfig); MidjourneyInteractions mjImagineInteractions = new MidjourneyInteractions(midjourneyConfig);
UploadAttachmentsRes res = mjImagineInteractions.uploadAttachments( UploadAttachmentsRes res = mjImagineInteractions.uploadAttachments(
new Attachments().setFileSystemResource( new Attachments().setFileSystemResource(
new FileSystemResource(new File("/Users/fansili/Downloads/DSC01402.JPG"))) new FileSystemResource(new File("/Users/fansili/Downloads/DSC01402.JPG")))
@ -64,7 +63,7 @@ public class MjInteractionsTests {
@Test @Test
public void describeTest() { public void describeTest() {
MjInteractions mjImagineInteractions = new MjInteractions(midjourneyConfig); MidjourneyInteractions mjImagineInteractions = new MidjourneyInteractions(midjourneyConfig);
mjImagineInteractions.describe(new Describe() mjImagineInteractions.describe(new Describe()
.setFileName("DSC01402.JPG") .setFileName("DSC01402.JPG")
.setFinalFileName("16826931-2873-45ec-8cfb-0ad81f1a075f/DSC01402.JPG") .setFinalFileName("16826931-2873-45ec-8cfb-0ad81f1a075f/DSC01402.JPG")

View File

@ -1,6 +1,6 @@
package cn.iocoder.yudao.framework.ai.mj; package cn.iocoder.yudao.framework.ai.midjourney;
import cn.iocoder.yudao.framework.ai.midjourney.util.MjUtil; import cn.iocoder.yudao.framework.ai.midjourney.util.MidjourneyUtil;
import org.junit.Test; import org.junit.Test;
/** /**
@ -9,14 +9,14 @@ import org.junit.Test;
* author: fansili * author: fansili
* time: 2024/4/6 21:57 * time: 2024/4/6 21:57
*/ */
public class MjUtilTests { public class MidjourneyUtilTests {
@Test @Test
public void parseContentTest() { public void parseContentTest() {
String content1 = "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (32%) (fast, stealth)"; String content1 = "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (32%) (fast, stealth)";
String content2 = "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (fast, stealth)"; String content2 = "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (fast, stealth)";
System.err.println(MjUtil.parseContent(content1)); System.err.println(MidjourneyUtil.parseContent(content1));
System.err.println(MjUtil.parseContent(content2)); System.err.println(MidjourneyUtil.parseContent(content2));
} }
} }

View File

@ -1,9 +1,8 @@
package cn.iocoder.yudao.framework.ai.mj; package cn.iocoder.yudao.framework.ai.midjourney;
import cn.hutool.core.io.FileUtil; import cn.hutool.core.io.FileUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MjWebSocketStarter;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -17,7 +16,7 @@ import java.util.Scanner;
* author: fansili * author: fansili
* time: 2024/4/3 16:40 * time: 2024/4/3 16:40
*/ */
public class MjWebSocketTests { public class MidjourneyWebSocketTests {
private MidjourneyConfig midjourneyConfig; private MidjourneyConfig midjourneyConfig;
@ -35,8 +34,8 @@ public class MjWebSocketTests {
@Test @Test
public void startSocketTest() { public void startSocketTest() {
String wssUrl = "wss://gateway.discord.gg"; String wssUrl = "wss://gateway.discord.gg";
var messageListener = new MjMessageListener(midjourneyConfig); var messageListener = new MidjourneyMessageListener(midjourneyConfig);
var webSocketStarter = new MjWebSocketStarter(wssUrl, null, midjourneyConfig, messageListener); var webSocketStarter = new MidjourneyWebSocketStarter(wssUrl, null, midjourneyConfig, messageListener);
try { try {
webSocketStarter.start(); webSocketStarter.start();