From 63a8cc244d054040ad227a9b5651e0382a65d7e9 Mon Sep 17 00:00:00 2001 From: cherishsince Date: Mon, 27 May 2024 18:16:44 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BF=AE=E6=94=B9=E3=80=91AI=20Image?= =?UTF-8?q?=20dall=20=E8=AF=B7=E6=B1=82=E8=BF=94=E5=9B=9E=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../admin/image/AiImageController.java | 4 +- .../admin/image/vo/AiImageListRespVO.java | 49 +++++++++++++++++++ .../module/ai/convert/AiImageConvert.java | 8 +++ .../ai/service/image/AiImageServiceImpl.java | 46 ++++++++--------- 4 files changed, 83 insertions(+), 24 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java index 2dbac1a9b..d072822eb 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java @@ -32,8 +32,8 @@ public class AiImageController { @Operation(summary = "获取image列表", description = "dall3、midjourney") @GetMapping("/list") - public PageResult list(@Validated @RequestBody AiImageListReqVO req) { - return aiImageService.list(req); + public CommonResult> list(@Validated @ModelAttribute AiImageListReqVO req) { + return CommonResult.success(aiImageService.list(req)); } @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageListRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageListRespVO.java index 9222a4ec4..1ddec8d7b 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageListRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageListRespVO.java @@ -1,6 +1,9 @@ package cn.iocoder.yudao.module.ai.controller.admin.image.vo; import cn.iocoder.yudao.framework.common.pojo.PageParam; +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; import lombok.experimental.Accessors; @@ -15,4 +18,50 @@ import lombok.experimental.Accessors; @Accessors(chain = true) public class AiImageListRespVO extends PageParam { + private Long id; + + @Schema(description = "用户id") + private Long userId; + + @Schema(description = "提示词") + private String prompt; + + @Schema(description = "模型 dall2/dall3、MJ、NIJI") + private String model; + + @Schema(description = "生成图像的尺寸大小。对于dall-e-2模型,尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型,尺寸可为1024x1024, 1792x1024, 或 1024x1792。") + private String size; + + @Schema(description = "风格") + private String style; + + @Schema(description = "图片地址(自己服务器)") + private String picUrl; + + @Schema(description = "绘画状态:提交、排队、绘画中、绘画完成、绘画失败") + private String status; + + @Schema(description = "绘画图片地址(绘画好的服务器)") + private String originalPicUrl; + + @Schema(description = "绘画错误信息") + private String errorMessage; + + @Schema(description = "是否发布") + private String publicStatus; + + // ============ mj 需要字段 + + @Schema(description = "用户操作的Nonce编号(MJ返回)") + private String mjNonceId; + + @Schema(description = "用户操作的操作编号(MJ返回)") + private String mjOperationId; + + @Schema(description = "用户操作的操作名字(MJ返回)") + private String mjOperationName; + + @Schema(description = "mj图片生产成功保存的 components json 数组") + private String mjOperations; + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java index f5e6699d7..df8647a5b 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java @@ -23,6 +23,14 @@ public interface AiImageConvert { AiImageConvert INSTANCE = Mappers.getMapper(AiImageConvert.class); + /** + * 转换 - AiImageDallDrawingRespVO + * + * @param req + * @return + */ + AiImageDallRespVO convertAiImageDallDrawingRespVO(AiImageDO req); + /** * 转换 - AiImageDallDrawingRespVO * diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 01476a4a3..18659377c 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -2,16 +2,10 @@ package cn.iocoder.yudao.module.ai.service.image; import cn.hutool.core.util.IdUtil; import cn.hutool.core.util.StrUtil; -import cn.iocoder.yudao.framework.ai.core.exception.AiException; -import org.springframework.ai.image.ImageGeneration; -import org.springframework.ai.image.ImagePrompt; -import org.springframework.ai.image.ImageResponse; +import cn.hutool.http.HttpUtil; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum; -import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi; -import org.springframework.ai.models.midjourney.api.req.ReRollReq; -import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter; -import org.springframework.ai.models.midjourney.webSocket.WssNotify; +import cn.iocoder.yudao.framework.ai.core.exception.AiException; import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.json.JsonUtils; @@ -23,9 +17,17 @@ import cn.iocoder.yudao.module.ai.convert.AiImageConvert; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum; +import cn.iocoder.yudao.module.infra.api.file.FileApi; import jakarta.annotation.PostConstruct; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi; +import org.springframework.ai.models.midjourney.api.req.ReRollReq; +import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter; +import org.springframework.ai.models.midjourney.webSocket.WssNotify; import org.springframework.ai.openai.OpenAiImageClient; import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.stereotype.Service; @@ -47,6 +49,7 @@ import java.util.List; public class AiImageServiceImpl implements AiImageService { private final AiImageMapper aiImageMapper; + private final FileApi fileApi; private final OpenAiImageClient openAiImageClient; private final MidjourneyWebSocketStarter midjourneyWebSocketStarter; private final MidjourneyInteractionsApi midjourneyInteractionsApi; @@ -89,8 +92,6 @@ public class AiImageServiceImpl implements AiImageService { // 获取 model OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel()); OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle()); - // 转换 AiImageDallDrawingRespVO - AiImageDallRespVO respVO = AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(req); try { // 转换openai 参数 OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions(); @@ -100,22 +101,21 @@ public class AiImageServiceImpl implements AiImageService { ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions)); // 发送 ImageGeneration imageGeneration = imageResponse.getResult(); + // 图片保存到服务器 + String filePath = fileApi.createFile(HttpUtil.downloadBytes(imageGeneration.getOutput().getUrl())); // 保存数据库 - doSave(req.getPrompt(), req.getSize(), req.getModel(), - imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null, + AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(), + filePath, imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null, null, null, null); - // 返回 flex - respVO.setOriginalPicUrl(imageGeneration.getOutput().getUrl()); - respVO.setBase64(imageGeneration.getOutput().getB64Json()); - return respVO; + // 转换 AiImageDallDrawingRespVO + return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO); } catch (AiException aiException) { // 保存数据库 - doSave(req.getPrompt(), req.getSize(), req.getModel(), - null, AiImageStatusEnum.FAIL, aiException.getMessage(), + AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(), + null, null, AiImageStatusEnum.FAIL, aiException.getMessage(), null, null, null); // 发送错误信息 - respVO.setErrorMessage(aiException.getMessage()); - return respVO; + return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO); } } @@ -125,7 +125,7 @@ public class AiImageServiceImpl implements AiImageService { // 保存数据库 String messageId = String.valueOf(IdUtil.getSnowflakeNextId()); AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny", - null, AiImageStatusEnum.SUBMIT, null, + null, null, AiImageStatusEnum.SUBMIT, null, messageId, null, null); // 提交 midjourney 任务 Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt()); @@ -149,7 +149,7 @@ public class AiImageServiceImpl implements AiImageService { String mjOperationName = midjourneyOperationsVO.getLabel(); // 保存一个 image 任务记录 doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModel(), - null, AiImageStatusEnum.SUBMIT, null, + null, null, AiImageStatusEnum.SUBMIT, null, req.getMessageId(), req.getOperateId(), mjOperationName); // 提交操作 midjourneyInteractionsApi.reRoll( @@ -201,6 +201,7 @@ public class AiImageServiceImpl implements AiImageService { private AiImageDO doSave(String prompt, String size, String model, + String picUrl, String originalPicUrl, AiImageStatusEnum statusEnum, String errorMessage, @@ -218,6 +219,7 @@ public class AiImageServiceImpl implements AiImageService { // TODO @芋艿 如何上传到自己服务器 aiImageDO.setPicUrl(null); aiImageDO.setStatus(statusEnum.getStatus()); + aiImageDO.setPicUrl(picUrl); aiImageDO.setOriginalPicUrl(originalPicUrl); aiImageDO.setErrorMessage(errorMessage); //