mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2024-11-30 11:11:55 +08:00
【增加】对接 Midjourney,增加nonce传递,更新Midjourney image 状态
This commit is contained in:
parent
ae934e84e8
commit
03b4460eae
@ -4,7 +4,6 @@ import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
|||||||
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
||||||
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
|
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
|
||||||
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
|
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
|
||||||
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
|
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
@ -42,7 +41,8 @@ public class AiImageController {
|
|||||||
|
|
||||||
@Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
|
@Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
|
||||||
@PostMapping("/midjourney")
|
@PostMapping("/midjourney")
|
||||||
public CommonResult<AiImageMidjourneyRes> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
|
public CommonResult<Void> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
|
||||||
return CommonResult.success(aiImageService.midjourney(req));
|
aiImageService.midjourney(req);
|
||||||
|
return CommonResult.success(null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -28,5 +28,5 @@ public interface AiImageService {
|
|||||||
* @param req
|
* @param req
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req);
|
void midjourney(AiImageMidjourneyReq req);
|
||||||
}
|
}
|
||||||
|
@ -95,18 +95,15 @@ public class AiImageServiceImpl implements AiImageService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
@Transactional(rollbackFor = Exception.class)
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req) {
|
public void midjourney(AiImageMidjourneyReq req) {
|
||||||
// 保存数据库
|
// 保存数据库
|
||||||
doSave(req.getPrompt(), null, "midjoureny",
|
AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
|
||||||
null, AiChatDrawingStatusEnum.SUBMIT, null);
|
null, AiChatDrawingStatusEnum.SUBMIT, null);
|
||||||
// 提交 midjourney 任务
|
// 提交 midjourney 任务
|
||||||
Boolean imagine = midjourneyInteractionsApi.imagine(req.getPrompt());
|
Boolean imagine = midjourneyInteractionsApi.imagine(aiImageDO.getId(), req.getPrompt());
|
||||||
if (!imagine) {
|
if (!imagine) {
|
||||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
|
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
|
||||||
}
|
}
|
||||||
//
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
|
private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
|
||||||
@ -120,7 +117,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void doSave(String prompt,
|
private AiImageDO doSave(String prompt,
|
||||||
String size,
|
String size,
|
||||||
String model,
|
String model,
|
||||||
String imageUrl,
|
String imageUrl,
|
||||||
@ -138,5 +135,6 @@ public class AiImageServiceImpl implements AiImageService {
|
|||||||
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
|
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
|
||||||
aiImageDO.setDrawingError(drawingError);
|
aiImageDO.setDrawingError(drawingError);
|
||||||
aiImageMapper.insert(aiImageDO);
|
aiImageMapper.insert(aiImageDO);
|
||||||
|
return aiImageDO;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,15 @@
|
|||||||
package cn.iocoder.yudao.module.ai.service.midjourneyHandler;
|
package cn.iocoder.yudao.module.ai.service.midjourneyHandler;
|
||||||
|
|
||||||
|
import cn.hutool.core.collection.CollUtil;
|
||||||
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
|
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
|
||||||
|
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler;
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO;
|
||||||
|
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
|
||||||
|
import cn.iocoder.yudao.module.ai.mapper.AiImageMapper;
|
||||||
|
import com.alibaba.fastjson2.JSON;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
@ -14,10 +22,51 @@ import org.springframework.stereotype.Component;
|
|||||||
*/
|
*/
|
||||||
@Component
|
@Component
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@AllArgsConstructor
|
||||||
public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
|
public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
|
||||||
|
|
||||||
|
private final AiImageMapper aiImageMapper;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void messageHandler(MidjourneyMessage midjourneyMessage) {
|
public void messageHandler(MidjourneyMessage midjourneyMessage) {
|
||||||
log.info("yudao-midjourney-midjourney-message-handler", midjourneyMessage);
|
log.info("yudao-midjourney-midjourney-message-handler {}", JSON.toJSONString(midjourneyMessage));
|
||||||
|
if (midjourneyMessage.getContent() != null) {
|
||||||
|
log.info("进度id {} 状态 {} 进度 {}",
|
||||||
|
midjourneyMessage.getNonce(),
|
||||||
|
midjourneyMessage.getGenerateStatus(),
|
||||||
|
midjourneyMessage.getContent().getProgress());
|
||||||
|
}
|
||||||
|
//
|
||||||
|
updateImage(midjourneyMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateImage(MidjourneyMessage midjourneyMessage) {
|
||||||
|
// Nonce 不存在不更新
|
||||||
|
if (StrUtil.isBlank(midjourneyMessage.getNonce())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// 获取id
|
||||||
|
Long aiImageId = Long.valueOf(midjourneyMessage.getNonce());
|
||||||
|
// 获取生成 url
|
||||||
|
String imageUrl = null;
|
||||||
|
if (CollUtil.isNotEmpty(midjourneyMessage.getAttachments())) {
|
||||||
|
imageUrl = midjourneyMessage.getAttachments().get(0).getUrl();
|
||||||
|
}
|
||||||
|
// 转换状态
|
||||||
|
AiChatDrawingStatusEnum drawingStatusEnum = null;
|
||||||
|
String generateStatus = midjourneyMessage.getGenerateStatus();
|
||||||
|
if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) {
|
||||||
|
drawingStatusEnum = AiChatDrawingStatusEnum.COMPLETE;
|
||||||
|
} else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) {
|
||||||
|
drawingStatusEnum = AiChatDrawingStatusEnum.IN_PROGRESS;
|
||||||
|
} else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
|
||||||
|
drawingStatusEnum = AiChatDrawingStatusEnum.WAITING;
|
||||||
|
}
|
||||||
|
aiImageMapper.updateById(
|
||||||
|
new AiImageDO()
|
||||||
|
.setId(aiImageId)
|
||||||
|
.setDrawingImageUrl(imageUrl)
|
||||||
|
.setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,10 @@ public class MidjourneyMessage {
|
|||||||
* id是一个重要的字段,在同时生成多个的时候,可以区分生成信息
|
* id是一个重要的字段,在同时生成多个的时候,可以区分生成信息
|
||||||
*/
|
*/
|
||||||
private String id;
|
private String id;
|
||||||
|
/**
|
||||||
|
* 提交id(nonce 可能会不存在,系统提示的时候,这个为空)
|
||||||
|
*/
|
||||||
|
private String nonce;
|
||||||
/**
|
/**
|
||||||
* 现在已知:
|
* 现在已知:
|
||||||
* 0:我们发送的消息,和指令
|
* 0:我们发送的消息,和指令
|
||||||
@ -45,6 +49,14 @@ public class MidjourneyMessage {
|
|||||||
* {@link MidjourneyGennerateStatusEnum}
|
* {@link MidjourneyGennerateStatusEnum}
|
||||||
*/
|
*/
|
||||||
private String generateStatus;
|
private String generateStatus;
|
||||||
|
/**
|
||||||
|
* 一般用于提示信息
|
||||||
|
* - 错误
|
||||||
|
* - 并发队列满了
|
||||||
|
* - 账号违规了、敏感词
|
||||||
|
* - 账号被封
|
||||||
|
*/
|
||||||
|
private List<Embed> embeds;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Accessors(chain = true)
|
@Accessors(chain = true)
|
||||||
@ -123,4 +135,39 @@ public class MidjourneyMessage {
|
|||||||
private String progress;
|
private String progress;
|
||||||
private String status;
|
private String status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* embed 用于警告、提示、错误
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Accessors(chain = true)
|
||||||
|
public static class Embed {
|
||||||
|
|
||||||
|
// 内容扫描版本号
|
||||||
|
private int contentScanVersion;
|
||||||
|
|
||||||
|
// 颜色值,这里用Java的Color类来表示,注意实际使用中可能需要自定义方法来从int转换为Color对象
|
||||||
|
private String color;
|
||||||
|
|
||||||
|
// 页脚信息,包含文本
|
||||||
|
private Footer footer;
|
||||||
|
|
||||||
|
// 描述信息
|
||||||
|
private String description;
|
||||||
|
|
||||||
|
// 消息类型,这里是富文本类型(这个区分不同提示类型)
|
||||||
|
private String type;
|
||||||
|
|
||||||
|
// 标题
|
||||||
|
private String title;
|
||||||
|
|
||||||
|
// Footer类,作为嵌套类存在,用来表示footer部分的JSON对象
|
||||||
|
@Data
|
||||||
|
@Accessors(chain = true)
|
||||||
|
public static class Footer {
|
||||||
|
// 页脚文本
|
||||||
|
private String text;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,11 +38,13 @@ public class MidjourneyInteractionsApi extends MidjourneyInteractions {
|
|||||||
this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
|
this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Boolean imagine(String prompt) {
|
public Boolean imagine(Long id, String prompt) {
|
||||||
|
String nonce = String.valueOf(id);
|
||||||
// 获取请求模板
|
// 获取请求模板
|
||||||
String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine");
|
String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine");
|
||||||
// 设置参数
|
// 设置参数
|
||||||
HashMap<String, String> requestParams = getDefaultParams();
|
HashMap<String, String> requestParams = getDefaultParams();
|
||||||
|
requestParams.put("nonce", nonce);
|
||||||
requestParams.put("prompt", prompt);
|
requestParams.put("prompt", prompt);
|
||||||
// 解析 template 参数占位符
|
// 解析 template 参数占位符
|
||||||
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
|
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
|
||||||
|
@ -6,6 +6,10 @@ public final class MidjourneyConstants {
|
|||||||
* 消息 - 编号
|
* 消息 - 编号
|
||||||
*/
|
*/
|
||||||
public static final String MSG_ID = "id";
|
public static final String MSG_ID = "id";
|
||||||
|
/**
|
||||||
|
* 用于区分操作唯一性
|
||||||
|
*/
|
||||||
|
public static final String MSG_NONCE = "nonce";
|
||||||
/**
|
/**
|
||||||
* 消息 - 类型
|
* 消息 - 类型
|
||||||
* 现在已知:
|
* 现在已知:
|
||||||
@ -32,6 +36,10 @@ public final class MidjourneyConstants {
|
|||||||
* 附件(生成中比较模糊的图片)
|
* 附件(生成中比较模糊的图片)
|
||||||
*/
|
*/
|
||||||
public static final String MSG_ATTACHMENTS = "attachments";
|
public static final String MSG_ATTACHMENTS = "attachments";
|
||||||
|
/**
|
||||||
|
* 一般用于提示
|
||||||
|
*/
|
||||||
|
public static final String MSG_EMBEDS = "embeds";
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -42,9 +42,11 @@ public class MidjourneyMessageListener {
|
|||||||
if (ignoreAndLogMessage(data, messageType)) {
|
if (ignoreAndLogMessage(data, messageType)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
log.info("socket message: {}", raw);
|
||||||
// 转换几个重要的信息
|
// 转换几个重要的信息
|
||||||
MidjourneyMessage mjMessage = new MidjourneyMessage();
|
MidjourneyMessage mjMessage = new MidjourneyMessage();
|
||||||
mjMessage.setId(data.getString(MidjourneyConstants.MSG_ID));
|
mjMessage.setId(getString(data, MidjourneyConstants.MSG_ID, ""));
|
||||||
|
mjMessage.setNonce(getString(data, MidjourneyConstants.MSG_NONCE, ""));
|
||||||
mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE));
|
mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE));
|
||||||
mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
|
mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
|
||||||
mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT)));
|
mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT)));
|
||||||
@ -60,6 +62,12 @@ public class MidjourneyMessageListener {
|
|||||||
List<MidjourneyMessage.Attachment> attachments = JsonUtils.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class);
|
List<MidjourneyMessage.Attachment> attachments = JsonUtils.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class);
|
||||||
mjMessage.setAttachments(attachments);
|
mjMessage.setAttachments(attachments);
|
||||||
}
|
}
|
||||||
|
// 转换 embeds 提示信息
|
||||||
|
if (!data.getArray(MidjourneyConstants.MSG_EMBEDS).isEmpty()) {
|
||||||
|
String embedJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_EMBEDS).toJson(), "UTF-8");
|
||||||
|
List<MidjourneyMessage.Embed> embeds = JsonUtils.parseArray(embedJson, MidjourneyMessage.Embed.class);
|
||||||
|
mjMessage.setEmbeds(embeds);
|
||||||
|
}
|
||||||
// 转换状态
|
// 转换状态
|
||||||
convertGenerateStatus(mjMessage);
|
convertGenerateStatus(mjMessage);
|
||||||
// message handler 调用
|
// message handler 调用
|
||||||
@ -68,7 +76,20 @@ public class MidjourneyMessageListener {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private String getString(DataObject data, String key, String defaultValue) {
|
||||||
|
if (!data.hasKey(key)) {
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
return data.getString(key);
|
||||||
|
}
|
||||||
|
|
||||||
private void convertGenerateStatus(MidjourneyMessage mjMessage) {
|
private void convertGenerateStatus(MidjourneyMessage mjMessage) {
|
||||||
|
//
|
||||||
|
// tip:提示、警告、异常 content是没有内容的
|
||||||
|
// tip: 一般错误信息在 Embeds 只要 Embeds有值,content就没信息。
|
||||||
|
if (CollUtil.isNotEmpty(mjMessage.getEmbeds())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
|
if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
|
||||||
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus());
|
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus());
|
||||||
} else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {
|
} else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {
|
||||||
|
Loading…
Reference in New Issue
Block a user