【代码优化】AI:MJ 生成图片 ACTION 的优化

This commit is contained in:
YunaiV 2024-06-25 20:36:49 +08:00
parent 098483d2be
commit 4c3add508b
6 changed files with 74 additions and 105 deletions

View File

@ -8,7 +8,8 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.service.image.AiImageService; import cn.iocoder.yudao.module.ai.service.image.AiImageService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
@ -67,30 +68,24 @@ public class AiImageController {
@Operation(summary = "【Midjourney】生成图片") @Operation(summary = "【Midjourney】生成图片")
@PostMapping("/midjourney/imagine") @PostMapping("/midjourney/imagine")
public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO reqVO) { public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiMidjourneyImagineReqVO reqVO) {
if (true) {
imageService.midjourneySync();
return null;
}
Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO); Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO);
return success(imageId); return success(imageId);
} }
@Operation(summary = "Midjourney 生成图片的回调通知", description = "由 Midjourney Proxy 回调") @Operation(summary = "【Midjourney】通知图片进展", description = "由 Midjourney Proxy 回调")
@PostMapping("/midjourney-notify") @PostMapping("/midjourney/notify") // 必须是 POST 方法否则会报错
@PermitAll @PermitAll
public void midjourneyNotify(@RequestBody MidjourneyApi.Notify notify) { public CommonResult<Boolean> midjourneyNotify(@Validated @RequestBody MidjourneyApi.Notify notify) {
imageService.midjourneyNotify(notify); imageService.midjourneyNotify(notify);
}
@Operation(summary = "Midjourney Action", description = "例如说放大、缩小、U1、U2 等")
@GetMapping("/midjourney/action")
@Parameter(name = "id", description = "图片id", example = "1")
@Parameter(name = "customId", description = "操作id", example = "MJ::JOB::upsample::1::85a4b4c1-8835-46c5-a15c-aea34fad1862")
public CommonResult<Boolean> midjourneyAction(@RequestParam("id") Long imageId,
@RequestParam("customId") String customId) {
imageService.midjourneyAction(getLoginUserId(), imageId, customId);
return success(true); return success(true);
} }
@Operation(summary = "【Midjourney】Action 操作(二次生成图片)", description = "例如说放大、缩小、U1、U2 等")
@PostMapping("/midjourney/action")
public CommonResult<Long> midjourneyAction(@Validated @RequestBody AiMidjourneyActionReqVO reqVO) {
Long imageId = imageService.midjourneyAction(getLoginUserId(), reqVO);
return success(imageId);
}
} }

View File

@ -1,31 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
/**
* Midjourneyaction 请求
*
* @author fansili
* @time 2024/5/30 14:02
* @since 1.0
*/
@Data
public class MidjourneyActionReqVO {
@Schema(description = "操作按钮id", required = true)
@NotNull(message = "customId 不能为空!")
private String customId;
@Schema(description = "操作按钮id", required = true)
@NotNull(message = "customId 不能为空!")
private String taskId;
@Schema(description = "通知地址", required = false)
@NotNull(message = "回调地址不能为空!")
private String notifyHook;
@Schema(description = "自定义参数", required = false)
private String state;
}

View File

@ -0,0 +1,20 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - ActionMidjourney Request VO")
@Data
public class AiMidjourneyActionReqVO {
@Schema(description = "图片编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "图片编号不能为空")
private Long id;
@Schema(description = "操作按钮编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "MJ::JOB::variation::4::06aa3e66-0e97-49cc-8201-e0295d883de4")
@NotEmpty(message = "操作按钮编号不能为空")
private String customId;
}

View File

@ -9,7 +9,7 @@ import java.util.List;
@Schema(description = "管理后台 - 绘画生成Midjourney Request VO") @Schema(description = "管理后台 - 绘画生成Midjourney Request VO")
@Data @Data
public class AiImageMidjourneyImagineReqVO { public class AiMidjourneyImagineReqVO {
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "中国神龙") @Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "中国神龙")
@NotEmpty(message = "提示词不能为空!") @NotEmpty(message = "提示词不能为空!")

View File

@ -4,7 +4,8 @@ import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
/** /**
@ -57,7 +58,7 @@ public interface AiImageService {
* @param reqVO 绘制请求 * @param reqVO 绘制请求
* @return 绘画编号 * @return 绘画编号
*/ */
Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO reqVO); Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO);
/** /**
* Midjourney同步图片进展 * Midjourney同步图片进展
@ -74,13 +75,12 @@ public interface AiImageService {
void midjourneyNotify(MidjourneyApi.Notify notify); void midjourneyNotify(MidjourneyApi.Notify notify);
/** /**
* midjourney - action(放大缩小U1U2...) * MidjourneyAction 操作(放大缩小U1U2...)
* *
* @param loginUserId * @param userId 用户编号
* @param imageId * @param reqVO 绘制请求
* @param customId * @return 绘画编号
* @return
*/ */
void midjourneyAction(Long loginUserId, Long imageId, String customId); Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO);
} }

View File

@ -14,7 +14,8 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
@ -149,7 +150,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO reqVO) { public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
// 1. 保存数据库 // 1. 保存数据库
AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()) .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
@ -170,11 +171,8 @@ public class AiImageServiceImpl implements AiImageService {
} }
// 4. 情况二成功更新 taskId 和参数 // 4. 情况二成功更新 taskId 和参数
imageMapper.updateById(new AiImageDO() imageMapper.updateById(new AiImageDO().setId(image.getId())
.setId(image.getId()) .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(reqVO)));
.setTaskId(imagineResponse.result())
.setOptions(BeanUtil.beanToMap(reqVO))
);
return image.getId(); return image.getId();
} }
@ -245,49 +243,36 @@ public class AiImageServiceImpl implements AiImageService {
} }
@Override @Override
public void midjourneyAction(Long loginUserId, Long imageId, String customId) { public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
// 1检查 image // 1.1 检查 image
AiImageDO image = validateImageExists(imageId); AiImageDO image = validateImageExists(reqVO.getId());
// 2检查 customId if (ObjUtil.notEqual(userId, image.getUserId())) {
validateCustomId(customId, image.getButtons()); throw exception(AI_IMAGE_NOT_EXISTS);
// 3调用 midjourney proxy
MidjourneyApi.SubmitResponse submitResponse = midjourneyApi.action(
new MidjourneyApi.ActionRequest(customId, image.getTaskId(), midjourneyNotifyUrl));
// 4检查错误 code (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(submitResponse.code())) {
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitResponse.description());
} }
// 1.2 检查 customId
// 5新增 image 记录(根据 image 新增一个) MidjourneyApi.Button button = CollUtil.findOne(image.getButtons(),
AiImageDO newImage = new AiImageDO(); buttonX -> buttonX.customId().equals(reqVO.getCustomId()));
newImage.setUserId(image.getUserId()); if (button == null) {
newImage.setPrompt(image.getPrompt());
newImage.setPlatform(image.getPlatform());
newImage.setModel(image.getModel());
newImage.setWidth(image.getWidth());
newImage.setHeight(image.getHeight());
newImage.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
newImage.setPublicStatus(image.getPublicStatus());
newImage.setOptions(image.getOptions());
newImage.setTaskId(submitResponse.result());
imageMapper.insert(newImage);
}
private static void validateCustomId(String customId, List<MidjourneyApi.Button> buttons) {
boolean isTrue = false;
for (MidjourneyApi.Button button : buttons) {
if (button.customId().equals(customId)) {
isTrue = true;
break;
}
}
if (!isTrue) {
throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS); throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS);
} }
// 2. 调用 Midjourney Proxy 提交任务
MidjourneyApi.SubmitResponse actionResponse = midjourneyApi.action(
new MidjourneyApi.ActionRequest(button.customId(), image.getTaskId(), midjourneyNotifyUrl));
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(actionResponse.code())) {
String description = actionResponse.description().contains("quota_not_enough") ?
"账户余额不足" : actionResponse.description();
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
}
// 3. 新增 image 记录
AiImageDO newImage = new AiImageDO().setUserId(image.getUserId()).setPublicStatus(false).setPrompt(image.getPrompt())
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform())
.setModel(image.getModel()).setWidth(image.getWidth()).setHeight(image.getHeight())
.setOptions(image.getOptions()).setTaskId(actionResponse.result());
imageMapper.insert(newImage);
return newImage.getId();
} }
/** /**