【增加】AI Image mj 增加 action 操作

This commit is contained in:
cherishsince 2024-06-05 10:27:31 +08:00
parent 79a094be02
commit 776d6e4e1e
5 changed files with 82 additions and 22 deletions

View File

@ -44,5 +44,6 @@ public interface ErrorCodeConstants {
ErrorCode AI_IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "image 不存在!"); ErrorCode AI_IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "image 不存在!");
ErrorCode AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败! {}"); ErrorCode AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败! {}");
ErrorCode AI_IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");
} }

View File

@ -74,8 +74,14 @@ public class AiImageController {
@Operation(summary = "midjourney proxy - 回调通知") @Operation(summary = "midjourney proxy - 回调通知")
@PostMapping("/midjourney-notify") @PostMapping("/midjourney-notify")
@PermitAll @PermitAll
public CommonResult<Boolean> midjourneyNotify( @RequestBody MidjourneyNotifyReqVO notifyReqVO) { public CommonResult<Boolean> midjourneyNotify(@RequestBody MidjourneyNotifyReqVO notifyReqVO) {
return success(imageService.midjourneyNotify(getLoginUserId(), notifyReqVO)); return success(imageService.midjourneyNotify(notifyReqVO));
} }
@Operation(summary = "midjourney - action(放大、缩小、U1、U2...)")
@PostMapping("/midjourney/action")
public CommonResult<Boolean> midjourneyAction(@RequestParam("id") Long imageId,
@RequestParam("customId") String customId) {
return success(imageService.midjourneyAction(getLoginUserId(), imageId, customId));
}
} }

View File

@ -59,10 +59,18 @@ public interface AiImageService {
/** /**
* midjourney proxy - 回调通知 * midjourney proxy - 回调通知
* *
* @param loginUserId
* @param notifyReqVO * @param notifyReqVO
* @return * @return
*/ */
Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO); Boolean midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO);
/**
* midjourney - action(放大缩小U1U2...)
*
* @param loginUserId
* @param imageId
* @param customId
* @return
*/
Boolean midjourneyAction(Long loginUserId, Long imageId, String customId);
} }

View File

@ -15,6 +15,7 @@ import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum; import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum; import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyTaskStatusEnum; import cn.iocoder.yudao.module.ai.client.enums.MidjourneyTaskStatusEnum;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO; import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO; import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
@ -39,9 +40,10 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.AI_IMAGE_NOT_EXISTS;
/** /**
* AI 绘画 Service 实现类 * AI 绘画 Service 实现类
@ -136,30 +138,21 @@ public class AiImageServiceImpl implements AiImageService {
aiImageDO.setUserId(loginUserId); aiImageDO.setUserId(loginUserId);
aiImageDO.setPrompt(req.getPrompt()); aiImageDO.setPrompt(req.getPrompt());
aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform()); aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
// todo @范 平台需要转换(mj 模型一般分版本)
aiImageDO.setModel(null); aiImageDO.setModel(null);
aiImageDO.setWidth(null); aiImageDO.setWidth(null);
aiImageDO.setHeight(null); aiImageDO.setHeight(null);
aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
// 2保存 image // 2保存 image
imageMapper.insert(aiImageDO); imageMapper.insert(aiImageDO);
// 3调用 MidjourneyProxy 提交任务 // 3调用 MidjourneyProxy 提交任务
MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class); MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
imagineReqVO.setNotifyHook(midjourneyNotifyUrl); imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
// 设置 midjourney 扩展参数 // 4设置 midjourney 扩展参数
// --ar 来设置尺寸 imagineReqVO.setState(buildParams(req.getWidth(),
String midjourneySizeParam = String.format(" --ar %s:%s ", req.getWidth(), req.getHeight()); req.getHeight(), req.getVersion(), MidjourneyModelEnum.valueOfModel(req.getModel())));
// --v 版本 // 5提交绘画请求
String midjourneyVersionParam = String.format(" --v %s ", req.getVersion());
// --niji 模型
MidjourneyModelEnum midjourneyModelEnum = MidjourneyModelEnum.valueOfModel(req.getModel());
String midjourneyNijiParam = MidjourneyModelEnum.NIJI == midjourneyModelEnum ? " --niji " : "";
// 设置参数
imagineReqVO.setState(midjourneySizeParam.concat(midjourneyVersionParam).concat(midjourneyNijiParam));
MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO); MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
// 4保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)) // 6保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) { if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription()); throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription());
} }
@ -170,6 +163,8 @@ public class AiImageServiceImpl implements AiImageService {
return aiImageDO.getId(); return aiImageDO.getId();
} }
@Override @Override
public void deleteImageMy(Long id, Long userId) { public void deleteImageMy(Long id, Long userId) {
// 1. 校验是否存在 // 1. 校验是否存在
@ -182,7 +177,7 @@ public class AiImageServiceImpl implements AiImageService {
} }
@Override @Override
public Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO) { public Boolean midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) {
// 1根据 job id 查询关联的 image // 1根据 job id 查询关联的 image
AiImageDO image = imageMapper.selectByJobId(notifyReqVO.getId()); AiImageDO image = imageMapper.selectByJobId(notifyReqVO.getId());
if (image == null) { if (image == null) {
@ -220,6 +215,34 @@ public class AiImageServiceImpl implements AiImageService {
return true; return true;
} }
@Override
@Transactional(rollbackFor = Exception.class)
public Boolean midjourneyAction(Long loginUserId, Long imageId, String customId) {
// 1检查 image
AiImageDO aiImageDO = validateImageExists(imageId);
// 2检查 customId
if (!validateCustomId(customId, aiImageDO.getButtons())) {
throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS);
}
// 3调用 midjourney proxy
midjourneyProxyClient.action(
new MidjourneyActionReqVO()
.setCustomId(customId)
.setTaskId(aiImageDO.getJobId())
.setNotifyHook(midjourneyNotifyUrl)
);
return Boolean.TRUE;
}
private static boolean validateCustomId(String customId, List<MidjourneyNotifyReqVO.Button> buttons) {
for (MidjourneyNotifyReqVO.Button button : buttons) {
if (button.getCustomId().equals(customId)) {
return true;
}
}
return false;
}
private AiImageDO validateImageExists(Long id) { private AiImageDO validateImageExists(Long id) {
AiImageDO image = imageMapper.selectById(id); AiImageDO image = imageMapper.selectById(id);
if (image == null) { if (image == null) {
@ -237,4 +260,25 @@ public class AiImageServiceImpl implements AiImageService {
return SpringUtil.getBean(getClass()); return SpringUtil.getBean(getClass());
} }
/**
* 构建 Midjourney 自定义参数
*
* @param width
* @param height
* @param version
* @param model
* @return
*/
private String buildParams(String width, String height, String version, MidjourneyModelEnum model) {
StringBuilder params = new StringBuilder();
// --ar 来设置尺寸
params.append(String.format(" --ar %s:%s ", width, height));
// --v 版本
params.append(String.format(" --v %s ", version));
// --niji 模型
if (MidjourneyModelEnum.NIJI == model) {
params.append(" --niji ");
}
return params.toString();
}
} }

View File

@ -80,7 +80,8 @@ server:
ai: ai:
midjourney-proxy: midjourney-proxy:
url: https://api.holdai.top/mj url: https://api.holdai.top/mj
notifyUrl: http://7b1aada4.r26.cpolar.top/admin-api/ai/image/midjourney-notify notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify
key: sk-c3qxUCVKsPfdQiYU8440E3Fc8dE5424d9cB124A4Ee2489E3
--- #################### 定时任务相关配置 #################### --- #################### 定时任务相关配置 ####################