From 38a9c1a7ee0ddd718b25751fd27b1219d4e6bfd5 Mon Sep 17 00:00:00 2001 From: cherishsince Date: Wed, 8 May 2024 15:58:03 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=A2=9E=E5=8A=A0=E3=80=91mj=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E5=A4=84=E7=90=86=E6=88=90=E5=8A=9F=E6=B6=88=E6=81=AF?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=8A=A0component=E6=93=8D=E4=BD=9C=EF=BC=8C?= =?UTF-8?q?=E5=A4=84=E7=90=86error=E4=BF=A1=E6=81=AF=E4=BF=9D=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../module/ai/convert/AiImageConvert.java | 10 ++++ .../YuDaoMidjourneyMessageHandler.java | 59 ++++++++++++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java index 089af76b5..27bf11136 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java @@ -1,8 +1,10 @@ package cn.iocoder.yudao.module.ai.convert; +import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListRespVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperationsVO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import org.mapstruct.Mapper; import org.mapstruct.factory.Mappers; @@ -36,4 +38,12 @@ public interface AiImageConvert { * @return */ List convertAiImageListRespVO(List list); + + /** + * 转换 - AiImageMidjourneyOperationsVO + * + * @param component + * @return + */ + AiImageMidjourneyOperationsVO convertAiImageMidjourneyOperationsVO(MidjourneyMessage.Component component); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/midjourneyHandler/YuDaoMidjourneyMessageHandler.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/midjourneyHandler/YuDaoMidjourneyMessageHandler.java index 3870ba0dc..43ca13b6a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/midjourneyHandler/YuDaoMidjourneyMessageHandler.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/midjourneyHandler/YuDaoMidjourneyMessageHandler.java @@ -5,14 +5,21 @@ import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage; import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler; +import cn.iocoder.yudao.framework.common.util.json.JsonUtils; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperationsVO; +import cn.iocoder.yudao.module.ai.convert.AiImageConvert; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; -import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum; import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper; +import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum; import com.alibaba.fastjson2.JSON; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + /** * yudao message handler * @@ -45,6 +52,36 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler { if (StrUtil.isBlank(midjourneyMessage.getNonce())) { return; } + // 根据 Embeds 来判断是否异常 + if (CollUtil.isEmpty(midjourneyMessage.getEmbeds())) { + successHandler(midjourneyMessage); + } else { + errorHandler(midjourneyMessage); + } + } + + private void errorHandler(MidjourneyMessage midjourneyMessage) { + // image 编号 + Long aiImageId = Long.valueOf(midjourneyMessage.getNonce()); + // 获取 error message + String errorMessage = getErrorMessage(midjourneyMessage); + aiImageMapper.updateById( + new AiImageDO() + .setId(aiImageId) + .setDrawingErrorMessage(errorMessage) + .setDrawingStatus(AiImageDrawingStatusEnum.FAIL.getStatus()) + ); + } + + private String getErrorMessage(MidjourneyMessage midjourneyMessage) { + StringBuilder errorMessage = new StringBuilder(); + for (MidjourneyMessage.Embed embed : midjourneyMessage.getEmbeds()) { + errorMessage.append(embed.getDescription()); + } + return errorMessage.toString(); + } + + private void successHandler(MidjourneyMessage midjourneyMessage) { // 获取id Long aiImageId = Long.valueOf(midjourneyMessage.getNonce()); // 获取生成 url @@ -59,14 +96,32 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler { drawingStatusEnum = AiImageDrawingStatusEnum.COMPLETE; } else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) { drawingStatusEnum = AiImageDrawingStatusEnum.IN_PROGRESS; - } else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) { + } else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) { drawingStatusEnum = AiImageDrawingStatusEnum.WAITING; } + // 获取 midjourneyOperations + List midjourneyOperations = getMidjourneyOperationsList(midjourneyMessage); + // 更新数据库 aiImageMapper.updateById( new AiImageDO() .setId(aiImageId) .setDrawingImageUrl(imageUrl) .setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus()) + .setMjMessageId(midjourneyMessage.getId()) + .setMjOperations(JsonUtils.toJsonString(midjourneyOperations)) ); } + + private List getMidjourneyOperationsList(MidjourneyMessage midjourneyMessage) { + // 为空直接返回 + if (CollUtil.isEmpty(midjourneyMessage.getComponents())) { + return Collections.emptyList(); + } + // 将 component 转成 AiImageMidjourneyOperationsVO + return midjourneyMessage.getComponents().stream() + .map(componentType -> componentType.getComponents().stream() + .map(AiImageConvert.INSTANCE::convertAiImageMidjourneyOperationsVO) + .collect(Collectors.toList())) + .toList().stream().flatMap(List::stream).toList(); + } }