【修改】AI Image dall 请求返回结构优化

This commit is contained in:
cherishsince 2024-05-27 18:16:44 +08:00
parent 0cdd3423e5
commit 63a8cc244d
4 changed files with 83 additions and 24 deletions

View File

@ -32,8 +32,8 @@ public class AiImageController {
@Operation(summary = "获取image列表", description = "dall3、midjourney") @Operation(summary = "获取image列表", description = "dall3、midjourney")
@GetMapping("/list") @GetMapping("/list")
public PageResult<AiImageListRespVO> list(@Validated @RequestBody AiImageListReqVO req) { public CommonResult<PageResult<AiImageListRespVO>> list(@Validated @ModelAttribute AiImageListReqVO req) {
return aiImageService.list(req); return CommonResult.success(aiImageService.list(req));
} }
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!") @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")

View File

@ -1,6 +1,9 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo; package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
@ -15,4 +18,50 @@ import lombok.experimental.Accessors;
@Accessors(chain = true) @Accessors(chain = true)
public class AiImageListRespVO extends PageParam { public class AiImageListRespVO extends PageParam {
private Long id;
@Schema(description = "用户id")
private Long userId;
@Schema(description = "提示词")
private String prompt;
@Schema(description = "模型 dall2/dall3、MJ、NIJI")
private String model;
@Schema(description = "生成图像的尺寸大小。对于dall-e-2模型尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 或 1024x1792。")
private String size;
@Schema(description = "风格")
private String style;
@Schema(description = "图片地址(自己服务器)")
private String picUrl;
@Schema(description = "绘画状态:提交、排队、绘画中、绘画完成、绘画失败")
private String status;
@Schema(description = "绘画图片地址(绘画好的服务器)")
private String originalPicUrl;
@Schema(description = "绘画错误信息")
private String errorMessage;
@Schema(description = "是否发布")
private String publicStatus;
// ============ mj 需要字段
@Schema(description = "用户操作的Nonce编号(MJ返回)")
private String mjNonceId;
@Schema(description = "用户操作的操作编号(MJ返回)")
private String mjOperationId;
@Schema(description = "用户操作的操作名字(MJ返回)")
private String mjOperationName;
@Schema(description = "mj图片生产成功保存的 components json 数组")
private String mjOperations;
} }

View File

@ -23,6 +23,14 @@ public interface AiImageConvert {
AiImageConvert INSTANCE = Mappers.getMapper(AiImageConvert.class); AiImageConvert INSTANCE = Mappers.getMapper(AiImageConvert.class);
/**
* 转换 - AiImageDallDrawingRespVO
*
* @param req
* @return
*/
AiImageDallRespVO convertAiImageDallDrawingRespVO(AiImageDO req);
/** /**
* 转换 - AiImageDallDrawingRespVO * 转换 - AiImageDallDrawingRespVO
* *

View File

@ -2,16 +2,10 @@ package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.util.IdUtil; import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.exception.AiException; import cn.hutool.http.HttpUtil;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi; import cn.iocoder.yudao.framework.ai.core.exception.AiException;
import org.springframework.ai.models.midjourney.api.req.ReRollReq;
import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
import org.springframework.ai.models.midjourney.webSocket.WssNotify;
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils; import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
@ -23,9 +17,17 @@ import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum; import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
import cn.iocoder.yudao.module.infra.api.file.FileApi;
import jakarta.annotation.PostConstruct; import jakarta.annotation.PostConstruct;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
import org.springframework.ai.models.midjourney.api.req.ReRollReq;
import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
import org.springframework.ai.models.midjourney.webSocket.WssNotify;
import org.springframework.ai.openai.OpenAiImageClient; import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -47,6 +49,7 @@ import java.util.List;
public class AiImageServiceImpl implements AiImageService { public class AiImageServiceImpl implements AiImageService {
private final AiImageMapper aiImageMapper; private final AiImageMapper aiImageMapper;
private final FileApi fileApi;
private final OpenAiImageClient openAiImageClient; private final OpenAiImageClient openAiImageClient;
private final MidjourneyWebSocketStarter midjourneyWebSocketStarter; private final MidjourneyWebSocketStarter midjourneyWebSocketStarter;
private final MidjourneyInteractionsApi midjourneyInteractionsApi; private final MidjourneyInteractionsApi midjourneyInteractionsApi;
@ -89,8 +92,6 @@ public class AiImageServiceImpl implements AiImageService {
// 获取 model // 获取 model
OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel()); OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel());
OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle()); OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
// 转换 AiImageDallDrawingRespVO
AiImageDallRespVO respVO = AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(req);
try { try {
// 转换openai 参数 // 转换openai 参数
OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions(); OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
@ -100,22 +101,21 @@ public class AiImageServiceImpl implements AiImageService {
ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions)); ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
// 发送 // 发送
ImageGeneration imageGeneration = imageResponse.getResult(); ImageGeneration imageGeneration = imageResponse.getResult();
// 图片保存到服务器
String filePath = fileApi.createFile(HttpUtil.downloadBytes(imageGeneration.getOutput().getUrl()));
// 保存数据库 // 保存数据库
doSave(req.getPrompt(), req.getSize(), req.getModel(), AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null, filePath, imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null,
null, null, null); null, null, null);
// 返回 flex // 转换 AiImageDallDrawingRespVO
respVO.setOriginalPicUrl(imageGeneration.getOutput().getUrl()); return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
respVO.setBase64(imageGeneration.getOutput().getB64Json());
return respVO;
} catch (AiException aiException) { } catch (AiException aiException) {
// 保存数据库 // 保存数据库
doSave(req.getPrompt(), req.getSize(), req.getModel(), AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
null, AiImageStatusEnum.FAIL, aiException.getMessage(), null, null, AiImageStatusEnum.FAIL, aiException.getMessage(),
null, null, null); null, null, null);
// 发送错误信息 // 发送错误信息
respVO.setErrorMessage(aiException.getMessage()); return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
return respVO;
} }
} }
@ -125,7 +125,7 @@ public class AiImageServiceImpl implements AiImageService {
// 保存数据库 // 保存数据库
String messageId = String.valueOf(IdUtil.getSnowflakeNextId()); String messageId = String.valueOf(IdUtil.getSnowflakeNextId());
AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny", AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
null, AiImageStatusEnum.SUBMIT, null, null, null, AiImageStatusEnum.SUBMIT, null,
messageId, null, null); messageId, null, null);
// 提交 midjourney 任务 // 提交 midjourney 任务
Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt()); Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt());
@ -149,7 +149,7 @@ public class AiImageServiceImpl implements AiImageService {
String mjOperationName = midjourneyOperationsVO.getLabel(); String mjOperationName = midjourneyOperationsVO.getLabel();
// 保存一个 image 任务记录 // 保存一个 image 任务记录
doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModel(), doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModel(),
null, AiImageStatusEnum.SUBMIT, null, null, null, AiImageStatusEnum.SUBMIT, null,
req.getMessageId(), req.getOperateId(), mjOperationName); req.getMessageId(), req.getOperateId(), mjOperationName);
// 提交操作 // 提交操作
midjourneyInteractionsApi.reRoll( midjourneyInteractionsApi.reRoll(
@ -201,6 +201,7 @@ public class AiImageServiceImpl implements AiImageService {
private AiImageDO doSave(String prompt, private AiImageDO doSave(String prompt,
String size, String size,
String model, String model,
String picUrl,
String originalPicUrl, String originalPicUrl,
AiImageStatusEnum statusEnum, AiImageStatusEnum statusEnum,
String errorMessage, String errorMessage,
@ -218,6 +219,7 @@ public class AiImageServiceImpl implements AiImageService {
// TODO @芋艿 如何上传到自己服务器 // TODO @芋艿 如何上传到自己服务器
aiImageDO.setPicUrl(null); aiImageDO.setPicUrl(null);
aiImageDO.setStatus(statusEnum.getStatus()); aiImageDO.setStatus(statusEnum.getStatus());
aiImageDO.setPicUrl(picUrl);
aiImageDO.setOriginalPicUrl(originalPicUrl); aiImageDO.setOriginalPicUrl(originalPicUrl);
aiImageDO.setErrorMessage(errorMessage); aiImageDO.setErrorMessage(errorMessage);
// //