From a1f738dd81dac8149a0004cc1003c76bb3236297 Mon Sep 17 00:00:00 2001 From: cherishsince Date: Thu, 30 May 2024 16:29:28 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=A2=9E=E5=8A=A0=E3=80=91=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=20midjourney=20=E6=8F=90=E4=BA=A4=E4=BB=BB=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../client/vo/MidjourneySubmitCodeEnum.java | 36 +++++++ .../ai/client/vo/MidjourneySubmitRespVO.java | 4 +- .../admin/image/AiImageController.java | 37 ++++---- ...ava => AiImageMidjourneyImagineReqVO.java} | 17 ++-- .../ai/service/image/AiImageService.java | 5 +- .../ai/service/image/AiImageServiceImpl.java | 95 ++++++++++++------- 6 files changed, 126 insertions(+), 68 deletions(-) create mode 100644 yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneySubmitCodeEnum.java rename yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/{AiImageMidjourneyReqVO.java => AiImageMidjourneyImagineReqVO.java} (51%) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneySubmitCodeEnum.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneySubmitCodeEnum.java new file mode 100644 index 000000000..5bf571929 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneySubmitCodeEnum.java @@ -0,0 +1,36 @@ +package cn.iocoder.yudao.module.ai.client.vo; + +import com.google.common.collect.Lists; +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.util.List; + +/** + * Midjourney 提交任务 code 枚举 + * + * @author fansili + * @time 2024/5/30 14:33 + * @since 1.0 + */ +@Getter +@AllArgsConstructor +public enum MidjourneySubmitCodeEnum { + + // 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) + SUBMIT_SUCCESS("1", "提交成功"), + ALREADY_EXISTS("1", "已存在"), + QUEUING("22", "排队中"), + + ; + + public static final List SUCCESS_CODES = Lists.newArrayList( + SUBMIT_SUCCESS.code, + ALREADY_EXISTS.code, + QUEUING.code + ); + + private String code; + private String name; + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneySubmitRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneySubmitRespVO.java index c9a430d50..d689b2bd7 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneySubmitRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneySubmitRespVO.java @@ -3,6 +3,8 @@ package cn.iocoder.yudao.module.ai.client.vo; import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; +import java.util.Map; + /** * Midjourney:Imagine 请求 * @@ -20,7 +22,7 @@ public class MidjourneySubmitRespVO { private String description; @Schema(description = "扩展字段") - private String properties; + private Map properties; @Schema(description = "任务ID") private String result; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java index 4dc66a644..6841b1c9c 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java @@ -3,13 +3,17 @@ package cn.iocoder.yudao.module.ai.controller.admin.image; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageMyRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.service.image.AiImageService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.annotation.Resource; +import jakarta.servlet.http.HttpServletRequest; import lombok.extern.slf4j.Slf4j; import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.*; @@ -49,32 +53,17 @@ public class AiImageController { } // TODO @fan:建议把 dallDrawing、midjourney 融合成一个 draw 接口,异步绘制;然后返回一个 id 给前端;前端通过 get 接口轮询,直到获取到生成成功 + // TODO @芋艿: 参数差异较大 @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!") @PostMapping("/dall") public CommonResult dall(@Validated @RequestBody AiImageDallReqVO req) { return success(aiImageService.dall(getLoginUserId(), req)); } - @Operation(summary = "midjourney绘画", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果") - @PostMapping("/midjourney") - public CommonResult midjourney(@Validated @RequestBody AiImageMidjourneyReqVO req) { - aiImageService.midjourney(req); - return success(null); - } - - @Operation(summary = "midjourney绘画操作", description = "一般有选择图片、放大、换一批...") - @PostMapping("/midjourney-operate") - public CommonResult midjourneyOperate(@Validated @RequestBody AiImageMidjourneyOperateReqVO req) { - aiImageService.midjourneyOperate(req); - return success(null); - } - - // TODO @fan:要不先不要 midjourneyOperate、cancelMidjourney 接口哈 - @Operation(summary = "取消 midjourney 绘画", description = "取消 midjourney 绘画") - @PostMapping("/cancel-midjourney") - public CommonResult cancelMidjourney(@RequestParam("id") Long id) { - // @范 这里实现mj取消逻辑 - return success(null); + @Operation(summary = "midjourney-imagine 绘画", description = "...") + @PostMapping("/midjourney/imagine") + public CommonResult midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO req) { + return success(aiImageService.midjourneyImagine(getLoginUserId(), req)); } @Operation(summary = "删除【我的】绘画记录") @@ -83,4 +72,10 @@ public class AiImageController { public CommonResult deleteIdMy(@RequestParam("id") Long id) { return success(aiImageService.deleteIdMy(id, getLoginUserId())); } + + @Operation(summary = "删除【我的】绘画记录") + @RequestMapping("/midjourney-notify") + public CommonResult midjourneyNotify(HttpServletRequest request) { + return success(true); + } } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageMidjourneyReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageMidjourneyImagineReqVO.java similarity index 51% rename from yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageMidjourneyReqVO.java rename to yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageMidjourneyImagineReqVO.java index e6cb0dcec..07cd10ab6 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageMidjourneyReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageMidjourneyImagineReqVO.java @@ -1,9 +1,12 @@ package cn.iocoder.yudao.module.ai.controller.admin.image.vo; import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotNull; import lombok.Data; import lombok.experimental.Accessors; +import java.util.List; + /** * midjourney req * @@ -13,17 +16,15 @@ import lombok.experimental.Accessors; */ @Data @Accessors(chain = true) -public class AiImageMidjourneyReqVO { +public class AiImageMidjourneyImagineReqVO { @Schema(description = "提示词") + @NotNull(message = "提示词不能为空!") private String prompt; - @Schema(description = "绘画比例 1:1、3:4、4:3、9:16、16:9") - private String size; + @Schema(description = "模型(midjourney、niji)") + private String model; - @Schema(description = "风格") - private String style; - - @Schema(description = "参考图") - private String referImage; + @Schema(description = "垫图(参考图)base64数组") + private List base64Array; } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java index 6a342a804..b07ba7dd8 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java @@ -3,8 +3,8 @@ package cn.iocoder.yudao.module.ai.service.image; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperateReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; /** @@ -44,10 +44,11 @@ public interface AiImageService { /** * midjourney 图片生成 * + * @param loginUserId * @param req * @return */ - void midjourney(AiImageMidjourneyReqVO req); + Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req); /** * midjourney 操作(u1、u2、放大、换一批...) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 25467c9af..d6da654f5 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -1,7 +1,7 @@ package cn.iocoder.yudao.module.ai.service.image; -import cn.hutool.core.util.IdUtil; import cn.hutool.http.HttpUtil; +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum; import cn.iocoder.yudao.framework.ai.core.exception.AiException; @@ -11,6 +11,10 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.module.ai.AiCommonConstants; import cn.iocoder.yudao.module.ai.ErrorCodeConstants; +import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient; +import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO; +import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitCodeEnum; +import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; @@ -18,21 +22,22 @@ import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum; import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum; import cn.iocoder.yudao.module.infra.api.file.FileApi; import com.google.common.collect.ImmutableMap; -import jakarta.annotation.PostConstruct; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; -import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi; -import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter; import org.springframework.ai.openai.OpenAiImageClient; import org.springframework.ai.openai.OpenAiImageOptions; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; @@ -59,28 +64,11 @@ public class AiImageServiceImpl implements AiImageService { private FileApi fileApi; @Resource private OpenAiImageClient openAiImageClient; - @Resource - private MidjourneyWebSocketStarter midjourneyWebSocketStarter; - @Resource - private MidjourneyInteractionsApi midjourneyInteractionsApi; + @Autowired + private MidjourneyProxyClient midjourneyProxyClient; - // TODO @fan:接 mj proxy - @PostConstruct - public void startMidjourney() { - // todo @fan 暂时注释掉 -// log.info("midjourney web socket starter..."); -// midjourneyWebSocketStarter.start(new WssNotify() { -// @Override -// public void notify(int code, String message) { -// log.info("code: {}, message: {}", code, message); -// if (message.contains("Authentication failed")) { -// // TODO 芋艿,这里看怎么处理,token无效的时候会认证失败! -// // 认证失败 -// log.error("midjourney socket 认证失败,检查token是否失效!"); -// } -// } -// }); - } + @Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}") + private String midjourneyNotifyUrl; @Override public PageResult getImagePageMy(Long loginUserId, AiImageListReqVO req) { @@ -143,18 +131,53 @@ public class AiImageServiceImpl implements AiImageService { @Override @Transactional(rollbackFor = Exception.class) - public void midjourney(AiImageMidjourneyReqVO req) { - // 保存数据库 - String messageId = String.valueOf(IdUtil.getSnowflakeNextId()); - // todo -// AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny", -// null, null, AiImageStatusEnum.SUBMIT, null, -// messageId, null, null); - // 提交 midjourney 任务 - Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt()); - if (!imagine) { - throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL); + public Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req) { + + // 1、构建 AiImageDO + AiImageDO aiImageDO = new AiImageDO(); + aiImageDO.setId(null); + aiImageDO.setUserId(loginUserId); + aiImageDO.setPrompt(req.getPrompt()); + aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform()); + // todo @范 平台需要转换(mj 模型一般分版本) + aiImageDO.setModel(null); + aiImageDO.setWidth(null); + aiImageDO.setHeight(null); + aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); + aiImageDO.setPublicStatus(AiImagePublicStatusEnum.PRIVATE.getStatus()); + aiImageDO.setPicUrl(null); + aiImageDO.setOriginalPicUrl(null); + aiImageDO.setDrawRequest(null); + aiImageDO.setDrawResponse(null); + aiImageDO.setErrorMessage(null); + + // 2、保存 image + imageMapper.insert(aiImageDO); + + // 3、调用 MidjourneyProxy 提交任务 + MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class); + imagineReqVO.setNotifyHook(midjourneyNotifyUrl); + imagineReqVO.setState(String.valueOf(aiImageDO.getId())); + MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO); + + // 4、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)) + String updateStatus = null; + String errorMessage = null; + Map drawResponse = new HashMap<>(); + + if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) { + updateStatus = AiImageStatusEnum.FAIL.getStatus(); + errorMessage = submitRespVO.getDescription(); + } else { + drawResponse.put("jobId", submitRespVO.getResult()); } + imageMapper.updateById(new AiImageDO() + .setId(aiImageDO.getId()) + .setStatus(updateStatus) + .setErrorMessage(errorMessage) + .setDrawResponse(drawResponse) + ); + return aiImageDO.getId(); } @Transactional(rollbackFor = Exception.class)