【优化】AI Image 调整字段

This commit is contained in:
cherishsince 2024-05-27 15:26:44 +08:00
parent 3afb3089b5
commit 0cdd3423e5
9 changed files with 58 additions and 53 deletions

View File

@ -12,7 +12,7 @@ import lombok.Getter;
*/
@AllArgsConstructor
@Getter
public enum AiImageDrawingStatusEnum {
public enum AiImageStatusEnum {
SUBMIT("submit", "提交任务"),
WAITING("waiting", "等待"),
@ -27,8 +27,8 @@ public enum AiImageDrawingStatusEnum {
private String name;
public static AiImageDrawingStatusEnum valueOfStatus(String status) {
for (AiImageDrawingStatusEnum itemEnum : AiImageDrawingStatusEnum.values()) {
public static AiImageStatusEnum valueOfStatus(String status) {
for (AiImageStatusEnum itemEnum : AiImageStatusEnum.values()) {
if (itemEnum.getStatus().equals(status)) {
return itemEnum;
}

View File

@ -37,8 +37,8 @@ public class AiImageController {
}
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
@PostMapping("/dallDrawing")
public AiImageDallDrawingRespVO dallDrawing(@Validated @RequestBody AiImageDallDrawingReqVO req) {
@PostMapping("/dall")
public AiImageDallRespVO dallDrawing(@Validated @RequestBody AiImageDallReqVO req) {
return aiImageService.dallDrawing(req);
}

View File

@ -15,19 +15,18 @@ import lombok.experimental.Accessors;
*/
@Data
@Accessors(chain = true)
public class AiImageDallDrawingReqVO {
public class AiImageDallReqVO {
@Schema(description = "提示词")
@NotNull(message = "提示词不能为空!")
@Size(max = 1200, message = "提示词最大1200")
private String prompt;
@Schema(description = "模型")
@Schema(description = "模型(dall2、dall3)")
@NotNull(message = "模型不能为空")
private String modal;
private String model;
@Schema(description = "图像生成的风格。可为vivid生动或natural自然)")
@NotNull(message = "图像生成的风格,不能为空!")
private String style;
@Schema(description = "生成图像的尺寸大小。对于dall-e-2模型尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 或 1024x1792。")

View File

@ -15,8 +15,7 @@ import lombok.experimental.Accessors;
*/
@Data
@Accessors(chain = true)
public class AiImageDallDrawingRespVO {
public class AiImageDallRespVO {
@Schema(description = "提示词")
@NotNull(message = "提示词不能为空!")
@ -25,7 +24,7 @@ public class AiImageDallDrawingRespVO {
@Schema(description = "模型")
@NotNull(message = "模型不能为空")
private String modal;
private String model;
@Schema(description = "风格")
private String style;
@ -33,8 +32,11 @@ public class AiImageDallDrawingRespVO {
@Schema(description = "图片size 1024x1024 ...")
private String size;
@Schema(description = "图片地址(自己服务器)")
private String picUrl;
@Schema(description = "可以访问图像的URL。")
private String url;
private String originalPicUrl;
@Schema(description = "图片base64。")
private String base64;

View File

@ -1,8 +1,8 @@
package cn.iocoder.yudao.module.ai.convert;
import org.springframework.ai.models.midjourney.MidjourneyMessage;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperationsVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@ -29,7 +29,7 @@ public interface AiImageConvert {
* @param req
* @return
*/
AiImageDallDrawingRespVO convertAiImageDallDrawingRespVO(AiImageDallDrawingReqVO req);
AiImageDallRespVO convertAiImageDallDrawingRespVO(AiImageDallReqVO req);
/**
* 转换 - AiImageListRespVO

View File

@ -30,22 +30,28 @@ public class AiImageDO extends BaseDO {
private String prompt;
@Schema(description = "模型 dall2/dall3、MJ、NIJI")
private String modal;
private String model;
@Schema(description = "生成图像的尺寸大小。对于dall-e-2模型尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 或 1024x1792。")
private String size;
@Schema(description = "风格")
private String style;
@Schema(description = "图片地址(自己服务器)")
private String imageUrl;
private String picUrl;
@Schema(description = "绘画状态:提交、排队、绘画中、绘画完成、绘画失败")
private String drawingStatus;
private String status;
@Schema(description = "绘画图片地址(绘画好的服务器)")
private String drawingImageUrl;
private String originalPicUrl;
@Schema(description = "绘画错误信息")
private String drawingErrorMessage;
private String errorMessage;
@Schema(description = "是否发布")
private String publicStatus;
// ============ mj 需要字段

View File

@ -3,8 +3,6 @@ package cn.iocoder.yudao.module.ai.service.image;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import java.util.List;
/**
* ai 作图
*
@ -27,7 +25,7 @@ public interface AiImageService {
*
* @param req
*/
AiImageDallDrawingRespVO dallDrawing(AiImageDallDrawingReqVO req);
AiImageDallRespVO dallDrawing(AiImageDallReqVO req);
/**
* midjourney 图片生成

View File

@ -22,7 +22,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
import jakarta.annotation.PostConstruct;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@ -85,12 +85,12 @@ public class AiImageServiceImpl implements AiImageService {
}
@Override
public AiImageDallDrawingRespVO dallDrawing(AiImageDallDrawingReqVO req) {
public AiImageDallRespVO dallDrawing(AiImageDallReqVO req) {
// 获取 model
OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal());
OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel());
OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
// 转换 AiImageDallDrawingRespVO
AiImageDallDrawingRespVO respVO = AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(req);
AiImageDallRespVO respVO = AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(req);
try {
// 转换openai 参数
OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
@ -101,17 +101,17 @@ public class AiImageServiceImpl implements AiImageService {
// 发送
ImageGeneration imageGeneration = imageResponse.getResult();
// 保存数据库
doSave(req.getPrompt(), req.getSize(), req.getModal(),
imageGeneration.getOutput().getUrl(), AiImageDrawingStatusEnum.COMPLETE, null,
doSave(req.getPrompt(), req.getSize(), req.getModel(),
imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null,
null, null, null);
// 返回 flex
respVO.setUrl(imageGeneration.getOutput().getUrl());
respVO.setOriginalPicUrl(imageGeneration.getOutput().getUrl());
respVO.setBase64(imageGeneration.getOutput().getB64Json());
return respVO;
} catch (AiException aiException) {
// 保存数据库
doSave(req.getPrompt(), req.getSize(), req.getModal(),
null, AiImageDrawingStatusEnum.FAIL, aiException.getMessage(),
doSave(req.getPrompt(), req.getSize(), req.getModel(),
null, AiImageStatusEnum.FAIL, aiException.getMessage(),
null, null, null);
// 发送错误信息
respVO.setErrorMessage(aiException.getMessage());
@ -125,7 +125,7 @@ public class AiImageServiceImpl implements AiImageService {
// 保存数据库
String messageId = String.valueOf(IdUtil.getSnowflakeNextId());
AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
null, AiImageDrawingStatusEnum.SUBMIT, null,
null, AiImageStatusEnum.SUBMIT, null,
messageId, null, null);
// 提交 midjourney 任务
Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt());
@ -148,8 +148,8 @@ public class AiImageServiceImpl implements AiImageService {
// 获取 mjOperationName
String mjOperationName = midjourneyOperationsVO.getLabel();
// 保存一个 image 任务记录
doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModal(),
null, AiImageDrawingStatusEnum.SUBMIT, null,
doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModel(),
null, AiImageStatusEnum.SUBMIT, null,
req.getMessageId(), req.getOperateId(), mjOperationName);
// 提交操作
midjourneyInteractionsApi.reRoll(
@ -201,9 +201,9 @@ public class AiImageServiceImpl implements AiImageService {
private AiImageDO doSave(String prompt,
String size,
String model,
String drawingImageUrl,
AiImageDrawingStatusEnum drawingStatusEnum,
String drawingErrorMessage,
String originalPicUrl,
AiImageStatusEnum statusEnum,
String errorMessage,
String mjMessageId,
String mjOperationId,
String mjOperationName) {
@ -213,13 +213,13 @@ public class AiImageServiceImpl implements AiImageService {
aiImageDO.setId(null);
aiImageDO.setPrompt(prompt);
aiImageDO.setSize(size);
aiImageDO.setModal(model);
aiImageDO.setModel(model);
aiImageDO.setUserId(loginUserId);
// TODO @芋艿 如何上传到自己服务器
aiImageDO.setImageUrl(null);
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
aiImageDO.setDrawingImageUrl(drawingImageUrl);
aiImageDO.setDrawingErrorMessage(drawingErrorMessage);
aiImageDO.setPicUrl(null);
aiImageDO.setStatus(statusEnum.getStatus());
aiImageDO.setOriginalPicUrl(originalPicUrl);
aiImageDO.setErrorMessage(errorMessage);
//
aiImageDO.setMjNonceId(mjMessageId);
aiImageDO.setMjOperationId(mjOperationId);

View File

@ -11,7 +11,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOpe
import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@ -67,8 +67,8 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
String errorMessage = getErrorMessage(midjourneyMessage);
aiImageMapper.updateByMjNonce(nonceId,
new AiImageDO()
.setDrawingErrorMessage(errorMessage)
.setDrawingStatus(AiImageDrawingStatusEnum.FAIL.getStatus())
.setErrorMessage(errorMessage)
.setStatus(AiImageStatusEnum.FAIL.getStatus())
);
}
@ -90,22 +90,22 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
imageUrl = midjourneyMessage.getAttachments().get(0).getUrl();
}
// 转换状态
AiImageDrawingStatusEnum drawingStatusEnum = null;
AiImageStatusEnum drawingStatusEnum = null;
String generateStatus = midjourneyMessage.getGenerateStatus();
if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiImageDrawingStatusEnum.COMPLETE;
drawingStatusEnum = AiImageStatusEnum.COMPLETE;
} else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiImageDrawingStatusEnum.IN_PROGRESS;
drawingStatusEnum = AiImageStatusEnum.IN_PROGRESS;
} else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiImageDrawingStatusEnum.WAITING;
drawingStatusEnum = AiImageStatusEnum.WAITING;
}
// 获取 midjourneyOperations
List<AiImageMidjourneyOperationsVO> midjourneyOperations = getMidjourneyOperationsList(midjourneyMessage);
// 更新数据库
aiImageMapper.updateByMjNonce(nonceId,
new AiImageDO()
.setDrawingImageUrl(imageUrl)
.setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus())
.setOriginalPicUrl(imageUrl)
.setStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus())
.setMjNonceId(midjourneyMessage.getId())
.setMjOperations(JsonUtils.toJsonString(midjourneyOperations))
);