【优化】AI Image 调整字段

This commit is contained in:
cherishsince 2024-05-27 15:26:44 +08:00
parent 3afb3089b5
commit 0cdd3423e5
9 changed files with 58 additions and 53 deletions

View File

@ -12,7 +12,7 @@ import lombok.Getter;
*/ */
@AllArgsConstructor @AllArgsConstructor
@Getter @Getter
public enum AiImageDrawingStatusEnum { public enum AiImageStatusEnum {
SUBMIT("submit", "提交任务"), SUBMIT("submit", "提交任务"),
WAITING("waiting", "等待"), WAITING("waiting", "等待"),
@ -27,8 +27,8 @@ public enum AiImageDrawingStatusEnum {
private String name; private String name;
public static AiImageDrawingStatusEnum valueOfStatus(String status) { public static AiImageStatusEnum valueOfStatus(String status) {
for (AiImageDrawingStatusEnum itemEnum : AiImageDrawingStatusEnum.values()) { for (AiImageStatusEnum itemEnum : AiImageStatusEnum.values()) {
if (itemEnum.getStatus().equals(status)) { if (itemEnum.getStatus().equals(status)) {
return itemEnum; return itemEnum;
} }

View File

@ -37,8 +37,8 @@ public class AiImageController {
} }
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!") @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
@PostMapping("/dallDrawing") @PostMapping("/dall")
public AiImageDallDrawingRespVO dallDrawing(@Validated @RequestBody AiImageDallDrawingReqVO req) { public AiImageDallRespVO dallDrawing(@Validated @RequestBody AiImageDallReqVO req) {
return aiImageService.dallDrawing(req); return aiImageService.dallDrawing(req);
} }

View File

@ -15,19 +15,18 @@ import lombok.experimental.Accessors;
*/ */
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public class AiImageDallDrawingReqVO { public class AiImageDallReqVO {
@Schema(description = "提示词") @Schema(description = "提示词")
@NotNull(message = "提示词不能为空!") @NotNull(message = "提示词不能为空!")
@Size(max = 1200, message = "提示词最大1200") @Size(max = 1200, message = "提示词最大1200")
private String prompt; private String prompt;
@Schema(description = "模型") @Schema(description = "模型(dall2、dall3)")
@NotNull(message = "模型不能为空") @NotNull(message = "模型不能为空")
private String modal; private String model;
@Schema(description = "图像生成的风格。可为vivid生动或natural自然)") @Schema(description = "图像生成的风格。可为vivid生动或natural自然)")
@NotNull(message = "图像生成的风格,不能为空!")
private String style; private String style;
@Schema(description = "生成图像的尺寸大小。对于dall-e-2模型尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 或 1024x1792。") @Schema(description = "生成图像的尺寸大小。对于dall-e-2模型尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 或 1024x1792。")

View File

@ -15,8 +15,7 @@ import lombok.experimental.Accessors;
*/ */
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public class AiImageDallDrawingRespVO { public class AiImageDallRespVO {
@Schema(description = "提示词") @Schema(description = "提示词")
@NotNull(message = "提示词不能为空!") @NotNull(message = "提示词不能为空!")
@ -25,7 +24,7 @@ public class AiImageDallDrawingRespVO {
@Schema(description = "模型") @Schema(description = "模型")
@NotNull(message = "模型不能为空") @NotNull(message = "模型不能为空")
private String modal; private String model;
@Schema(description = "风格") @Schema(description = "风格")
private String style; private String style;
@ -33,8 +32,11 @@ public class AiImageDallDrawingRespVO {
@Schema(description = "图片size 1024x1024 ...") @Schema(description = "图片size 1024x1024 ...")
private String size; private String size;
@Schema(description = "图片地址(自己服务器)")
private String picUrl;
@Schema(description = "可以访问图像的URL。") @Schema(description = "可以访问图像的URL。")
private String url; private String originalPicUrl;
@Schema(description = "图片base64。") @Schema(description = "图片base64。")
private String base64; private String base64;

View File

@ -1,8 +1,8 @@
package cn.iocoder.yudao.module.ai.convert; package cn.iocoder.yudao.module.ai.convert;
import org.springframework.ai.models.midjourney.MidjourneyMessage; import org.springframework.ai.models.midjourney.MidjourneyMessage;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListRespVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperationsVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperationsVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@ -29,7 +29,7 @@ public interface AiImageConvert {
* @param req * @param req
* @return * @return
*/ */
AiImageDallDrawingRespVO convertAiImageDallDrawingRespVO(AiImageDallDrawingReqVO req); AiImageDallRespVO convertAiImageDallDrawingRespVO(AiImageDallReqVO req);
/** /**
* 转换 - AiImageListRespVO * 转换 - AiImageListRespVO

View File

@ -30,22 +30,28 @@ public class AiImageDO extends BaseDO {
private String prompt; private String prompt;
@Schema(description = "模型 dall2/dall3、MJ、NIJI") @Schema(description = "模型 dall2/dall3、MJ、NIJI")
private String modal; private String model;
@Schema(description = "生成图像的尺寸大小。对于dall-e-2模型尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 或 1024x1792。") @Schema(description = "生成图像的尺寸大小。对于dall-e-2模型尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 或 1024x1792。")
private String size; private String size;
@Schema(description = "风格")
private String style;
@Schema(description = "图片地址(自己服务器)") @Schema(description = "图片地址(自己服务器)")
private String imageUrl; private String picUrl;
@Schema(description = "绘画状态:提交、排队、绘画中、绘画完成、绘画失败") @Schema(description = "绘画状态:提交、排队、绘画中、绘画完成、绘画失败")
private String drawingStatus; private String status;
@Schema(description = "绘画图片地址(绘画好的服务器)") @Schema(description = "绘画图片地址(绘画好的服务器)")
private String drawingImageUrl; private String originalPicUrl;
@Schema(description = "绘画错误信息") @Schema(description = "绘画错误信息")
private String drawingErrorMessage; private String errorMessage;
@Schema(description = "是否发布")
private String publicStatus;
// ============ mj 需要字段 // ============ mj 需要字段

View File

@ -3,8 +3,6 @@ package cn.iocoder.yudao.module.ai.service.image;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import java.util.List;
/** /**
* ai 作图 * ai 作图
* *
@ -27,7 +25,7 @@ public interface AiImageService {
* *
* @param req * @param req
*/ */
AiImageDallDrawingRespVO dallDrawing(AiImageDallDrawingReqVO req); AiImageDallRespVO dallDrawing(AiImageDallReqVO req);
/** /**
* midjourney 图片生成 * midjourney 图片生成

View File

@ -22,7 +22,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.convert.AiImageConvert; import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum; import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
import jakarta.annotation.PostConstruct; import jakarta.annotation.PostConstruct;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -85,12 +85,12 @@ public class AiImageServiceImpl implements AiImageService {
} }
@Override @Override
public AiImageDallDrawingRespVO dallDrawing(AiImageDallDrawingReqVO req) { public AiImageDallRespVO dallDrawing(AiImageDallReqVO req) {
// 获取 model // 获取 model
OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal()); OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel());
OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle()); OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
// 转换 AiImageDallDrawingRespVO // 转换 AiImageDallDrawingRespVO
AiImageDallDrawingRespVO respVO = AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(req); AiImageDallRespVO respVO = AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(req);
try { try {
// 转换openai 参数 // 转换openai 参数
OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions(); OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
@ -101,17 +101,17 @@ public class AiImageServiceImpl implements AiImageService {
// 发送 // 发送
ImageGeneration imageGeneration = imageResponse.getResult(); ImageGeneration imageGeneration = imageResponse.getResult();
// 保存数据库 // 保存数据库
doSave(req.getPrompt(), req.getSize(), req.getModal(), doSave(req.getPrompt(), req.getSize(), req.getModel(),
imageGeneration.getOutput().getUrl(), AiImageDrawingStatusEnum.COMPLETE, null, imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null,
null, null, null); null, null, null);
// 返回 flex // 返回 flex
respVO.setUrl(imageGeneration.getOutput().getUrl()); respVO.setOriginalPicUrl(imageGeneration.getOutput().getUrl());
respVO.setBase64(imageGeneration.getOutput().getB64Json()); respVO.setBase64(imageGeneration.getOutput().getB64Json());
return respVO; return respVO;
} catch (AiException aiException) { } catch (AiException aiException) {
// 保存数据库 // 保存数据库
doSave(req.getPrompt(), req.getSize(), req.getModal(), doSave(req.getPrompt(), req.getSize(), req.getModel(),
null, AiImageDrawingStatusEnum.FAIL, aiException.getMessage(), null, AiImageStatusEnum.FAIL, aiException.getMessage(),
null, null, null); null, null, null);
// 发送错误信息 // 发送错误信息
respVO.setErrorMessage(aiException.getMessage()); respVO.setErrorMessage(aiException.getMessage());
@ -125,7 +125,7 @@ public class AiImageServiceImpl implements AiImageService {
// 保存数据库 // 保存数据库
String messageId = String.valueOf(IdUtil.getSnowflakeNextId()); String messageId = String.valueOf(IdUtil.getSnowflakeNextId());
AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny", AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
null, AiImageDrawingStatusEnum.SUBMIT, null, null, AiImageStatusEnum.SUBMIT, null,
messageId, null, null); messageId, null, null);
// 提交 midjourney 任务 // 提交 midjourney 任务
Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt()); Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt());
@ -148,8 +148,8 @@ public class AiImageServiceImpl implements AiImageService {
// 获取 mjOperationName // 获取 mjOperationName
String mjOperationName = midjourneyOperationsVO.getLabel(); String mjOperationName = midjourneyOperationsVO.getLabel();
// 保存一个 image 任务记录 // 保存一个 image 任务记录
doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModal(), doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModel(),
null, AiImageDrawingStatusEnum.SUBMIT, null, null, AiImageStatusEnum.SUBMIT, null,
req.getMessageId(), req.getOperateId(), mjOperationName); req.getMessageId(), req.getOperateId(), mjOperationName);
// 提交操作 // 提交操作
midjourneyInteractionsApi.reRoll( midjourneyInteractionsApi.reRoll(
@ -201,9 +201,9 @@ public class AiImageServiceImpl implements AiImageService {
private AiImageDO doSave(String prompt, private AiImageDO doSave(String prompt,
String size, String size,
String model, String model,
String drawingImageUrl, String originalPicUrl,
AiImageDrawingStatusEnum drawingStatusEnum, AiImageStatusEnum statusEnum,
String drawingErrorMessage, String errorMessage,
String mjMessageId, String mjMessageId,
String mjOperationId, String mjOperationId,
String mjOperationName) { String mjOperationName) {
@ -213,13 +213,13 @@ public class AiImageServiceImpl implements AiImageService {
aiImageDO.setId(null); aiImageDO.setId(null);
aiImageDO.setPrompt(prompt); aiImageDO.setPrompt(prompt);
aiImageDO.setSize(size); aiImageDO.setSize(size);
aiImageDO.setModal(model); aiImageDO.setModel(model);
aiImageDO.setUserId(loginUserId); aiImageDO.setUserId(loginUserId);
// TODO @芋艿 如何上传到自己服务器 // TODO @芋艿 如何上传到自己服务器
aiImageDO.setImageUrl(null); aiImageDO.setPicUrl(null);
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus()); aiImageDO.setStatus(statusEnum.getStatus());
aiImageDO.setDrawingImageUrl(drawingImageUrl); aiImageDO.setOriginalPicUrl(originalPicUrl);
aiImageDO.setDrawingErrorMessage(drawingErrorMessage); aiImageDO.setErrorMessage(errorMessage);
// //
aiImageDO.setMjNonceId(mjMessageId); aiImageDO.setMjNonceId(mjMessageId);
aiImageDO.setMjOperationId(mjOperationId); aiImageDO.setMjOperationId(mjOperationId);

View File

@ -11,7 +11,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOpe
import cn.iocoder.yudao.module.ai.convert.AiImageConvert; import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum; import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@ -67,8 +67,8 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
String errorMessage = getErrorMessage(midjourneyMessage); String errorMessage = getErrorMessage(midjourneyMessage);
aiImageMapper.updateByMjNonce(nonceId, aiImageMapper.updateByMjNonce(nonceId,
new AiImageDO() new AiImageDO()
.setDrawingErrorMessage(errorMessage) .setErrorMessage(errorMessage)
.setDrawingStatus(AiImageDrawingStatusEnum.FAIL.getStatus()) .setStatus(AiImageStatusEnum.FAIL.getStatus())
); );
} }
@ -90,22 +90,22 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
imageUrl = midjourneyMessage.getAttachments().get(0).getUrl(); imageUrl = midjourneyMessage.getAttachments().get(0).getUrl();
} }
// 转换状态 // 转换状态
AiImageDrawingStatusEnum drawingStatusEnum = null; AiImageStatusEnum drawingStatusEnum = null;
String generateStatus = midjourneyMessage.getGenerateStatus(); String generateStatus = midjourneyMessage.getGenerateStatus();
if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) { if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiImageDrawingStatusEnum.COMPLETE; drawingStatusEnum = AiImageStatusEnum.COMPLETE;
} else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) { } else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiImageDrawingStatusEnum.IN_PROGRESS; drawingStatusEnum = AiImageStatusEnum.IN_PROGRESS;
} else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) { } else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiImageDrawingStatusEnum.WAITING; drawingStatusEnum = AiImageStatusEnum.WAITING;
} }
// 获取 midjourneyOperations // 获取 midjourneyOperations
List<AiImageMidjourneyOperationsVO> midjourneyOperations = getMidjourneyOperationsList(midjourneyMessage); List<AiImageMidjourneyOperationsVO> midjourneyOperations = getMidjourneyOperationsList(midjourneyMessage);
// 更新数据库 // 更新数据库
aiImageMapper.updateByMjNonce(nonceId, aiImageMapper.updateByMjNonce(nonceId,
new AiImageDO() new AiImageDO()
.setDrawingImageUrl(imageUrl) .setOriginalPicUrl(imageUrl)
.setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus()) .setStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus())
.setMjNonceId(midjourneyMessage.getId()) .setMjNonceId(midjourneyMessage.getId())
.setMjOperations(JsonUtils.toJsonString(midjourneyOperations)) .setMjOperations(JsonUtils.toJsonString(midjourneyOperations))
); );