【增加】niji 模型参数设置

This commit is contained in:
cherishsince 2024-05-31 14:36:07 +08:00
parent 47ff5bf814
commit 8f3076b2ea
2 changed files with 41 additions and 4 deletions

View File

@ -0,0 +1,30 @@
package cn.iocoder.yudao.module.ai.client.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 来源于 midjourney-proxy
*/
@Getter
@AllArgsConstructor
public enum MidjourneyModelEnum {
MIDJOURNEY("midjourney", "midjourney"),
NIJI("Niji", "Niji"),
;
private String model;
private String name;
public static MidjourneyModelEnum valueOfModel(String model) {
for (MidjourneyModelEnum itemEnum : MidjourneyModelEnum.values()) {
if (itemEnum.getModel().equals(model)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + model);
}
}

View File

@ -12,6 +12,7 @@ import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.AiCommonConstants;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
@ -157,10 +158,16 @@ public class AiImageServiceImpl implements AiImageService {
// 3调用 MidjourneyProxy 提交任务
MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
// 设置 midjourney 扩展参数通过 --ar 来设置尺寸
String midjourneySizeParam = String.format("--ar %s:%s", req.getWidth(), req.getHeight());
String midjourneyVersionParam = String.format("--v %s", req.getVersion());
imagineReqVO.setState(midjourneySizeParam.concat(" ").concat(midjourneyVersionParam));
// 设置 midjourney 扩展参数
// --ar 来设置尺寸
String midjourneySizeParam = String.format(" --ar %s:%s ", req.getWidth(), req.getHeight());
// --v 版本
String midjourneyVersionParam = String.format(" --v %s ", req.getVersion());
// --niji 模型
MidjourneyModelEnum midjourneyModelEnum = MidjourneyModelEnum.valueOfModel(req.getModel());
String midjourneyNijiParam = MidjourneyModelEnum.NIJI == midjourneyModelEnum ? " --niji " : "";
// 设置参数
imagineReqVO.setState(midjourneySizeParam.concat(midjourneyVersionParam).concat(midjourneyNijiParam));
MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
// 4保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))