【代码优化】AI:MJ 生成图片的优化

This commit is contained in:
YunaiV 2024-06-25 19:57:37 +08:00
parent 88142ed74c
commit 098483d2be
23 changed files with 333 additions and 653 deletions

View File

@ -21,30 +21,21 @@ public interface ErrorCodeConstants {
// ========== API 聊天模型 1-040-002-000 ========== // ========== API 聊天模型 1-040-002-000 ==========
ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在"); ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在");
ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!"); ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!");
ErrorCode CHAT_ROLE_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认聊天角色");
// ========== API 聊天会话 1-040-003-000 ========== // ========== API 聊天会话 1-040-003-000 ==========
ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "对话不存在!"); ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "对话不存在!");
ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整"); ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整");
ErrorCode CHAT_CONVERSATION_UPDATE_MAX_TOKENS_ERROR = new ErrorCode(1_040_003_002, "更新对话失败,最大 Token 超过上限");
ErrorCode CHAT_CONVERSATION_UPDATE_MAX_CONTEXTS_ERROR = new ErrorCode(1_040_003_002, "更新对话失败,最大 Context 超过上限");
// ========== API 聊天消息 1-040-004-000 ========== // ========== API 聊天消息 1-040-004-000 ==========
ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!"); ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
ErrorCode AI_CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "Stream 对话异常!"); ErrorCode AI_CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "Stream 对话异常!");
// midjourney
ErrorCode AI_MIDJOURNEY_IMAGINE_FAIL = new ErrorCode(1_022_000_040, "midjourney imagine 操作失败!");
ErrorCode AI_MIDJOURNEY_OPERATION_NOT_EXISTS = new ErrorCode(1_022_000_040, "midjourney 操作不存在!");
ErrorCode AI_MIDJOURNEY_MESSAGE_ID_INCORRECT = new ErrorCode(1_022_000_040, "midjourney message id 不正确!");
// ========== API 绘画 1-040-005-000 ========== // ========== API 绘画 1-040-005-000 ==========
ErrorCode AI_IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "image 不存在!"); ErrorCode AI_IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
ErrorCode AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败! {}"); ErrorCode AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
ErrorCode AI_IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}"); ErrorCode AI_IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");
ErrorCode AI_IMAGE_SYSTEM_ACCOUNT_INSUFFICIENT_BALANCE = new ErrorCode(1_022_005_003, "Midjourney 系统账户余额不足!");
} }

View File

@ -12,20 +12,20 @@ import lombok.Getter;
@Getter @Getter
public enum AiImageStatusEnum { public enum AiImageStatusEnum {
IN_PROGRESS("10", "进行中"), IN_PROGRESS(10, "进行中"),
SUCCESS("20", "完成"), SUCCESS(20, "完成"),
FAIL("30", "失败"); FAIL(30, "失败");
/** /**
* 状态 * 状态
*/ */
private final String status; private final Integer status;
/** /**
* 状态名 * 状态名
*/ */
private final String name; private final String name;
public static AiImageStatusEnum valueOfStatus(String status) { public static AiImageStatusEnum valueOfStatus(Integer status) {
for (AiImageStatusEnum statusEnum : AiImageStatusEnum.values()) { for (AiImageStatusEnum statusEnum : AiImageStatusEnum.values()) {
if (statusEnum.getStatus().equals(status)) { if (statusEnum.getStatus().equals(status)) {
return statusEnum; return statusEnum;

View File

@ -3,7 +3,6 @@ package cn.iocoder.yudao.module.ai.config;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.core.model.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;

View File

@ -29,12 +29,17 @@ Authorization: {{token}}
"style": "vivid" "style": "vivid"
} }
### chat midjourney ### 生成图片:生成图片
POST {{baseUrl}}/admin-api/ai/image/midjourney POST {{baseUrl}}/ai/image/midjourney/imagine
Content-Type: application/json Content-Type: application/json
Authorization: {{token}} Authorization: {{token}}
{ {
"prompt": "Cute cartoon style mobile game scene, a colorful camping car with an outdoor table and chairs next to it on the road in a spring forest, the simple structure of the camper van, soft lighting, C4D rendering, 3d model in the style of a cartoon, cute shape, a pastel color scheme, closeup view from the side angle, high resolution, bright colors, a happy atmosphere." "prompt": "中国旗袍",
"model": "midjourney",
"width": "1",
"height": "1",
"version": "6.0",
"base64Array": []
} }

View File

@ -1,14 +1,14 @@
package cn.iocoder.yudao.module.ai.controller.admin.image; package cn.iocoder.yudao.module.ai.controller.admin.image;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.service.image.AiImageService; import cn.iocoder.yudao.module.ai.service.image.AiImageService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
@ -63,19 +63,24 @@ public class AiImageController {
return success(true); return success(true);
} }
// ================ midjourney 接口 ================ // ================ midjourney 专属 ================
@Operation(summary = "Midjourney imagine绘画") @Operation(summary = "【Midjourney】生成图片")
@PostMapping("/midjourney/imagine") @PostMapping("/midjourney/imagine")
public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO req) { public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO reqVO) {
return success(imageService.midjourneyImagine(getLoginUserId(), req)); if (true) {
imageService.midjourneySync();
return null;
}
Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO);
return success(imageId);
} }
@Operation(summary = "Midjourney 回调通知", description = "由 Midjourney Proxy 回调") @Operation(summary = "Midjourney 生成图片的回调通知", description = "由 Midjourney Proxy 回调")
@PostMapping("/midjourney-notify") @PostMapping("/midjourney-notify")
@PermitAll @PermitAll
public void midjourneyNotify(@RequestBody MidjourneyNotifyReqVO notifyReqVO) { public void midjourneyNotify(@RequestBody MidjourneyApi.Notify notify) {
imageService.midjourneyNotify(notifyReqVO); imageService.midjourneyNotify(notify);
} }
@Operation(summary = "Midjourney Action", description = "例如说放大、缩小、U1、U2 等") @Operation(summary = "Midjourney Action", description = "例如说放大、缩小、U1、U2 等")

View File

@ -1,40 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.List;
// TODO @fan待定
/**
* midjourney req
*
* @author fansili
* @time 2024/4/28 17:42
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiImageMidjourneyImagineReqVO {
@Schema(description = "提示词")
@NotNull(message = "提示词不能为空!")
private String prompt;
@Schema(description = "模型(midjourney、niji)")
private String model;
@Schema(description = "图片宽度 --ar 设置")
private Integer width;
@Schema(description = "图片高度 --ar 设置")
private Integer height;
@Schema(description = "版本号 --v 设置")
private String version;
@Schema(description = "垫图(参考图)base64数组")
private List<String> base64Array;
}

View File

@ -1,30 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* mj 保存 components 记录
*
* "components": [
* {
* "custom_id": "MJ::JOB::upsample::1::5d32f4e8-8d2f-4bef-82d8-bf517e3c3660",
* "style": 2,
* "label": "U1",
* "type": 2
* },
* ]
*
* @author fansili
* @time 2024/5/8 14:44
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiImageMidjourneyOperationsVO {
private String custom_id;
private String style;
private String label;
private String type;
}

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo; package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data; import lombok.Data;
@ -46,14 +47,11 @@ public class AiImageRespVO {
@Schema(description = "绘制参数") @Schema(description = "绘制参数")
private Map<String, String> options; private Map<String, String> options;
@Schema(description = "绘画 response")
private MidjourneyNotifyReqVO response;
// TODO @fan进度是百分比还是一个数字哈感觉这个可以统一成通用字段 // TODO @fan进度是百分比还是一个数字哈感觉这个可以统一成通用字段
@Schema(description = "mj 进度") @Schema(description = "mj 进度")
private String progress; private String progress;
@Schema(description = "mj buttons 按钮") @Schema(description = "mj buttons 按钮")
private List<MidjourneyNotifyReqVO.Button> buttons; private List<MidjourneyApi.Button> buttons;
} }

View File

@ -1,33 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import java.util.List;
// TODO @fan待定
/**
* MidjourneyImagine 请求
*
* @author fansili
* @time 2024/5/30 14:02
* @since 1.0
*/
@Data
public class MidjourneyImagineReqVO {
@Schema(description = "垫图(参考图)base64数组", required = false)
private List<String> base64Array;
@Schema(description = "通知地址", required = false)
@NotNull(message = "回调地址不能为空!")
private String notifyHook;
@Schema(description = "提示词", required = true)
@NotNull(message = "提示词不能为空!")
private String prompt;
@Schema(description = "自定义参数", required = false)
private String state;
}

View File

@ -1,75 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.List;
/**
* Midjourney Proxy 通知回调
*
* - Midjourney Proxy通知回调 bean com.github.novicezk.midjourney.support.Task
* - 毫秒 api 通知回调文档地址https://gpt-best.apifox.cn/doc-3530863
*
* @author fansili
* @time 2024/5/31 10:37
* @since 1.0
*/
@Data
public class MidjourneyNotifyReqVO {
@Schema(description = "job id")
private String id;
@Schema(description = "任务类型 MidjourneyTaskActionEnum")
private String action;
@Schema(description = "任务状态 MidjourneyTaskStatusEnum")
private String status;
@Schema(description = "提示词")
private String prompt;
@Schema(description = "提示词-英文")
private String promptEn;
@Schema(description = "任务描述")
private String description;
@Schema(description = "自定义参数")
private String state;
@Schema(description = "提交时间")
private Long submitTime;
@Schema(description = "开始执行时间")
private Long startTime;
@Schema(description = "结束时间")
private Long finishTime;
@Schema(description = "图片url")
private String imageUrl;
@Schema(description = "任务进度")
private String progress;
@Schema(description = "失败原因")
private String failReason;
@Schema(description = "任务完成后的可执行按钮")
private List<Button> buttons;
@Data
public static class Button {
@Schema(description = "MJ::JOB::upsample::1::85a4b4c1-8835-46c5-a15c-aea34fad1862 动作标识")
private String customId;
@Schema(description = "图标 emoji")
private String emoji;
@Schema(description = "Make Variations 文本")
private String label;
@Schema(description = "类型,系统内部使用")
private String type;
@Schema(description = "样式: 2Primary、3Green")
private String style;
}
}

View File

@ -1,30 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.Map;
// TODO @fan待定
/**
* MidjourneyImagine 请求
*
* @author fansili
* @time 2024/5/30 14:02
* @since 1.0
*/
@Data
public class MidjourneySubmitRespVO {
@Schema(description = "状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)")
private String code;
@Schema(description = "描述")
private String description;
@Schema(description = "扩展字段")
private Map<String, Object> properties;
@Schema(description = "任务ID")
private String result;
}

View File

@ -0,0 +1,38 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import java.util.List;
@Schema(description = "管理后台 - 绘画生成Midjourney Request VO")
@Data
public class AiImageMidjourneyImagineReqVO {
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "中国神龙")
@NotEmpty(message = "提示词不能为空!")
private String prompt;
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "midjourney")
@NotEmpty(message = "模型不能为空")
private String model; // 参考 MidjourneyApi.ModelEnum
@Schema(description = "图片宽度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "图片宽度不能为空")
private Integer width;
@Schema(description = "图片高度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "图片高度不能为空")
private Integer height;
@Schema(description = "版本号", requiredMode = Schema.RequiredMode.REQUIRED, example = "6.0")
@NotEmpty(message = "版本号不能为空")
private String version;
// TODO @fan参考图建议用 referImageUrl
@Schema(description = "垫图(参考图)base64数组")
private List<String> base64Array;
}

View File

@ -1,8 +1,8 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.image; package cn.iocoder.yudao.module.ai.dal.dataobject.image;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils; import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO; import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
@ -28,8 +28,6 @@ import java.util.Map;
@Data @Data
public class AiImageDO extends BaseDO { public class AiImageDO extends BaseDO {
// TODO @fan1使用 java 注释哈不要注解2关联枚举字段要关联到对应类参考 AiChatMessageDO 的注释
/** /**
* 编号 * 编号
*/ */
@ -76,7 +74,7 @@ public class AiImageDO extends BaseDO {
* *
* 枚举 {@link AiImageStatusEnum} * 枚举 {@link AiImageStatusEnum}
*/ */
private String status; private Integer status;
/** /**
* 图片地址 * 图片地址
@ -96,23 +94,11 @@ public class AiImageDO extends BaseDO {
@TableField(typeHandler = JacksonTypeHandler.class) @TableField(typeHandler = JacksonTypeHandler.class)
private Map<String, Object> options; private Map<String, Object> options;
/**
* 绘画 response
*/
@TableField(typeHandler = JacksonTypeHandler.class)
private MidjourneyNotifyReqVO response;
// TODO @fan这个建议 Double
/**
* mj 进度(10%50%100%)
*/
private String progress;
/** /**
* mj buttons 按钮 * mj buttons 按钮
*/ */
@TableField(typeHandler = ButtonTypeHandler.class) @TableField(typeHandler = ButtonTypeHandler.class)
private List<MidjourneyNotifyReqVO.Button> buttons; private List<MidjourneyApi.Button> buttons;
/** /**
* midjourney proxy 关联的 task id * midjourney proxy 关联的 task id
@ -124,12 +110,11 @@ public class AiImageDO extends BaseDO {
*/ */
private String errorMessage; private String errorMessage;
// TODO @芋艿看看是不是 MidjourneyNotifyReqVO.Button 搞到 MJ API
public static class ButtonTypeHandler extends AbstractJsonTypeHandler<Object> { public static class ButtonTypeHandler extends AbstractJsonTypeHandler<Object> {
@Override @Override
protected Object parse(String json) { protected Object parse(String json) {
return JsonUtils.parseArray(json, MidjourneyNotifyReqVO.Button.class); return JsonUtils.parseArray(json, MidjourneyApi.Button.class);
} }
@Override @Override

View File

@ -1,13 +1,10 @@
package cn.iocoder.yudao.module.ai.dal.mysql.image; package cn.iocoder.yudao.module.ai.dal.mysql.image;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.List; import java.util.List;
@ -20,52 +17,19 @@ import java.util.List;
@Mapper @Mapper
public interface AiImageMapper extends BaseMapperX<AiImageDO> { public interface AiImageMapper extends BaseMapperX<AiImageDO> {
// TODO @fan这个建议直接使用 updateservice 拼接要改的状态哈 default AiImageDO selectByTaskId(String taskId) {
return this.selectOne(AiImageDO::getTaskId, taskId);
/**
* 更新 - 根据 messageId
*
* @param mjNonceId
* @param aiImageDO
*/
default void updateByMjNonce(Long mjNonceId, AiImageDO aiImageDO) {
// this.update(aiImageDO, new LambdaQueryWrapperX<AiImageDO>().eq(AiImageDO::getMjNonceId, mjNonceId));
return;
} }
/**
* 查询 - 根据 job id
*
* @param id
* @return
*/
default AiImageDO selectByJobId(String id) {
return this.selectOne(new LambdaQueryWrapperX<AiImageDO>().eq(AiImageDO::getTaskId, id));
}
/**
* 查询 - page
*
* @param userId
* @param pageReqVO
* @return
*/
default PageResult<AiImageDO> selectPage(Long userId, PageParam pageReqVO) { default PageResult<AiImageDO> selectPage(Long userId, PageParam pageReqVO) {
return selectPage(pageReqVO, new LambdaQueryWrapperX<AiImageDO>() return selectPage(pageReqVO, new LambdaQueryWrapperX<AiImageDO>()
.eq(AiImageDO::getUserId, userId) .eq(AiImageDO::getUserId, userId)
.orderByDesc(AiImageDO::getId)); .orderByDesc(AiImageDO::getId));
} }
/** default List<AiImageDO> selectListByStatusAndPlatform(Integer status, String platform) {
* 查询 - 根据 status platform return selectList(AiImageDO::getStatus, status,
* AiImageDO::getPlatform, platform);
* @return
*/
default List<AiImageDO> selectByStatusAndPlatform(AiImageStatusEnum statusEnum, AiPlatformEnum platformEnum) {
return this.selectList(new LambdaUpdateWrapper<AiImageDO>()
.eq(AiImageDO::getStatus, statusEnum.getStatus())
.eq(AiImageDO::getPlatform, platformEnum.getPlatform())
);
} }
} }

View File

@ -1,79 +0,0 @@
package cn.iocoder.yudao.module.ai.job;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.quartz.core.handler.JobHandler;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.ai.service.image.AiImageService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* midjourney job 定时拉去 midjourney 绘制状态
*
* @author fansili
* @time 2024/6/5 14:55
* @since 1.0
*/
@Component
@Slf4j
public class MidjourneyJob implements JobHandler {
// TODO @fan@Resource
@Autowired(required = false)
private MidjourneyApi midjourneyApi;
@Autowired
private AiImageMapper imageMapper;
@Autowired
private AiImageService imageService;
// TODO @fan这个方法建议实现到 AiImageService例如说 midjourneySync返回 int 同步数量
@Override
public String execute(String param) {
// 1获取 midjourney 平台状态在 进行中 image
List<AiImageDO> imageList = imageMapper.selectByStatusAndPlatform(AiImageStatusEnum.IN_PROGRESS, AiPlatformEnum.MIDJOURNEY);
log.info("Midjourney 同步 - 任务数量 {}!", imageList.size());
if (CollUtil.isEmpty(imageList)) {
return "Midjourney 同步 - 数量为空!";
}
log.info("Midjourney 同步 - 开始...");
// 2批量拉去 task 信息
List<MidjourneyApi.NotifyRequest> taskList = midjourneyApi
.listByCondition(CollectionUtils.convertSet(imageList, AiImageDO::getTaskId));
Map<String, MidjourneyApi.NotifyRequest> taskIdMap
= CollectionUtils.convertMap(taskList, MidjourneyApi.NotifyRequest::id);
// 3更新 image 状态
List<AiImageDO> updateImageList = new ArrayList<>();
for (AiImageDO aiImageDO : imageList) {
// 3.1 排除掉空的情况
if (!taskIdMap.containsKey(aiImageDO.getTaskId())) {
log.warn("Midjourney 同步 - {} 任务为空!", aiImageDO.getTaskId());
continue;
}
// TODO @ 3.1 3.2 是不是融合下get然后判空continue
// 3.2 获取通知对象
MidjourneyApi.NotifyRequest notifyRequest = taskIdMap.get(aiImageDO.getTaskId());
// 3.2 构建更新对象
// TODO @fan建议 List<MidjourneyNotifyReqVO> 作为 imageService 去更新
// TODO @芋艿 BeanUtils.toBean 转换为 null
updateImageList.add(imageService.buildUpdateImage(aiImageDO.getId(),
JsonUtils.parseObject(JsonUtils.toJsonString(notifyRequest), MidjourneyNotifyReqVO.class)));
}
// 4批了更新 updateImageList
imageMapper.updateBatch(updateImageList);
return "Midjourney 同步 - ".concat(String.valueOf(updateImageList.size())).concat(" 任务!");
}
}

View File

@ -0,0 +1,28 @@
package cn.iocoder.yudao.module.ai.job.image;
import cn.iocoder.yudao.framework.quartz.core.handler.JobHandler;
import cn.iocoder.yudao.module.ai.service.image.AiImageService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
/**
* Midjourney 同步 Job定时拉去 midjourney 绘制状态
*
* @author fansili
*/
@Component
@Slf4j
public class MidjourneySyncJob implements JobHandler {
@Resource
private AiImageService imageService;
@Override
public String execute(String param) {
Integer count = imageService.midjourneySync();
log.info("[execute][同步 Midjourney ({}) 个]", count);
return String.format("同步 Midjourney %s 个", count);
}
}

View File

@ -1,10 +1,10 @@
package cn.iocoder.yudao.module.ai.service.image; package cn.iocoder.yudao.module.ai.service.image;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
/** /**
@ -40,15 +40,6 @@ public interface AiImageService {
*/ */
Long drawImage(Long userId, AiImageDrawReqVO drawReqVO); Long drawImage(Long userId, AiImageDrawReqVO drawReqVO);
/**
* Midjourney imagine绘画
*
* @param userId 用户编号
* @param imagineReqVO 绘制请求
* @return 绘画编号
*/
Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO imagineReqVO);
/** /**
* 删除我的绘画记录 * 删除我的绘画记录
* *
@ -57,22 +48,30 @@ public interface AiImageService {
*/ */
void deleteImageMy(Long id, Long userId); void deleteImageMy(Long id, Long userId);
/** // ================ midjourney 专属 ================
* midjourney proxy - 回调通知
*
* @param notifyReqVO
* @return
*/
void midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO);
/** /**
* 构建 midjourney - 更新对象 * Midjourney生成图片
* *
* @param imageId * @param userId 用户编号
* @param notifyReqVO * @param reqVO 绘制请求
* @return * @return 绘画编号
*/ */
AiImageDO buildUpdateImage(Long imageId, MidjourneyNotifyReqVO notifyReqVO); Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO reqVO);
/**
* Midjourney同步图片进展
*
* @return 同步成功数量
*/
Integer midjourneySync();
/**
* Midjourney通知图片进展
*
* @param notify 通知
*/
void midjourneyNotify(MidjourneyApi.Notify notify);
/** /**
* midjourney - action(放大缩小U1U2...) * midjourney - action(放大缩小U1U2...)
@ -83,4 +82,5 @@ public interface AiImageService {
* @return * @return
*/ */
void midjourneyAction(Long loginUserId, Long imageId, String customId); void midjourneyAction(Long loginUserId, Long imageId, String customId);
} }

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.bean.BeanUtil; import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.codec.Base64; import cn.hutool.core.codec.Base64;
import cn.hutool.core.exceptions.ExceptionUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.map.MapUtil; import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
@ -14,8 +14,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
@ -29,15 +28,17 @@ import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.util.List; import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
/** /**
@ -51,12 +52,16 @@ public class AiImageServiceImpl implements AiImageService {
@Resource @Resource
private AiImageMapper imageMapper; private AiImageMapper imageMapper;
@Resource @Resource
private FileApi fileApi; private FileApi fileApi;
@Resource @Resource
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Autowired(required = false)
@Resource
private MidjourneyApi midjourneyApi; private MidjourneyApi midjourneyApi;
@Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}") @Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}")
private String midjourneyNotifyUrl; private String midjourneyNotifyUrl;
@ -74,7 +79,7 @@ public class AiImageServiceImpl implements AiImageService {
public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) { public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
// 1. 保存数据库 // 1. 保存数据库
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setWidth(drawReqVO.getWidth()).setHeight(drawReqVO.getHeight()).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
imageMapper.insert(image); imageMapper.insert(image);
// 2. 异步绘制后续前端通过返回的 id 进行轮询结果 // 2. 异步绘制后续前端通过返回的 id 进行轮询结果
getSelf().executeDrawImage(image, drawReqVO); getSelf().executeDrawImage(image, drawReqVO);
@ -121,47 +126,6 @@ public class AiImageServiceImpl implements AiImageService {
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
} }
@Override
@Transactional(rollbackFor = Exception.class)
public Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO req) {
// 1构建 AiImageDO 保存
AiImageDO image = new AiImageDO()
.setUserId(userId)
.setPrompt(req.getPrompt())
.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform())
.setModel(req.getModel())
.setWidth(req.getWidth())
.setHeight(req.getHeight())
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
imageMapper.insert(image);
// 3调用 MidjourneyProxy 提交任务
// 3.1设置 midjourney 扩展参数
MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(null, midjourneyNotifyUrl, req.getPrompt(),
buildParams(req.getWidth(), req.getHeight(), req.getVersion(),
MidjourneyApi.ModelEnum.valueOfModel(req.getModel())));
// 3.2提交绘画请求
// TODO @fan5 这里失败的情况到底抛出异常还是 RespVO可以参考 OpenAI API 封装
MidjourneyApi.SubmitResponse submitResponse = midjourneyApi.imagine(imagineRequest);
// 4保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(submitResponse.code())) {
if (submitResponse.description().contains("quota_not_enough")) {
throw exception(AI_IMAGE_SYSTEM_ACCOUNT_INSUFFICIENT_BALANCE, submitResponse.description());
}
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitResponse.description());
}
// 4.1更新 taskId 和参数
imageMapper.updateById(new AiImageDO()
.setId(image.getId())
.setTaskId(submitResponse.result())
.setOptions(BeanUtil.beanToMap(req))
);
return image.getId();
}
@Override @Override
public void deleteImageMy(Long id, Long userId) { public void deleteImageMy(Long id, Long userId) {
// 1. 校验是否存在 // 1. 校验是否存在
@ -173,50 +137,111 @@ public class AiImageServiceImpl implements AiImageService {
imageMapper.deleteById(id); imageMapper.deleteById(id);
} }
@Override private AiImageDO validateImageExists(Long id) {
public void midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) { AiImageDO image = imageMapper.selectById(id);
// 1根据 job id 查询关联的 image
AiImageDO image = imageMapper.selectByJobId(notifyReqVO.getId());
if (image == null) { if (image == null) {
log.warn("midjourneyNotify 回调的 jobId 不存在! jobId: {}", notifyReqVO.getId()); throw exception(AI_IMAGE_NOT_EXISTS);
} }
// 2转换状态 return image;
AiImageDO updateImage = buildUpdateImage(image.getId(), notifyReqVO);
// 3更新 image 状态
imageMapper.updateById(updateImage);
} }
public AiImageDO buildUpdateImage(Long imageId, MidjourneyNotifyReqVO notifyReqVO) { // ================ midjourney 专属 ================
// 1转换状态
String imageStatus = null; @Override
if (StrUtil.isNotBlank(notifyReqVO.getStatus())) { @Transactional(rollbackFor = Exception.class)
MidjourneyApi.TaskStatusEnum taskStatusEnum = MidjourneyApi.TaskStatusEnum.valueOf(notifyReqVO.getStatus()); public Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO reqVO) {
// 1. 保存数据库
AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
imageMapper.insert(image);
// 2. 调用 Midjourney Proxy 提交任务
MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(
null, midjourneyNotifyUrl, reqVO.getPrompt(),
MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(), reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel()));
MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest);
// 3. 情况一失败抛出业务异常
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(imagineResponse.code())) {
String description = imagineResponse.description().contains("quota_not_enough") ?
"账户余额不足" : imagineResponse.description();
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
}
// 4. 情况二成功更新 taskId 和参数
imageMapper.updateById(new AiImageDO()
.setId(image.getId())
.setTaskId(imagineResponse.result())
.setOptions(BeanUtil.beanToMap(reqVO))
);
return image.getId();
}
@Override
public Integer midjourneySync() {
// 1.1 获取 Midjourney 平台状态在 进行中 image
List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
if (CollUtil.isEmpty(imageList)) {
return 0;
}
// 1.2 调用 Midjourney Proxy 获取任务进展
List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(convertSet(imageList, AiImageDO::getTaskId));
Map<String, MidjourneyApi.Notify> taskMap = convertMap(taskList, MidjourneyApi.Notify::id);
// 2. 逐个处理更新进展
int count = 0;
for (AiImageDO image : imageList) {
MidjourneyApi.Notify notify = taskMap.get(image.getTaskId());
if (notify == null) {
log.error("[midjourneySync][image({}) 查询不到进展]", image);
continue;
}
count++;
updateMidjourneyStatus(image, notify);
}
return count;
}
@Override
public void midjourneyNotify(MidjourneyApi.Notify notify) {
// 1. 校验 image 存在
AiImageDO image = imageMapper.selectByTaskId(notify.id());
if (image == null) {
log.warn("[midjourneyNotify][回调任务({}) 不存在]", notify.id());
return;
}
// 2. 更新状态
updateMidjourneyStatus(image, notify);
}
private void updateMidjourneyStatus(AiImageDO image, MidjourneyApi.Notify notify) {
// 1. 转换状态
Integer status = null;
if (StrUtil.isNotBlank(notify.status())) {
MidjourneyApi.TaskStatusEnum taskStatusEnum = MidjourneyApi.TaskStatusEnum.valueOf(notify.status());
if (MidjourneyApi.TaskStatusEnum.SUCCESS == taskStatusEnum) { if (MidjourneyApi.TaskStatusEnum.SUCCESS == taskStatusEnum) {
imageStatus = AiImageStatusEnum.SUCCESS.getStatus(); status = AiImageStatusEnum.SUCCESS.getStatus();
} else if (MidjourneyApi.TaskStatusEnum.FAILURE == taskStatusEnum) { } else if (MidjourneyApi.TaskStatusEnum.FAILURE == taskStatusEnum) {
imageStatus = AiImageStatusEnum.FAIL.getStatus(); status = AiImageStatusEnum.FAIL.getStatus();
} }
} }
// 2上传图片 // 2. 上传图片
String filePath = null; String picUrl = null;
if (!StrUtil.isBlank(notifyReqVO.getImageUrl())) { if (StrUtil.isNotBlank(notify.imageUrl())) {
try { try {
filePath = fileApi.createFile(HttpUtil.downloadBytes(notifyReqVO.getImageUrl())); picUrl = fileApi.createFile(HttpUtil.downloadBytes(notify.imageUrl()));
} catch (Exception e) { } catch (Exception e) {
log.warn("midjourneyNotify 图片上传失败! {} 异常:{}", notifyReqVO.getImageUrl(), ExceptionUtil.getMessage(e)); picUrl = notify.imageUrl();
log.warn("[updateMidjourneyStatus][图片({}) 地址({}) 上传失败]", image.getId(), notify.imageUrl(), e);
} }
} }
// 3更新 image 状态 // 3. 更新 image 状态
return new AiImageDO() imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(status)
.setId(imageId) .setPicUrl(picUrl).setButtons(notify.buttons()).setErrorMessage(notify.failReason()));
.setStatus(imageStatus)
.setPicUrl(filePath)
.setProgress(notifyReqVO.getProgress())
.setResponse(notifyReqVO)
.setButtons(notifyReqVO.getButtons())
.setErrorMessage(notifyReqVO.getFailReason());
} }
@Override @Override
@ -236,7 +261,6 @@ public class AiImageServiceImpl implements AiImageService {
// 5新增 image 记录(根据 image 新增一个) // 5新增 image 记录(根据 image 新增一个)
AiImageDO newImage = new AiImageDO(); AiImageDO newImage = new AiImageDO();
newImage.setId(null);
newImage.setUserId(image.getUserId()); newImage.setUserId(image.getUserId());
newImage.setPrompt(image.getPrompt()); newImage.setPrompt(image.getPrompt());
@ -248,20 +272,15 @@ public class AiImageServiceImpl implements AiImageService {
newImage.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); newImage.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
newImage.setPublicStatus(image.getPublicStatus()); newImage.setPublicStatus(image.getPublicStatus());
newImage.setPicUrl(null);
newImage.setProgress(null);
newImage.setButtons(null);
newImage.setOptions(image.getOptions()); newImage.setOptions(image.getOptions());
newImage.setResponse(image.getResponse());
newImage.setTaskId(submitResponse.result()); newImage.setTaskId(submitResponse.result());
newImage.setErrorMessage(null);
imageMapper.insert(newImage); imageMapper.insert(newImage);
} }
private static void validateCustomId(String customId, List<MidjourneyNotifyReqVO.Button> buttons) { private static void validateCustomId(String customId, List<MidjourneyApi.Button> buttons) {
boolean isTrue = false; boolean isTrue = false;
for (MidjourneyNotifyReqVO.Button button : buttons) { for (MidjourneyApi.Button button : buttons) {
if (button.getCustomId().equals(customId)) { if (button.customId().equals(customId)) {
isTrue = true; isTrue = true;
break; break;
} }
@ -271,14 +290,6 @@ public class AiImageServiceImpl implements AiImageService {
} }
} }
private AiImageDO validateImageExists(Long id) {
AiImageDO image = imageMapper.selectById(id);
if (image == null) {
throw exception(AI_IMAGE_NOT_EXISTS);
}
return image;
}
/** /**
* 获得自身的代理对象解决 AOP 生效问题 * 获得自身的代理对象解决 AOP 生效问题
* *
@ -288,28 +299,4 @@ public class AiImageServiceImpl implements AiImageService {
return SpringUtil.getBean(getClass()); return SpringUtil.getBean(getClass());
} }
// TODO @fan这个是不是应该放在 MJ API 的封装里面搞哈
/**
* 构建 Midjourney 自定义参数
*
* @param width
* @param height
* @param version
* @param model
* @return
*/
private String buildParams(Integer width, Integer height, String version, MidjourneyApi.ModelEnum model) {
StringBuilder params = new StringBuilder();
// --ar 来设置尺寸
params.append(String.format(" --ar %s:%s ", width, height));
// --niji 模型
if (MidjourneyApi.ModelEnum.NIJI == model) {
params.append(String.format(" --niji %s ", version));
} else {
// --v 版本
params.append(String.format(" --v %s ", version));
}
return params.toString();
}
} }

View File

@ -50,34 +50,6 @@
<artifactId>junit</artifactId> <artifactId>junit</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<!-- TODO fan这里包要进一步减少 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<!-- https://mvnrepository.com/artifact/com.squareup.okhttp3/okhttp -->
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
</dependency>
<dependency>
<groupId>io.projectreactor.netty</groupId>
<artifactId>reactor-netty</artifactId>
</dependency>
<dependency>
<groupId>net.dv8tion</groupId>
<artifactId>JDA</artifactId>
<version>5.0.0-beta.21</version>
<!-- <exclusions>-->
<!-- <exclusion>-->
<!-- <groupId>club.minnced</groupId>-->
<!-- <artifactId>opus-java</artifactId>-->
<!-- </exclusion>-->
<!-- </exclusions>-->
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -16,66 +16,54 @@ import java.util.List;
import java.util.Map; import java.util.Map;
/** /**
* Midjourney api * Midjourney API
* *
* @author fansili * @author fansili
* @time 2024/6/11 15:46
* @since 1.0 * @since 1.0
*/ */
@Slf4j @Slf4j
public class MidjourneyApi { public class MidjourneyApi {
private static final String URI_IMAGINE = "/submit/imagine";
private static final String URI_ACTON = "/submit/action";
private static final String URI_LIST_BY_CONDITION = "/task/list-by-condition";
private final WebClient webClient; private final WebClient webClient;
private final MidjourneyConfig midjourneyConfig;
public MidjourneyApi(MidjourneyConfig midjourneyConfig) { public MidjourneyApi(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig;
this.webClient = WebClient.builder() this.webClient = WebClient.builder()
.baseUrl(midjourneyConfig.getUrl()) .baseUrl(midjourneyConfig.getUrl())
.defaultHeaders(ApiUtils.getJsonContentHeaders(midjourneyConfig.getKey())) .defaultHeaders(ApiUtils.getJsonContentHeaders(midjourneyConfig.getKey()))
.build(); .build();
} }
/** /**
* imagine - 根据提示词提交绘画任务 * imagine - 根据提示词提交绘画任务
* *
* @param imagineReqVO * @param request 请求
* @return * @return 提交结果
*/ */
public SubmitResponse imagine(ImagineRequest imagineReqVO) { public SubmitResponse imagine(ImagineRequest request) {
// 1发送 post 请求 String response = post("/submit/imagine", request);
String res = post(URI_IMAGINE, imagineReqVO); return JsonUtils.parseObject(response, SubmitResponse.class);
// 2转换 resp
return JsonUtils.parseObject(res, SubmitResponse.class);
} }
/** /**
* action - 放大缩小U1U2... * action - 放大缩小U1U2...
* *
* @param actionReqVO * @param request 请求
* @return 提交结果
*/ */
public SubmitResponse action(ActionRequest actionReqVO) { public SubmitResponse action(ActionRequest request) {
// 1发送 post 请求 String res = post("/submit/action", request);
String res = post(URI_ACTON, actionReqVO);
// 2转换 resp
return JsonUtils.parseObject(res, SubmitResponse.class); return JsonUtils.parseObject(res, SubmitResponse.class);
} }
/** /**
* 批量查询 task 任务 * 批量查询 task 任务
* *
* @param taskIds * @param ids 任务编号数组
* @return * @return task 任务
*/ */
public List<NotifyRequest> listByCondition(Collection<String> taskIds) { public List<Notify> getTaskList(Collection<String> ids) {
// 1发送 post 请求 String res = post("/task/list-by-condition", ImmutableMap.of("ids", ids));
String res = post(URI_LIST_BY_CONDITION, ImmutableMap.of("ids", taskIds)); return JsonUtils.parseArray(res, Notify.class);
// 2转换 对象
return JsonUtils.parseArray(res, NotifyRequest.class);
} }
private String post(String uri, Object body) { private String post(String uri, Object body) {
@ -94,12 +82,12 @@ public class MidjourneyApi {
.block(); .block();
} }
// ====== record 结构 // ========== record 结构 ==========
/** /**
* Midjourney - Imagine 请求 * Imagine 请求生成图片
* *
* @param base64Array 垫图(参考图)base64数 * @param base64Array 垫图(参考图) base64数
* @param notifyHook 通知地址 * @param notifyHook 通知地址
* @param prompt 提示词 * @param prompt 提示词
* @param state 自定义参数 * @param state 自定义参数
@ -108,10 +96,24 @@ public class MidjourneyApi {
String notifyHook, String notifyHook,
String prompt, String prompt,
String state) { String state) {
public static String buildState(Integer width, Integer height, String version, String model) {
StringBuilder params = new StringBuilder();
// --ar 来设置尺寸
params.append(String.format(" --ar %s:%s ", width, height));
// --niji 模型
if (MidjourneyApi.ModelEnum.NIJI.getModel().equals(model)) {
params.append(String.format(" --niji %s ", version));
} else {
params.append(String.format(" --v %s ", version));
}
return params.toString();
}
} }
/** /**
* Midjourney - Action 请求 * Action 请求
* *
* @param customId 操作按钮id * @param customId 操作按钮id
* @param taskId 操作按钮id * @param taskId 操作按钮id
@ -124,7 +126,7 @@ public class MidjourneyApi {
} }
/** /**
* Midjourney - Submit 返回 * Submit 统一返回
* *
* @param code 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) * @param code 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)
* @param description 描述 * @param description 描述
@ -138,7 +140,7 @@ public class MidjourneyApi {
} }
/** /**
* Midjourney - 通知 request * 通知 request
* *
* @param id job id * @param id job id
* @param action 任务类型 {@link TaskActionEnum} * @param action 任务类型 {@link TaskActionEnum}
@ -155,47 +157,47 @@ public class MidjourneyApi {
* @param failReason 失败原因 * @param failReason 失败原因
* @param buttons 任务完成后的可执行按钮 * @param buttons 任务完成后的可执行按钮
*/ */
public record NotifyRequest(String id, public record Notify(String id,
String action, String action,
String status, String status,
String prompt, String prompt,
String promptEn, String promptEn,
String description, String description,
String state, String state,
Long submitTime, Long submitTime,
Long startTime, Long startTime,
Long finishTime, Long finishTime,
String imageUrl, String imageUrl,
String progress, String progress,
String failReason, String failReason,
List<Button> buttons) { List<Button> buttons) {
/**
* button
*
* @param customId MJ::JOB::upsample::1::85a4b4c1-8835-46c5-a15c-aea34fad1862 动作标识
* @param emoji 图标 emoji
* @param label Make Variations 文本
* @param type 类型系统内部使用
* @param style 样式: 2Primary3Green
*/
public record Button(String customId,
String emoji,
String label,
String type,
String style) {
}
} }
// ====== enums /**
* button
*
* @param customId MJ::JOB::upsample::1::85a4b4c1-8835-46c5-a15c-aea34fad1862 动作标识
* @param emoji 图标 emoji
* @param label Make Variations 文本
* @param type 类型系统内部使用
* @param style 样式: 2Primary3Green
*/
public record Button(String customId,
String emoji,
String label,
String type,
String style) {
}
// ============ enums ============
/** /**
* Midjourney - 模型 * 模型枚举
*/ */
@AllArgsConstructor @AllArgsConstructor
@Getter @Getter
@ -203,24 +205,15 @@ public class MidjourneyApi {
MIDJOURNEY("midjourney", "midjourney"), MIDJOURNEY("midjourney", "midjourney"),
NIJI("niji", "niji"), NIJI("niji", "niji"),
; ;
private String model; private final String model;
private String name; private final String name;
public static ModelEnum valueOfModel(String model) {
for (ModelEnum itemEnum : ModelEnum.values()) {
if (itemEnum.getModel().equals(model)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + model);
}
} }
/** /**
* Midjourney - 提交返回的状态码 * 提交返回的状态码的枚举
*/ */
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
@ -239,64 +232,68 @@ public class MidjourneyApi {
private final String code; private final String code;
private final String name; private final String name;
} }
/** /**
* Midjourney - action * Action 枚举
*/ */
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public enum TaskActionEnum { public enum TaskActionEnum {
/** /**
* 生成图片. * 生成图片
*/ */
IMAGINE, IMAGINE,
/** /**
* 选中放大. * 选中放大
*/ */
UPSCALE, UPSCALE,
/** /**
* 选中其中的一张图生成四张相似的. * 选中其中的一张图生成四张相似的
*/ */
VARIATION, VARIATION,
/** /**
* 重新执行. * 重新执行
*/ */
REROLL, REROLL,
/** /**
* 图转prompt. * 图转 prompt
*/ */
DESCRIBE, DESCRIBE,
/** /**
* 多图混合. * 多图混合
*/ */
BLEND BLEND
} }
/** /**
* Midjourney - 任务状态 * 任务状态枚举
*/ */
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public enum TaskStatusEnum { public enum TaskStatusEnum {
/** /**
* 未启动. * 未启动
*/ */
NOT_START(0), NOT_START(0),
/** /**
* 已提交. * 已提交
*/ */
SUBMITTED(1), SUBMITTED(1),
/** /**
* 执行中. * 执行中
*/ */
IN_PROGRESS(3), IN_PROGRESS(3),
/** /**
* 失败. * 失败
*/ */
FAILURE(4), FAILURE(4),
/** /**
* 成功. * 成功
*/ */
SUCCESS(4); SUCCESS(4);

View File

@ -2,11 +2,9 @@
* model 接入各种大模型对标 https://github.com/spring-projects/spring-ai/tree/main/models * model 接入各种大模型对标 https://github.com/spring-projects/spring-ai/tree/main/models
* *
* 1. yiyan 百度文心一言 * 1. yiyan 百度文心一言
* 2. TODO 芋艿 * 2. tongyi 阿里通义千问对标 spring-cloud-alibaba 提供的 ai TODO 芋艿未来直接使用它
* tongyi 阿里通义千问对标 spring-cloud-alibaba 提供的 ai * 3. xinghuo 讯飞星火自己实现
* 2.2 * 4. midjourney Midjourney接入 https://github.com/novicezk/midjourney-proxy 实现
* 2.3 xinghuo 讯飞星火自己实现 * 5. suno TODO 芋艿
* 2.4 openai OpenAIChatGPT拷贝 spring-ai 提供的 models/openai
* 2.5 midjourney Midjourney参考 https://github.com/novicezk/midjourney-proxy 实现
*/ */
package cn.iocoder.yudao.framework.ai.core.model; package cn.iocoder.yudao.framework.ai.core.model;

View File

@ -76,13 +76,6 @@ server:
enabled: true enabled: true
charset: UTF-8 charset: UTF-8
force: true force: true
# ai TODO @fan这个融合到 yudao.ai 那好点哈
ai:
midjourney-proxy:
enable: true
url: https://api.holdai.top/mj
notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify
key: sk-c3qxUCVKsPfdQiYU8440E3Fc8dE5424d9cB124A4Ee2489E3
--- #################### 定时任务相关配置 #################### --- #################### 定时任务相关配置 ####################

View File

@ -201,6 +201,13 @@ yudao.ai:
enable: true enable: true
base-url: https://suno-imrqwwui8-status2xxs-projects.vercel.app base-url: https://suno-imrqwwui8-status2xxs-projects.vercel.app
ai:
midjourney-proxy:
enable: true
url: https://api.holdai.top/mj
notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify
key: sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf
--- #################### 芋道相关配置 #################### --- #################### 芋道相关配置 ####################
yudao: yudao: