【增加】对接 Midjourney,增加nonce传递,更新Midjourney image 状态

This commit is contained in:
cherishsince 2024-04-29 22:10:12 +08:00
parent ae934e84e8
commit 03b4460eae
8 changed files with 140 additions and 15 deletions

View File

@ -4,7 +4,6 @@ import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.service.AiImageService; import cn.iocoder.yudao.module.ai.service.AiImageService;
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq; import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq; import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@ -42,7 +41,8 @@ public class AiImageController {
@Operation(summary = "midjourney", description = "midjourney图片绘画流程1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果") @Operation(summary = "midjourney", description = "midjourney图片绘画流程1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
@PostMapping("/midjourney") @PostMapping("/midjourney")
public CommonResult<AiImageMidjourneyRes> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) { public CommonResult<Void> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
return CommonResult.success(aiImageService.midjourney(req)); aiImageService.midjourney(req);
return CommonResult.success(null);
} }
} }

View File

@ -28,5 +28,5 @@ public interface AiImageService {
* @param req * @param req
* @return * @return
*/ */
AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req); void midjourney(AiImageMidjourneyReq req);
} }

View File

@ -95,18 +95,15 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req) { public void midjourney(AiImageMidjourneyReq req) {
// 保存数据库 // 保存数据库
doSave(req.getPrompt(), null, "midjoureny", AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
null, AiChatDrawingStatusEnum.SUBMIT, null); null, AiChatDrawingStatusEnum.SUBMIT, null);
// 提交 midjourney 任务 // 提交 midjourney 任务
Boolean imagine = midjourneyInteractionsApi.imagine(req.getPrompt()); Boolean imagine = midjourneyInteractionsApi.imagine(aiImageDO.getId(), req.getPrompt());
if (!imagine) { if (!imagine) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL); throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
} }
//
return null;
} }
private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) { private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
@ -120,7 +117,7 @@ public class AiImageServiceImpl implements AiImageService {
} }
} }
private void doSave(String prompt, private AiImageDO doSave(String prompt,
String size, String size,
String model, String model,
String imageUrl, String imageUrl,
@ -138,5 +135,6 @@ public class AiImageServiceImpl implements AiImageService {
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus()); aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
aiImageDO.setDrawingError(drawingError); aiImageDO.setDrawingError(drawingError);
aiImageMapper.insert(aiImageDO); aiImageMapper.insert(aiImageDO);
return aiImageDO;
} }
} }

View File

@ -1,7 +1,15 @@
package cn.iocoder.yudao.module.ai.service.midjourneyHandler; package cn.iocoder.yudao.module.ai.service.midjourneyHandler;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO;
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
import cn.iocoder.yudao.module.ai.mapper.AiImageMapper;
import com.alibaba.fastjson2.JSON;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@ -14,10 +22,51 @@ import org.springframework.stereotype.Component;
*/ */
@Component @Component
@Slf4j @Slf4j
@AllArgsConstructor
public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler { public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
private final AiImageMapper aiImageMapper;
@Override @Override
public void messageHandler(MidjourneyMessage midjourneyMessage) { public void messageHandler(MidjourneyMessage midjourneyMessage) {
log.info("yudao-midjourney-midjourney-message-handler", midjourneyMessage); log.info("yudao-midjourney-midjourney-message-handler {}", JSON.toJSONString(midjourneyMessage));
if (midjourneyMessage.getContent() != null) {
log.info("进度id {} 状态 {} 进度 {}",
midjourneyMessage.getNonce(),
midjourneyMessage.getGenerateStatus(),
midjourneyMessage.getContent().getProgress());
}
//
updateImage(midjourneyMessage);
}
private void updateImage(MidjourneyMessage midjourneyMessage) {
// Nonce 不存在不更新
if (StrUtil.isBlank(midjourneyMessage.getNonce())) {
return;
}
// 获取id
Long aiImageId = Long.valueOf(midjourneyMessage.getNonce());
// 获取生成 url
String imageUrl = null;
if (CollUtil.isNotEmpty(midjourneyMessage.getAttachments())) {
imageUrl = midjourneyMessage.getAttachments().get(0).getUrl();
}
// 转换状态
AiChatDrawingStatusEnum drawingStatusEnum = null;
String generateStatus = midjourneyMessage.getGenerateStatus();
if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiChatDrawingStatusEnum.COMPLETE;
} else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiChatDrawingStatusEnum.IN_PROGRESS;
} else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiChatDrawingStatusEnum.WAITING;
}
aiImageMapper.updateById(
new AiImageDO()
.setId(aiImageId)
.setDrawingImageUrl(imageUrl)
.setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus())
);
} }
} }

View File

@ -14,6 +14,10 @@ public class MidjourneyMessage {
* id是一个重要的字段在同时生成多个的时候可以区分生成信息 * id是一个重要的字段在同时生成多个的时候可以区分生成信息
*/ */
private String id; private String id;
/**
* 提交id(nonce 可能会不存在系统提示的时候这个为空)
*/
private String nonce;
/** /**
* 现在已知 * 现在已知
* 0我们发送的消息和指令 * 0我们发送的消息和指令
@ -45,6 +49,14 @@ public class MidjourneyMessage {
* {@link MidjourneyGennerateStatusEnum} * {@link MidjourneyGennerateStatusEnum}
*/ */
private String generateStatus; private String generateStatus;
/**
* 一般用于提示信息
* - 错误
* - 并发队列满了
* - 账号违规了敏感词
* - 账号被封
*/
private List<Embed> embeds;
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
@ -123,4 +135,39 @@ public class MidjourneyMessage {
private String progress; private String progress;
private String status; private String status;
} }
/**
* embed 用于警告提示错误
*/
@Data
@Accessors(chain = true)
public static class Embed {
// 内容扫描版本号
private int contentScanVersion;
// 颜色值这里用Java的Color类来表示注意实际使用中可能需要自定义方法来从int转换为Color对象
private String color;
// 页脚信息包含文本
private Footer footer;
// 描述信息
private String description;
// 消息类型这里是富文本类型(这个区分不同提示类型)
private String type;
// 标题
private String title;
// Footer类作为嵌套类存在用来表示footer部分的JSON对象
@Data
@Accessors(chain = true)
public static class Footer {
// 页脚文本
private String text;
}
}
} }

View File

@ -38,11 +38,13 @@ public class MidjourneyInteractionsApi extends MidjourneyInteractions {
this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions()); this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
} }
public Boolean imagine(String prompt) { public Boolean imagine(Long id, String prompt) {
String nonce = String.valueOf(id);
// 获取请求模板 // 获取请求模板
String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine"); String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine");
// 设置参数 // 设置参数
HashMap<String, String> requestParams = getDefaultParams(); HashMap<String, String> requestParams = getDefaultParams();
requestParams.put("nonce", nonce);
requestParams.put("prompt", prompt); requestParams.put("prompt", prompt);
// 解析 template 参数占位符 // 解析 template 参数占位符
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams); String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);

View File

@ -6,6 +6,10 @@ public final class MidjourneyConstants {
* 消息 - 编号 * 消息 - 编号
*/ */
public static final String MSG_ID = "id"; public static final String MSG_ID = "id";
/**
* 用于区分操作唯一性
*/
public static final String MSG_NONCE = "nonce";
/** /**
* 消息 - 类型 * 消息 - 类型
* 现在已知 * 现在已知
@ -32,6 +36,10 @@ public final class MidjourneyConstants {
* 附件(生成中比较模糊的图片) * 附件(生成中比较模糊的图片)
*/ */
public static final String MSG_ATTACHMENTS = "attachments"; public static final String MSG_ATTACHMENTS = "attachments";
/**
* 一般用于提示
*/
public static final String MSG_EMBEDS = "embeds";
// //

View File

@ -42,9 +42,11 @@ public class MidjourneyMessageListener {
if (ignoreAndLogMessage(data, messageType)) { if (ignoreAndLogMessage(data, messageType)) {
return; return;
} }
log.info("socket message: {}", raw);
// 转换几个重要的信息 // 转换几个重要的信息
MidjourneyMessage mjMessage = new MidjourneyMessage(); MidjourneyMessage mjMessage = new MidjourneyMessage();
mjMessage.setId(data.getString(MidjourneyConstants.MSG_ID)); mjMessage.setId(getString(data, MidjourneyConstants.MSG_ID, ""));
mjMessage.setNonce(getString(data, MidjourneyConstants.MSG_NONCE, ""));
mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE)); mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE));
mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8")); mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT))); mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT)));
@ -60,6 +62,12 @@ public class MidjourneyMessageListener {
List<MidjourneyMessage.Attachment> attachments = JsonUtils.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class); List<MidjourneyMessage.Attachment> attachments = JsonUtils.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class);
mjMessage.setAttachments(attachments); mjMessage.setAttachments(attachments);
} }
// 转换 embeds 提示信息
if (!data.getArray(MidjourneyConstants.MSG_EMBEDS).isEmpty()) {
String embedJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_EMBEDS).toJson(), "UTF-8");
List<MidjourneyMessage.Embed> embeds = JsonUtils.parseArray(embedJson, MidjourneyMessage.Embed.class);
mjMessage.setEmbeds(embeds);
}
// 转换状态 // 转换状态
convertGenerateStatus(mjMessage); convertGenerateStatus(mjMessage);
// message handler 调用 // message handler 调用
@ -68,7 +76,20 @@ public class MidjourneyMessageListener {
} }
} }
private String getString(DataObject data, String key, String defaultValue) {
if (!data.hasKey(key)) {
return defaultValue;
}
return data.getString(key);
}
private void convertGenerateStatus(MidjourneyMessage mjMessage) { private void convertGenerateStatus(MidjourneyMessage mjMessage) {
//
// tip提示警告异常 content是没有内容的
// tip: 一般错误信息在 Embeds 只要 Embeds有值content就没信息
if (CollUtil.isNotEmpty(mjMessage.getEmbeds())) {
return;
}
if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) { if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus()); mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus());
} else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) { } else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {