【代码优化】AI:MJ 生成图片的优化

This commit is contained in:
YunaiV 2024-06-25 19:57:37 +08:00
parent 88142ed74c
commit 098483d2be
23 changed files with 333 additions and 653 deletions

View File

@ -21,30 +21,21 @@ public interface ErrorCodeConstants {
// ========== API 聊天模型 1-040-002-000 ========== // ========== API 聊天模型 1-040-002-000 ==========
ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在"); ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在");
ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!"); ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!");
ErrorCode CHAT_ROLE_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认聊天角色");
// ========== API 聊天会话 1-040-003-000 ========== // ========== API 聊天会话 1-040-003-000 ==========
ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "对话不存在!"); ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "对话不存在!");
ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整"); ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整");
ErrorCode CHAT_CONVERSATION_UPDATE_MAX_TOKENS_ERROR = new ErrorCode(1_040_003_002, "更新对话失败,最大 Token 超过上限");
ErrorCode CHAT_CONVERSATION_UPDATE_MAX_CONTEXTS_ERROR = new ErrorCode(1_040_003_002, "更新对话失败,最大 Context 超过上限");
// ========== API 聊天消息 1-040-004-000 ========== // ========== API 聊天消息 1-040-004-000 ==========
ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!"); ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
ErrorCode AI_CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "Stream 对话异常!"); ErrorCode AI_CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "Stream 对话异常!");
// midjourney
ErrorCode AI_MIDJOURNEY_IMAGINE_FAIL = new ErrorCode(1_022_000_040, "midjourney imagine 操作失败!");
ErrorCode AI_MIDJOURNEY_OPERATION_NOT_EXISTS = new ErrorCode(1_022_000_040, "midjourney 操作不存在!");
ErrorCode AI_MIDJOURNEY_MESSAGE_ID_INCORRECT = new ErrorCode(1_022_000_040, "midjourney message id 不正确!");
// ========== API 绘画 1-040-005-000 ========== // ========== API 绘画 1-040-005-000 ==========
ErrorCode AI_IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "image 不存在!"); ErrorCode AI_IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
ErrorCode AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败! {}"); ErrorCode AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
ErrorCode AI_IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}"); ErrorCode AI_IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");
ErrorCode AI_IMAGE_SYSTEM_ACCOUNT_INSUFFICIENT_BALANCE = new ErrorCode(1_022_005_003, "Midjourney 系统账户余额不足!");
} }

View File

@ -12,20 +12,20 @@ import lombok.Getter;
@Getter @Getter
public enum AiImageStatusEnum { public enum AiImageStatusEnum {
IN_PROGRESS("10", "进行中"), IN_PROGRESS(10, "进行中"),
SUCCESS("20", "完成"), SUCCESS(20, "完成"),
FAIL("30", "失败"); FAIL(30, "失败");
/** /**
* 状态 * 状态
*/ */
private final String status; private final Integer status;
/** /**
* 状态名 * 状态名
*/ */
private final String name; private final String name;
public static AiImageStatusEnum valueOfStatus(String status) { public static AiImageStatusEnum valueOfStatus(Integer status) {
for (AiImageStatusEnum statusEnum : AiImageStatusEnum.values()) { for (AiImageStatusEnum statusEnum : AiImageStatusEnum.values()) {
if (statusEnum.getStatus().equals(status)) { if (statusEnum.getStatus().equals(status)) {
return statusEnum; return statusEnum;

View File

@ -3,7 +3,6 @@ package cn.iocoder.yudao.module.ai.config;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.core.model.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;

View File

@ -29,12 +29,17 @@ Authorization: {{token}}
"style": "vivid" "style": "vivid"
} }
### chat midjourney ### 生成图片:生成图片
POST {{baseUrl}}/admin-api/ai/image/midjourney POST {{baseUrl}}/ai/image/midjourney/imagine
Content-Type: application/json Content-Type: application/json
Authorization: {{token}} Authorization: {{token}}
{ {
"prompt": "Cute cartoon style mobile game scene, a colorful camping car with an outdoor table and chairs next to it on the road in a spring forest, the simple structure of the camper van, soft lighting, C4D rendering, 3d model in the style of a cartoon, cute shape, a pastel color scheme, closeup view from the side angle, high resolution, bright colors, a happy atmosphere." "prompt": "中国旗袍",
"model": "midjourney",
"width": "1",
"height": "1",
"version": "6.0",
"base64Array": []
} }

View File

@ -1,14 +1,14 @@
package cn.iocoder.yudao.module.ai.controller.admin.image; package cn.iocoder.yudao.module.ai.controller.admin.image;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.service.image.AiImageService; import cn.iocoder.yudao.module.ai.service.image.AiImageService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
@ -63,19 +63,24 @@ public class AiImageController {
return success(true); return success(true);
} }
// ================ midjourney 接口 ================ // ================ midjourney 专属 ================
@Operation(summary = "Midjourney imagine绘画") @Operation(summary = "【Midjourney】生成图片")
@PostMapping("/midjourney/imagine") @PostMapping("/midjourney/imagine")
public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO req) { public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO reqVO) {
return success(imageService.midjourneyImagine(getLoginUserId(), req)); if (true) {
imageService.midjourneySync();
return null;
}
Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO);
return success(imageId);
} }
@Operation(summary = "Midjourney 回调通知", description = "由 Midjourney Proxy 回调") @Operation(summary = "Midjourney 生成图片的回调通知", description = "由 Midjourney Proxy 回调")
@PostMapping("/midjourney-notify") @PostMapping("/midjourney-notify")
@PermitAll @PermitAll
public void midjourneyNotify(@RequestBody MidjourneyNotifyReqVO notifyReqVO) { public void midjourneyNotify(@RequestBody MidjourneyApi.Notify notify) {
imageService.midjourneyNotify(notifyReqVO); imageService.midjourneyNotify(notify);
} }
@Operation(summary = "Midjourney Action", description = "例如说放大、缩小、U1、U2 等") @Operation(summary = "Midjourney Action", description = "例如说放大、缩小、U1、U2 等")

View File

@ -1,40 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.List;
// TODO @fan待定
/**
* midjourney req
*
* @author fansili
* @time 2024/4/28 17:42
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiImageMidjourneyImagineReqVO {
@Schema(description = "提示词")
@NotNull(message = "提示词不能为空!")
private String prompt;
@Schema(description = "模型(midjourney、niji)")
private String model;
@Schema(description = "图片宽度 --ar 设置")
private Integer width;
@Schema(description = "图片高度 --ar 设置")
private Integer height;
@Schema(description = "版本号 --v 设置")
private String version;
@Schema(description = "垫图(参考图)base64数组")
private List<String> base64Array;
}

View File

@ -1,30 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* mj 保存 components 记录
*
* "components": [
* {
* "custom_id": "MJ::JOB::upsample::1::5d32f4e8-8d2f-4bef-82d8-bf517e3c3660",
* "style": 2,
* "label": "U1",
* "type": 2
* },
* ]
*
* @author fansili
* @time 2024/5/8 14:44
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiImageMidjourneyOperationsVO {
private String custom_id;
private String style;
private String label;
private String type;
}

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo; package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data; import lombok.Data;
@ -46,14 +47,11 @@ public class AiImageRespVO {
@Schema(description = "绘制参数") @Schema(description = "绘制参数")
private Map<String, String> options; private Map<String, String> options;
@Schema(description = "绘画 response")
private MidjourneyNotifyReqVO response;
// TODO @fan进度是百分比还是一个数字哈感觉这个可以统一成通用字段 // TODO @fan进度是百分比还是一个数字哈感觉这个可以统一成通用字段
@Schema(description = "mj 进度") @Schema(description = "mj 进度")
private String progress; private String progress;
@Schema(description = "mj buttons 按钮") @Schema(description = "mj buttons 按钮")
private List<MidjourneyNotifyReqVO.Button> buttons; private List<MidjourneyApi.Button> buttons;
} }

View File

@ -1,33 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import java.util.List;
// TODO @fan待定
/**
* MidjourneyImagine 请求
*
* @author fansili
* @time 2024/5/30 14:02
* @since 1.0
*/
@Data
public class MidjourneyImagineReqVO {
@Schema(description = "垫图(参考图)base64数组", required = false)
private List<String> base64Array;
@Schema(description = "通知地址", required = false)
@NotNull(message = "回调地址不能为空!")
private String notifyHook;
@Schema(description = "提示词", required = true)
@NotNull(message = "提示词不能为空!")
private String prompt;
@Schema(description = "自定义参数", required = false)
private String state;
}

View File

@ -1,75 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.List;
/**
* Midjourney Proxy 通知回调
*
* - Midjourney Proxy通知回调 bean com.github.novicezk.midjourney.support.Task
* - 毫秒 api 通知回调文档地址https://gpt-best.apifox.cn/doc-3530863
*
* @author fansili
* @time 2024/5/31 10:37
* @since 1.0
*/
@Data
public class MidjourneyNotifyReqVO {
@Schema(description = "job id")
private String id;
@Schema(description = "任务类型 MidjourneyTaskActionEnum")
private String action;
@Schema(description = "任务状态 MidjourneyTaskStatusEnum")
private String status;
@Schema(description = "提示词")
private String prompt;
@Schema(description = "提示词-英文")
private String promptEn;
@Schema(description = "任务描述")
private String description;
@Schema(description = "自定义参数")
private String state;
@Schema(description = "提交时间")
private Long submitTime;
@Schema(description = "开始执行时间")
private Long startTime;
@Schema(description = "结束时间")
private Long finishTime;
@Schema(description = "图片url")
private String imageUrl;
@Schema(description = "任务进度")
private String progress;
@Schema(description = "失败原因")
private String failReason;
@Schema(description = "任务完成后的可执行按钮")
private List<Button> buttons;
@Data
public static class Button {
@Schema(description = "MJ::JOB::upsample::1::85a4b4c1-8835-46c5-a15c-aea34fad1862 动作标识")
private String customId;
@Schema(description = "图标 emoji")
private String emoji;
@Schema(description = "Make Variations 文本")
private String label;
@Schema(description = "类型,系统内部使用")
private String type;
@Schema(description = "样式: 2Primary、3Green")
private String style;
}
}

View File

@ -1,30 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.Map;
// TODO @fan待定
/**
* MidjourneyImagine 请求
*
* @author fansili
* @time 2024/5/30 14:02
* @since 1.0
*/
@Data
public class MidjourneySubmitRespVO {
@Schema(description = "状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)")
private String code;
@Schema(description = "描述")
private String description;
@Schema(description = "扩展字段")
private Map<String, Object> properties;
@Schema(description = "任务ID")
private String result;
}

View File

@ -0,0 +1,38 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import java.util.List;
@Schema(description = "管理后台 - 绘画生成Midjourney Request VO")
@Data
public class AiImageMidjourneyImagineReqVO {
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "中国神龙")
@NotEmpty(message = "提示词不能为空!")
private String prompt;
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "midjourney")
@NotEmpty(message = "模型不能为空")
private String model; // 参考 MidjourneyApi.ModelEnum
@Schema(description = "图片宽度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "图片宽度不能为空")
private Integer width;
@Schema(description = "图片高度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "图片高度不能为空")
private Integer height;
@Schema(description = "版本号", requiredMode = Schema.RequiredMode.REQUIRED, example = "6.0")
@NotEmpty(message = "版本号不能为空")
private String version;
// TODO @fan参考图建议用 referImageUrl
@Schema(description = "垫图(参考图)base64数组")
private List<String> base64Array;
}

View File

@ -1,8 +1,8 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.image; package cn.iocoder.yudao.module.ai.dal.dataobject.image;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils; import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO; import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
@ -28,8 +28,6 @@ import java.util.Map;
@Data @Data
public class AiImageDO extends BaseDO { public class AiImageDO extends BaseDO {
// TODO @fan1使用 java 注释哈不要注解2关联枚举字段要关联到对应类参考 AiChatMessageDO 的注释
/** /**
* 编号 * 编号
*/ */
@ -76,7 +74,7 @@ public class AiImageDO extends BaseDO {
* *
* 枚举 {@link AiImageStatusEnum} * 枚举 {@link AiImageStatusEnum}
*/ */
private String status; private Integer status;
/** /**
* 图片地址 * 图片地址
@ -96,23 +94,11 @@ public class AiImageDO extends BaseDO {
@TableField(typeHandler = JacksonTypeHandler.class) @TableField(typeHandler = JacksonTypeHandler.class)
private Map<String, Object> options; private Map<String, Object> options;
/**
* 绘画 response
*/
@TableField(typeHandler = JacksonTypeHandler.class)
private MidjourneyNotifyReqVO response;
// TODO @fan这个建议 Double
/**
* mj 进度(10%50%100%)
*/
private String progress;
/** /**
* mj buttons 按钮 * mj buttons 按钮
*/ */
@TableField(typeHandler = ButtonTypeHandler.class) @TableField(typeHandler = ButtonTypeHandler.class)
private List<MidjourneyNotifyReqVO.Button> buttons; private List<MidjourneyApi.Button> buttons;
/** /**
* midjourney proxy 关联的 task id * midjourney proxy 关联的 task id
@ -124,12 +110,11 @@ public class AiImageDO extends BaseDO {
*/ */
private String errorMessage; private String errorMessage;
// TODO @芋艿看看是不是 MidjourneyNotifyReqVO.Button 搞到 MJ API
public static class ButtonTypeHandler extends AbstractJsonTypeHandler<Object> { public static class ButtonTypeHandler extends AbstractJsonTypeHandler<Object> {
@Override @Override
protected Object parse(String json) { protected Object parse(String json) {
return JsonUtils.parseArray(json, MidjourneyNotifyReqVO.Button.class); return JsonUtils.parseArray(json, MidjourneyApi.Button.class);
} }
@Override @Override

View File

@ -1,13 +1,10 @@
package cn.iocoder.yudao.module.ai.dal.mysql.image; package cn.iocoder.yudao.module.ai.dal.mysql.image;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.List; import java.util.List;
@ -20,52 +17,19 @@ import java.util.List;
@Mapper @Mapper
public interface AiImageMapper extends BaseMapperX<AiImageDO> { public interface AiImageMapper extends BaseMapperX<AiImageDO> {
// TODO @fan这个建议直接使用 updateservice 拼接要改的状态哈 default AiImageDO selectByTaskId(String taskId) {
return this.selectOne(AiImageDO::getTaskId, taskId);
/**
* 更新 - 根据 messageId
*
* @param mjNonceId
* @param aiImageDO
*/
default void updateByMjNonce(Long mjNonceId, AiImageDO aiImageDO) {
// this.update(aiImageDO, new LambdaQueryWrapperX<AiImageDO>().eq(AiImageDO::getMjNonceId, mjNonceId));
return;
} }
/**
* 查询 - 根据 job id
*
* @param id
* @return
*/
default AiImageDO selectByJobId(String id) {
return this.selectOne(new LambdaQueryWrapperX<AiImageDO>().eq(AiImageDO::getTaskId, id));
}
/**
* 查询 - page
*
* @param userId
* @param pageReqVO
* @return
*/
default PageResult<AiImageDO> selectPage(Long userId, PageParam pageReqVO) { default PageResult<AiImageDO> selectPage(Long userId, PageParam pageReqVO) {
return selectPage(pageReqVO, new LambdaQueryWrapperX<AiImageDO>() return selectPage(pageReqVO, new LambdaQueryWrapperX<AiImageDO>()
.eq(AiImageDO::getUserId, userId) .eq(AiImageDO::getUserId, userId)
.orderByDesc(AiImageDO::getId)); .orderByDesc(AiImageDO::getId));
} }
/** default List<AiImageDO> selectListByStatusAndPlatform(Integer status, String platform) {
* 查询 - 根据 status platform return selectList(AiImageDO::getStatus, status,
* AiImageDO::getPlatform, platform);
* @return
*/
default List<AiImageDO> selectByStatusAndPlatform(AiImageStatusEnum statusEnum, AiPlatformEnum platformEnum) {
return this.selectList(new LambdaUpdateWrapper<AiImageDO>()
.eq(AiImageDO::getStatus, statusEnum.getStatus())
.eq(AiImageDO::getPlatform, platformEnum.getPlatform())
);
} }
} }

View File

@ -1,79 +0,0 @@
package cn.iocoder.yudao.module.ai.job;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.quartz.core.handler.JobHandler;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.ai.service.image.AiImageService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* midjourney job 定时拉去 midjourney 绘制状态
*
* @author fansili
* @time 2024/6/5 14:55
* @since 1.0
*/
@Component
@Slf4j
public class MidjourneyJob implements JobHandler {
// TODO @fan@Resource
@Autowired(required = false)
private MidjourneyApi midjourneyApi;
@Autowired
private AiImageMapper imageMapper;
@Autowired
private AiImageService imageService;
// TODO @fan这个方法建议实现到 AiImageService例如说 midjourneySync返回 int 同步数量
@Override
public String execute(String param) {
// 1获取 midjourney 平台状态在 进行中 image
List<AiImageDO> imageList = imageMapper.selectByStatusAndPlatform(AiImageStatusEnum.IN_PROGRESS, AiPlatformEnum.MIDJOURNEY);
log.info("Midjourney 同步 - 任务数量 {}!", imageList.size());
if (CollUtil.isEmpty(imageList)) {
return "Midjourney 同步 - 数量为空!";
}
log.info("Midjourney 同步 - 开始...");
// 2批量拉去 task 信息
List<MidjourneyApi.NotifyRequest> taskList = midjourneyApi
.listByCondition(CollectionUtils.convertSet(imageList, AiImageDO::getTaskId));
Map<String, MidjourneyApi.NotifyRequest> taskIdMap
= CollectionUtils.convertMap(taskList, MidjourneyApi.NotifyRequest::id);
// 3更新 image 状态
List<AiImageDO> updateImageList = new ArrayList<>();
for (AiImageDO aiImageDO : imageList) {
// 3.1 排除掉空的情况
if (!taskIdMap.containsKey(aiImageDO.getTaskId())) {
log.warn("Midjourney 同步 - {} 任务为空!", aiImageDO.getTaskId());
continue;
}
// TODO @ 3.1 3.2 是不是融合下get然后判空continue
// 3.2 获取通知对象
MidjourneyApi.NotifyRequest notifyRequest = taskIdMap.get(aiImageDO.getTaskId());
// 3.2 构建更新对象
// TODO @fan建议 List<MidjourneyNotifyReqVO> 作为 imageService 去更新
// TODO @芋艿 BeanUtils.toBean 转换为 null
updateImageList.add(imageService.buildUpdateImage(aiImageDO.getId(),
JsonUtils.parseObject(JsonUtils.toJsonString(notifyRequest), MidjourneyNotifyReqVO.class)));
}
// 4批了更新 updateImageList
imageMapper.updateBatch(updateImageList);
return "Midjourney 同步 - ".concat(String.valueOf(updateImageList.size())).concat(" 任务!");
}
}

View File

@ -0,0 +1,28 @@
package cn.iocoder.yudao.module.ai.job.image;
import cn.iocoder.yudao.framework.quartz.core.handler.JobHandler;
import cn.iocoder.yudao.module.ai.service.image.AiImageService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
/**
* Midjourney 同步 Job定时拉去 midjourney 绘制状态
*
* @author fansili
*/
@Component
@Slf4j
public class MidjourneySyncJob implements JobHandler {
@Resource
private AiImageService imageService;
@Override
public String execute(String param) {
Integer count = imageService.midjourneySync();
log.info("[execute][同步 Midjourney ({}) 个]", count);
return String.format("同步 Midjourney %s 个", count);
}
}

View File

@ -1,10 +1,10 @@
package cn.iocoder.yudao.module.ai.service.image; package cn.iocoder.yudao.module.ai.service.image;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
/** /**
@ -40,15 +40,6 @@ public interface AiImageService {
*/ */
Long drawImage(Long userId, AiImageDrawReqVO drawReqVO); Long drawImage(Long userId, AiImageDrawReqVO drawReqVO);
/**
* Midjourney imagine绘画
*
* @param userId 用户编号
* @param imagineReqVO 绘制请求
* @return 绘画编号
*/
Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO imagineReqVO);
/** /**
* 删除我的绘画记录 * 删除我的绘画记录
* *
@ -57,22 +48,30 @@ public interface AiImageService {
*/ */
void deleteImageMy(Long id, Long userId); void deleteImageMy(Long id, Long userId);
/** // ================ midjourney 专属 ================
* midjourney proxy - 回调通知
*
* @param notifyReqVO
* @return
*/
void midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO);
/** /**
* 构建 midjourney - 更新对象 * Midjourney生成图片
* *
* @param imageId * @param userId 用户编号
* @param notifyReqVO * @param reqVO 绘制请求
* @return * @return 绘画编号
*/ */
AiImageDO buildUpdateImage(Long imageId, MidjourneyNotifyReqVO notifyReqVO); Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO reqVO);
/**
* Midjourney同步图片进展
*
* @return 同步成功数量
*/
Integer midjourneySync();
/**
* Midjourney通知图片进展
*
* @param notify 通知
*/
void midjourneyNotify(MidjourneyApi.Notify notify);
/** /**
* midjourney - action(放大缩小U1U2...) * midjourney - action(放大缩小U1U2...)
@ -83,4 +82,5 @@ public interface AiImageService {
* @return * @return
*/ */
void midjourneyAction(Long loginUserId, Long imageId, String customId); void midjourneyAction(Long loginUserId, Long imageId, String customId);
} }

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.bean.BeanUtil; import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.codec.Base64; import cn.hutool.core.codec.Base64;
import cn.hutool.core.exceptions.ExceptionUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.map.MapUtil; import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
@ -14,8 +14,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
@ -29,15 +28,17 @@ import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.util.List; import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
/** /**
@ -51,12 +52,16 @@ public class AiImageServiceImpl implements AiImageService {
@Resource @Resource
private AiImageMapper imageMapper; private AiImageMapper imageMapper;
@Resource @Resource
private FileApi fileApi; private FileApi fileApi;
@Resource @Resource
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Autowired(required = false)
@Resource
private MidjourneyApi midjourneyApi; private MidjourneyApi midjourneyApi;
@Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}") @Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}")
private String midjourneyNotifyUrl; private String midjourneyNotifyUrl;
@ -74,7 +79,7 @@ public class AiImageServiceImpl implements AiImageService {
public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) { public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
// 1. 保存数据库 // 1. 保存数据库
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setWidth(drawReqVO.getWidth()).setHeight(drawReqVO.getHeight()).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
imageMapper.insert(image); imageMapper.insert(image);
// 2. 异步绘制后续前端通过返回的 id 进行轮询结果 // 2. 异步绘制后续前端通过返回的 id 进行轮询结果
getSelf().executeDrawImage(image, drawReqVO); getSelf().executeDrawImage(image, drawReqVO);
@ -121,47 +126,6 @@ public class AiImageServiceImpl implements AiImageService {
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
} }
@Override
@Transactional(rollbackFor = Exception.class)
public Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO req) {
// 1构建 AiImageDO 保存
AiImageDO image = new AiImageDO()
.setUserId(userId)
.setPrompt(req.getPrompt())
.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform())
.setModel(req.getModel())
.setWidth(req.getWidth())
.setHeight(req.getHeight())
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
imageMapper.insert(image);
// 3调用 MidjourneyProxy 提交任务
// 3.1设置 midjourney 扩展参数
MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(null, midjourneyNotifyUrl, req.getPrompt(),
buildParams(req.getWidth(), req.getHeight(), req.getVersion(),
MidjourneyApi.ModelEnum.valueOfModel(req.getModel())));
// 3.2提交绘画请求
// TODO @fan5 这里失败的情况到底抛出异常还是 RespVO可以参考 OpenAI API 封装
MidjourneyApi.SubmitResponse submitResponse = midjourneyApi.imagine(imagineRequest);
// 4保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(submitResponse.code())) {
if (submitResponse.description().contains("quota_not_enough")) {
throw exception(AI_IMAGE_SYSTEM_ACCOUNT_INSUFFICIENT_BALANCE, submitResponse.description());
}
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitResponse.description());
}
// 4.1更新 taskId 和参数
imageMapper.updateById(new AiImageDO()
.setId(image.getId())
.setTaskId(submitResponse.result())
.setOptions(BeanUtil.beanToMap(req))
);
return image.getId();
}
@Override @Override
public void deleteImageMy(Long id, Long userId) { public void deleteImageMy(Long id, Long userId) {
// 1. 校验是否存在 // 1. 校验是否存在
@ -173,50 +137,111 @@ public class AiImageServiceImpl implements AiImageService {
imageMapper.deleteById(id); imageMapper.deleteById(id);
} }
@Override private AiImageDO validateImageExists(Long id) {
public void midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) { AiImageDO image = imageMapper.selectById(id);
// 1根据 job id 查询关联的 image
AiImageDO image = imageMapper.selectByJobId(notifyReqVO.getId());
if (image == null) { if (image == null) {
log.warn("midjourneyNotify 回调的 jobId 不存在! jobId: {}", notifyReqVO.getId()); throw exception(AI_IMAGE_NOT_EXISTS);
} }
// 2转换状态 return image;
AiImageDO updateImage = buildUpdateImage(image.getId(), notifyReqVO);
// 3更新 image 状态
imageMapper.updateById(updateImage);
} }
public AiImageDO buildUpdateImage(Long imageId, MidjourneyNotifyReqVO notifyReqVO) { // ================ midjourney 专属 ================
// 1转换状态
String imageStatus = null; @Override
if (StrUtil.isNotBlank(notifyReqVO.getStatus())) { @Transactional(rollbackFor = Exception.class)
MidjourneyApi.TaskStatusEnum taskStatusEnum = MidjourneyApi.TaskStatusEnum.valueOf(notifyReqVO.getStatus()); public Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO reqVO) {
// 1. 保存数据库
AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
imageMapper.insert(image);
// 2. 调用 Midjourney Proxy 提交任务
MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(
null, midjourneyNotifyUrl, reqVO.getPrompt(),
MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(), reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel()));
MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest);
// 3. 情况一失败抛出业务异常
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(imagineResponse.code())) {
String description = imagineResponse.description().contains("quota_not_enough") ?
"账户余额不足" : imagineResponse.description();
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
}
// 4. 情况二成功更新 taskId 和参数
imageMapper.updateById(new AiImageDO()
.setId(image.getId())
.setTaskId(imagineResponse.result())
.setOptions(BeanUtil.beanToMap(reqVO))
);
return image.getId();
}
@Override
public Integer midjourneySync() {
// 1.1 获取 Midjourney 平台状态在 进行中 image
List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
if (CollUtil.isEmpty(imageList)) {
return 0;
}
// 1.2 调用 Midjourney Proxy 获取任务进展
List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(convertSet(imageList, AiImageDO::getTaskId));
Map<String, MidjourneyApi.Notify> taskMap = convertMap(taskList, MidjourneyApi.Notify::id);
// 2. 逐个处理更新进展
int count = 0;
for (AiImageDO image : imageList) {
MidjourneyApi.Notify notify = taskMap.get(image.getTaskId());
if (notify == null) {
log.error("[midjourneySync][image({}) 查询不到进展]", image);
continue;
}
count++;
updateMidjourneyStatus(image, notify);
}
return count;
}
@Override
public void midjourneyNotify(MidjourneyApi.Notify notify) {
// 1. 校验 image 存在
AiImageDO image = imageMapper.selectByTaskId(notify.id());
if (image == null) {
log.warn("[midjourneyNotify][回调任务({}) 不存在]", notify.id());
return;
}
// 2. 更新状态
updateMidjourneyStatus(image, notify);
}
private void updateMidjourneyStatus(AiImageDO image, MidjourneyApi.Notify notify) {
// 1. 转换状态
Integer status = null;
if (StrUtil.isNotBlank(notify.status())) {
MidjourneyApi.TaskStatusEnum taskStatusEnum = MidjourneyApi.TaskStatusEnum.valueOf(notify.status());
if (MidjourneyApi.TaskStatusEnum.SUCCESS == taskStatusEnum) { if (MidjourneyApi.TaskStatusEnum.SUCCESS == taskStatusEnum) {
imageStatus = AiImageStatusEnum.SUCCESS.getStatus(); status = AiImageStatusEnum.SUCCESS.getStatus();
} else if (MidjourneyApi.TaskStatusEnum.FAILURE == taskStatusEnum) { } else if (MidjourneyApi.TaskStatusEnum.FAILURE == taskStatusEnum) {
imageStatus = AiImageStatusEnum.FAIL.getStatus(); status = AiImageStatusEnum.FAIL.getStatus();
} }
} }
// 2上传图片 // 2. 上传图片
String filePath = null; String picUrl = null;
if (!StrUtil.isBlank(notifyReqVO.getImageUrl())) { if (StrUtil.isNotBlank(notify.imageUrl())) {
try { try {
filePath = fileApi.createFile(HttpUtil.downloadBytes(notifyReqVO.getImageUrl())); picUrl = fileApi.createFile(HttpUtil.downloadBytes(notify.imageUrl()));
} catch (Exception e) { } catch (Exception e) {
log.warn("midjourneyNotify 图片上传失败! {} 异常:{}", notifyReqVO.getImageUrl(), ExceptionUtil.getMessage(e)); picUrl = notify.imageUrl();
log.warn("[updateMidjourneyStatus][图片({}) 地址({}) 上传失败]", image.getId(), notify.imageUrl(), e);
} }
} }
// 3更新 image 状态 // 3. 更新 image 状态
return new AiImageDO() imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(status)
.setId(imageId) .setPicUrl(picUrl).setButtons(notify.buttons()).setErrorMessage(notify.failReason()));
.setStatus(imageStatus)
.setPicUrl(filePath)
.setProgress(notifyReqVO.getProgress())
.setResponse(notifyReqVO)
.setButtons(notifyReqVO.getButtons())
.setErrorMessage(notifyReqVO.getFailReason());
} }
@Override @Override
@ -236,7 +261,6 @@ public class AiImageServiceImpl implements AiImageService {
// 5新增 image 记录(根据 image 新增一个) // 5新增 image 记录(根据 image 新增一个)
AiImageDO newImage = new AiImageDO(); AiImageDO newImage = new AiImageDO();
newImage.setId(null);
newImage.setUserId(image.getUserId()); newImage.setUserId(image.getUserId());
newImage.setPrompt(image.getPrompt()); newImage.setPrompt(image.getPrompt());
@ -248,20 +272,15 @@ public class AiImageServiceImpl implements AiImageService {
newImage.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); newImage.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
newImage.setPublicStatus(image.getPublicStatus()); newImage.setPublicStatus(image.getPublicStatus());
newImage.setPicUrl(null);
newImage.setProgress(null);
newImage.setButtons(null);
newImage.setOptions(image.getOptions()); newImage.setOptions(image.getOptions());
newImage.setResponse(image.getResponse());
newImage.setTaskId(submitResponse.result()); newImage.setTaskId(submitResponse.result());
newImage.setErrorMessage(null);
imageMapper.insert(newImage); imageMapper.insert(newImage);
} }
private static void validateCustomId(String customId, List<MidjourneyNotifyReqVO.Button> buttons) { private static void validateCustomId(String customId, List<MidjourneyApi.Button> buttons) {
boolean isTrue = false; boolean isTrue = false;
for (MidjourneyNotifyReqVO.Button button : buttons) { for (MidjourneyApi.Button button : buttons) {
if (button.getCustomId().equals(customId)) { if (button.customId().equals(customId)) {
isTrue = true; isTrue = true;
break; break;
} }
@ -271,14 +290,6 @@ public class AiImageServiceImpl implements AiImageService {
} }
} }
private AiImageDO validateImageExists(Long id) {
AiImageDO image = imageMapper.selectById(id);
if (image == null) {
throw exception(AI_IMAGE_NOT_EXISTS);
}
return image;
}
/** /**
* 获得自身的代理对象解决 AOP 生效问题 * 获得自身的代理对象解决 AOP 生效问题
* *
@ -288,28 +299,4 @@ public class AiImageServiceImpl implements AiImageService {
return SpringUtil.getBean(getClass()); return SpringUtil.getBean(getClass());
} }
// TODO @fan这个是不是应该放在 MJ API 的封装里面搞哈
/**
* 构建 Midjourney 自定义参数
*
* @param width
* @param height
* @param version
* @param model
* @return
*/
private String buildParams(Integer width, Integer height, String version, MidjourneyApi.ModelEnum model) {
StringBuilder params = new StringBuilder();
// --ar 来设置尺寸
params.append(String.format(" --ar %s:%s ", width, height));
// --niji 模型
if (MidjourneyApi.ModelEnum.NIJI == model) {
params.append(String.format(" --niji %s ", version));
} else {
// --v 版本
params.append(String.format(" --v %s ", version));
}
return params.toString();
}
} }

View File

@ -50,34 +50,6 @@
<artifactId>junit</artifactId> <artifactId>junit</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<!-- TODO fan这里包要进一步减少 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<!-- https://mvnrepository.com/artifact/com.squareup.okhttp3/okhttp -->
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
</dependency>
<dependency>
<groupId>io.projectreactor.netty</groupId>
<artifactId>reactor-netty</artifactId>
</dependency>
<dependency>
<groupId>net.dv8tion</groupId>
<artifactId>JDA</artifactId>
<version>5.0.0-beta.21</version>
<!-- <exclusions>-->
<!-- <exclusion>-->
<!-- <groupId>club.minnced</groupId>-->
<!-- <artifactId>opus-java</artifactId>-->
<!-- </exclusion>-->
<!-- </exclusions>-->
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -16,66 +16,54 @@ import java.util.List;
import java.util.Map; import java.util.Map;
/** /**
* Midjourney api * Midjourney API
* *
* @author fansili * @author fansili
* @time 2024/6/11 15:46
* @since 1.0 * @since 1.0
*/ */
@Slf4j @Slf4j
public class MidjourneyApi { public class MidjourneyApi {
private static final String URI_IMAGINE = "/submit/imagine";
private static final String URI_ACTON = "/submit/action";
private static final String URI_LIST_BY_CONDITION = "/task/list-by-condition";
private final WebClient webClient; private final WebClient webClient;
private final MidjourneyConfig midjourneyConfig;
public MidjourneyApi(MidjourneyConfig midjourneyConfig) { public MidjourneyApi(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig;
this.webClient = WebClient.builder() this.webClient = WebClient.builder()
.baseUrl(midjourneyConfig.getUrl()) .baseUrl(midjourneyConfig.getUrl())
.defaultHeaders(ApiUtils.getJsonContentHeaders(midjourneyConfig.getKey())) .defaultHeaders(ApiUtils.getJsonContentHeaders(midjourneyConfig.getKey()))
.build(); .build();
} }
/** /**
* imagine - 根据提示词提交绘画任务 * imagine - 根据提示词提交绘画任务
* *
* @param imagineReqVO * @param request 请求
* @return * @return 提交结果
*/ */
public SubmitResponse imagine(ImagineRequest imagineReqVO) { public SubmitResponse imagine(ImagineRequest request) {
// 1发送 post 请求 String response = post("/submit/imagine", request);
String res = post(URI_IMAGINE, imagineReqVO); return JsonUtils.parseObject(response, SubmitResponse.class);
// 2转换 resp
return JsonUtils.parseObject(res, SubmitResponse.class);
} }
/** /**
* action - 放大缩小U1U2... * action - 放大缩小U1U2...
* *
* @param actionReqVO * @param request 请求
* @return 提交结果
*/ */
public SubmitResponse action(ActionRequest actionReqVO) { public SubmitResponse action(ActionRequest request) {
// 1发送 post 请求 String res = post("/submit/action", request);
String res = post(URI_ACTON, actionReqVO);
// 2转换 resp
return JsonUtils.parseObject(res, SubmitResponse.class); return JsonUtils.parseObject(res, SubmitResponse.class);
} }
/** /**
* 批量查询 task 任务 * 批量查询 task 任务
* *
* @param taskIds * @param ids 任务编号数组
* @return * @return task 任务
*/ */
public List<NotifyRequest> listByCondition(Collection<String> taskIds) { public List<Notify> getTaskList(Collection<String> ids) {
// 1发送 post 请求 String res = post("/task/list-by-condition", ImmutableMap.of("ids", ids));
String res = post(URI_LIST_BY_CONDITION, ImmutableMap.of("ids", taskIds)); return JsonUtils.parseArray(res, Notify.class);
// 2转换 对象
return JsonUtils.parseArray(res, NotifyRequest.class);
} }
private String post(String uri, Object body) { private String post(String uri, Object body) {
@ -94,10 +82,10 @@ public class MidjourneyApi {
.block(); .block();
} }
// ====== record 结构 // ========== record 结构 ==========
/** /**
* Midjourney - Imagine 请求 * Imagine 请求生成图片
* *
* @param base64Array 垫图(参考图) base64数 * @param base64Array 垫图(参考图) base64数
* @param notifyHook 通知地址 * @param notifyHook 通知地址
@ -108,10 +96,24 @@ public class MidjourneyApi {
String notifyHook, String notifyHook,
String prompt, String prompt,
String state) { String state) {
public static String buildState(Integer width, Integer height, String version, String model) {
StringBuilder params = new StringBuilder();
// --ar 来设置尺寸
params.append(String.format(" --ar %s:%s ", width, height));
// --niji 模型
if (MidjourneyApi.ModelEnum.NIJI.getModel().equals(model)) {
params.append(String.format(" --niji %s ", version));
} else {
params.append(String.format(" --v %s ", version));
}
return params.toString();
}
} }
/** /**
* Midjourney - Action 请求 * Action 请求
* *
* @param customId 操作按钮id * @param customId 操作按钮id
* @param taskId 操作按钮id * @param taskId 操作按钮id
@ -124,7 +126,7 @@ public class MidjourneyApi {
} }
/** /**
* Midjourney - Submit 返回 * Submit 统一返回
* *
* @param code 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) * @param code 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)
* @param description 描述 * @param description 描述
@ -138,7 +140,7 @@ public class MidjourneyApi {
} }
/** /**
* Midjourney - 通知 request * 通知 request
* *
* @param id job id * @param id job id
* @param action 任务类型 {@link TaskActionEnum} * @param action 任务类型 {@link TaskActionEnum}
@ -155,7 +157,7 @@ public class MidjourneyApi {
* @param failReason 失败原因 * @param failReason 失败原因
* @param buttons 任务完成后的可执行按钮 * @param buttons 任务完成后的可执行按钮
*/ */
public record NotifyRequest(String id, public record Notify(String id,
String action, String action,
String status, String status,
@ -174,6 +176,8 @@ public class MidjourneyApi {
String failReason, String failReason,
List<Button> buttons) { List<Button> buttons) {
}
/** /**
* button * button
* *
@ -188,14 +192,12 @@ public class MidjourneyApi {
String label, String label,
String type, String type,
String style) { String style) {
}
} }
// ====== enums // ============ enums ============
/** /**
* Midjourney - 模型 * 模型枚举
*/ */
@AllArgsConstructor @AllArgsConstructor
@Getter @Getter
@ -203,24 +205,15 @@ public class MidjourneyApi {
MIDJOURNEY("midjourney", "midjourney"), MIDJOURNEY("midjourney", "midjourney"),
NIJI("niji", "niji"), NIJI("niji", "niji"),
; ;
private String model; private final String model;
private String name; private final String name;
public static ModelEnum valueOfModel(String model) {
for (ModelEnum itemEnum : ModelEnum.values()) {
if (itemEnum.getModel().equals(model)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + model);
}
} }
/** /**
* Midjourney - 提交返回的状态码 * 提交返回的状态码的枚举
*/ */
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
@ -239,64 +232,68 @@ public class MidjourneyApi {
private final String code; private final String code;
private final String name; private final String name;
} }
/** /**
* Midjourney - action * Action 枚举
*/ */
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public enum TaskActionEnum { public enum TaskActionEnum {
/** /**
* 生成图片. * 生成图片
*/ */
IMAGINE, IMAGINE,
/** /**
* 选中放大. * 选中放大
*/ */
UPSCALE, UPSCALE,
/** /**
* 选中其中的一张图生成四张相似的. * 选中其中的一张图生成四张相似的
*/ */
VARIATION, VARIATION,
/** /**
* 重新执行. * 重新执行
*/ */
REROLL, REROLL,
/** /**
* 图转prompt. * 图转 prompt
*/ */
DESCRIBE, DESCRIBE,
/** /**
* 多图混合. * 多图混合
*/ */
BLEND BLEND
} }
/** /**
* Midjourney - 任务状态 * 任务状态枚举
*/ */
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public enum TaskStatusEnum { public enum TaskStatusEnum {
/** /**
* 未启动. * 未启动
*/ */
NOT_START(0), NOT_START(0),
/** /**
* 已提交. * 已提交
*/ */
SUBMITTED(1), SUBMITTED(1),
/** /**
* 执行中. * 执行中
*/ */
IN_PROGRESS(3), IN_PROGRESS(3),
/** /**
* 失败. * 失败
*/ */
FAILURE(4), FAILURE(4),
/** /**
* 成功. * 成功
*/ */
SUCCESS(4); SUCCESS(4);

View File

@ -2,11 +2,9 @@
* model 接入各种大模型对标 https://github.com/spring-projects/spring-ai/tree/main/models * model 接入各种大模型对标 https://github.com/spring-projects/spring-ai/tree/main/models
* *
* 1. yiyan 百度文心一言 * 1. yiyan 百度文心一言
* 2. TODO 芋艿 * 2. tongyi 阿里通义千问对标 spring-cloud-alibaba 提供的 ai TODO 芋艿未来直接使用它
* tongyi 阿里通义千问对标 spring-cloud-alibaba 提供的 ai * 3. xinghuo 讯飞星火自己实现
* 2.2 * 4. midjourney Midjourney接入 https://github.com/novicezk/midjourney-proxy 实现
* 2.3 xinghuo 讯飞星火自己实现 * 5. suno TODO 芋艿
* 2.4 openai OpenAIChatGPT拷贝 spring-ai 提供的 models/openai
* 2.5 midjourney Midjourney参考 https://github.com/novicezk/midjourney-proxy 实现
*/ */
package cn.iocoder.yudao.framework.ai.core.model; package cn.iocoder.yudao.framework.ai.core.model;

View File

@ -76,13 +76,6 @@ server:
enabled: true enabled: true
charset: UTF-8 charset: UTF-8
force: true force: true
# ai TODO @fan这个融合到 yudao.ai 那好点哈
ai:
midjourney-proxy:
enable: true
url: https://api.holdai.top/mj
notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify
key: sk-c3qxUCVKsPfdQiYU8440E3Fc8dE5424d9cB124A4Ee2489E3
--- #################### 定时任务相关配置 #################### --- #################### 定时任务相关配置 ####################

View File

@ -201,6 +201,13 @@ yudao.ai:
enable: true enable: true
base-url: https://suno-imrqwwui8-status2xxs-projects.vercel.app base-url: https://suno-imrqwwui8-status2xxs-projects.vercel.app
ai:
midjourney-proxy:
enable: true
url: https://api.holdai.top/mj
notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify
key: sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf
--- #################### 芋道相关配置 #################### --- #################### 芋道相关配置 ####################
yudao: yudao: