【增加】1、mj components 相关操作任务提交 2、增加ai image 列表查询

This commit is contained in:
cherishsince 2024-05-08 15:07:49 +08:00
parent 16a47bde5b
commit e8cf40a82a
12 changed files with 240 additions and 57 deletions

View File

@ -14,29 +14,26 @@ public interface ErrorCodeConstants {
// chat
ErrorCode AI_MODULE_NOT_SUPPORTED = new ErrorCode(1_022_000_000, "AI 模型暂不支持!");
ErrorCode AI_CHAT_ROLE_NOT_EXISTENT = new ErrorCode(1_022_000_001, "AI Role 不存在!");;
// conversation
ErrorCode AI_CONVERSATION_NOT_EXISTS = new ErrorCode(1_022_000_002, "AI 对话不存在!");;
ErrorCode AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL = new ErrorCode(1_022_000_002, "chat 继续对话,对话 id 不能为空!");;
ErrorCode AI_CHAT_CONTINUE_NOT_EXIST = new ErrorCode(1_022_000_020, "chat 对话不存在!");
ErrorCode AI_CHAT_CONVERSATION_NOT_YOURS = new ErrorCode(1_022_000_021, "这条 chat 对话不是你的!");
// midjourney
ErrorCode AI_MIDJOURNEY_IMAGINE_FAIL = new ErrorCode(1_022_000_040, "midjourney imagine 操作失败!");
ErrorCode AI_MIDJOURNEY_OPERATION_NOT_EXISTS = new ErrorCode(1_022_000_040, "midjourney 操作不存在!");
ErrorCode AI_MIDJOURNEY_MESSAGE_ID_INCORRECT = new ErrorCode(1_022_000_040, "midjourney message id 不正确!");
// role
ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "chatRole 不存在!");
ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "AI 角色不存在!");
ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!");
// modal
ErrorCode AI_MODAL_NOT_EXIST = new ErrorCode(1_022_000_080, "AI 模型不存在!");
ErrorCode AI_MODAL_CONFIG_PARAMS_INCORRECT = new ErrorCode(1_022_000_081, "AI 模型 config 参数不正确! {} ");
ErrorCode AI_MODAL_NOT_SUPPORTED_MODAL = new ErrorCode(1_022_000_082, "AI 模型不支持的 modal! {} ");
ErrorCode AI_MODAL_PLATFORM_PARAMS_INCORRECT = new ErrorCode(1_022_000_083, "AI 平台参数不正确! {} ");
ErrorCode AI_MODAL_DISABLE_NOT_USED = new ErrorCode(1_022_000_084, "AI 模型禁用不能使用!");

View File

@ -12,7 +12,7 @@ import lombok.Getter;
*/
@AllArgsConstructor
@Getter
public enum AiChatDrawingStatusEnum {
public enum AiImageDrawingStatusEnum {
SUBMIT("submit", "提交任务"),
WAITING("waiting", "等待"),
@ -27,8 +27,8 @@ public enum AiChatDrawingStatusEnum {
private String name;
public static AiChatDrawingStatusEnum valueOfStatus(String status) {
for (AiChatDrawingStatusEnum itemEnum : AiChatDrawingStatusEnum.values()) {
public static AiImageDrawingStatusEnum valueOfStatus(String status) {
for (AiImageDrawingStatusEnum itemEnum : AiImageDrawingStatusEnum.values()) {
if (itemEnum.getStatus().equals(status)) {
return itemEnum;
}

View File

@ -1,9 +1,8 @@
package cn.iocoder.yudao.module.ai.controller.admin.image;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReqVO;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.service.AiImageService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
@ -31,6 +30,12 @@ public class AiImageController {
private final AiImageService aiImageService;
@Operation(summary = "获取image列表", description = "dall3、midjourney")
@GetMapping("/list")
public PageResult<AiImageListRespVO> list(@Validated @RequestBody AiImageListReqVO req) {
return aiImageService.list(req);
}
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
@PostMapping("/dallDrawing")
public AiImageDallDrawingRespVO dallDrawing(@Validated @RequestBody AiImageDallDrawingReqVO req) {
@ -46,13 +51,8 @@ public class AiImageController {
@Operation(summary = "midjourney绘画操作", description = "一般有选择图片、放大、换一批...")
@PostMapping("/midjourney-operate")
public CommonResult<Void> midjourneyOperate(@Validated @RequestBody AiImageMidjourneyReqVO req) {
return success(null);
}
@Operation(summary = "获取midjourney绘画列表", description = "获取 Midjourney 绘画列表")
@GetMapping("/get-midjourney-list")
public CommonResult<Void> getMidjourneyList(@Validated @RequestBody AiImageMidjourneyReqVO req) {
public CommonResult<Void> midjourneyOperate(@Validated @RequestBody AiImageMidjourneyOperateReqVO req) {
aiImageService.midjourneyOperate(req);
return success(null);
}

View File

@ -13,6 +13,6 @@ import lombok.experimental.Accessors;
*/
@Data
@Accessors(chain = true)
public class AiImageMidjourneyListRespVO extends PageParam {
public class AiImageListReqVO extends PageParam {
}

View File

@ -1,7 +1,6 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.experimental.Accessors;
@ -14,6 +13,6 @@ import lombok.experimental.Accessors;
*/
@Data
@Accessors(chain = true)
public class AiImageMidjourneyListReqVO extends PageParam {
public class AiImageListRespVO extends PageParam {
}

View File

@ -0,0 +1,31 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* midjourney req
*
* @author fansili
* @time 2024/4/28 17:42
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiImageMidjourneyOperateReqVO {
@NotNull(message = "图片编号不能为空")
@Schema(description = "编号")
private String id;
@NotNull(message = "消息编号不能为空")
@Schema(description = "消息编号")
private String messageId;
@NotNull(message = "操作编号不能为空")
@Schema(description = "操作编号")
private String operateId;
}

View File

@ -0,0 +1,30 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* mj 保存 components 记录
*
* "components": [
* {
* "custom_id": "MJ::JOB::upsample::1::5d32f4e8-8d2f-4bef-82d8-bf517e3c3660",
* "style": 2,
* "label": "U1",
* "type": 2
* },
* ]
*
* @author fansili
* @time 2024/5/8 14:44
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiImageMidjourneyOperationsVO {
private String custom_id;
private String style;
private String label;
private String type;
}

View File

@ -2,9 +2,13 @@ package cn.iocoder.yudao.module.ai.convert;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import org.mapstruct.Mapper;
import org.mapstruct.factory.Mappers;
import java.util.List;
/**
* ai image convert
*
@ -24,4 +28,12 @@ public interface AiImageConvert {
* @return
*/
AiImageDallDrawingRespVO convertAiImageDallDrawingRespVO(AiImageDallDrawingReqVO req);
/**
* 转换 - AiImageListRespVO
*
* @param list
* @return
*/
List<AiImageListRespVO> convertAiImageListRespVO(List<AiImageDO> list);
}

View File

@ -29,20 +29,37 @@ public class AiImageDO extends BaseDO {
@Schema(description = "提示词")
private String prompt;
@Schema(description = "模型")
@Schema(description = "模型 dall2/dall3、MJ、NIJI")
private String modal;
@Schema(description = "生成图像的尺寸大小。对于dall-e-2模型尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 或 1024x1792。")
private String size;
@Schema(description = "图片地址(自己服务器)")
private String imageUrl;
@Schema(description = "绘画状态:提交、排队、绘画中、绘画完成、绘画失败")
private String drawingStatus;
@Schema(description = "绘画图片地址")
@Schema(description = "绘画图片地址(绘画好的服务器)")
private String drawingImageUrl;
@Schema(description = "绘画错误信息")
private String drawingError;
private String drawingErrorMessage;
// ============ mj 需要字段
@Schema(description = "用户操作的消息编号(MJ返回)")
private String mjMessageId;
@Schema(description = "用户操作的操作编号(MJ返回)")
private String mjOperationId;
@Schema(description = "用户操作的操作名字(MJ返回)")
private String mjOperationName;
@Schema(description = "mj图片生产成功保存的 components json 数组")
private String mjOperations;
}

View File

@ -1,8 +1,9 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReqVO;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import java.util.List;
/**
* ai 作图
@ -13,6 +14,14 @@ import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq
*/
public interface AiImageService {
/**
* ai绘画 - 列表
*
* @param req
* @return
*/
PageResult<AiImageListRespVO> list(AiImageListReqVO req);
/**
* ai绘画 - dall2/dall3 绘画
*
@ -27,4 +36,12 @@ public interface AiImageService {
* @return
*/
void midjourney(AiImageMidjourneyReqVO req);
/**
* midjourney 操作(u1u2放大换一批...)
*
* @param req
*/
void midjourneyOperate(AiImageMidjourneyOperateReqVO req);
}

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.service.impl;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.exception.AiException;
import cn.iocoder.yudao.framework.ai.image.ImageGeneration;
import cn.iocoder.yudao.framework.ai.image.ImagePrompt;
@ -9,18 +10,20 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum;
import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageStyleEnum;
import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
import cn.iocoder.yudao.framework.ai.midjourney.api.req.ReRollReq;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify;
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
import cn.iocoder.yudao.module.ai.service.AiImageService;
import jakarta.annotation.PostConstruct;
import lombok.AllArgsConstructor;
@ -28,6 +31,9 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.Collections;
import java.util.List;
/**
* ai 作图
*
@ -61,6 +67,23 @@ public class AiImageServiceImpl implements AiImageService {
});
}
@Override
public PageResult<AiImageListRespVO> list(AiImageListReqVO req) {
// 获取登录用户
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询当前用户下所有的绘画记录
PageResult<AiImageDO> pageResult = aiImageMapper.selectPage(req,
new LambdaQueryWrapperX<AiImageDO>()
.eq(AiImageDO::getUserId, loginUserId)
.orderByDesc(AiImageDO::getId)
);
// 转换 PageResult<AiImageListRespVO> 返回
PageResult<AiImageListRespVO> result = new PageResult<>();
result.setTotal(pageResult.getTotal());
result.setList(AiImageConvert.INSTANCE.convertAiImageListRespVO(pageResult.getList()));
return result;
}
@Override
public AiImageDallDrawingRespVO dallDrawing(AiImageDallDrawingReqVO req) {
// 获取 model
@ -79,7 +102,8 @@ public class AiImageServiceImpl implements AiImageService {
ImageGeneration imageGeneration = imageResponse.getResult();
// 保存数据库
doSave(req.getPrompt(), req.getSize(), req.getModal(),
imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
imageGeneration.getOutput().getUrl(), AiImageDrawingStatusEnum.COMPLETE, null,
null, null, null);
// 返回 flex
respVO.setUrl(imageGeneration.getOutput().getUrl());
respVO.setBase64(imageGeneration.getOutput().getB64Json());
@ -87,7 +111,8 @@ public class AiImageServiceImpl implements AiImageService {
} catch (AiException aiException) {
// 保存数据库
doSave(req.getPrompt(), req.getSize(), req.getModal(),
null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
null, AiImageDrawingStatusEnum.FAIL, aiException.getMessage(),
null, null, null);
// 发送错误信息
respVO.setErrorMessage(aiException.getMessage());
return respVO;
@ -99,7 +124,8 @@ public class AiImageServiceImpl implements AiImageService {
public void midjourney(AiImageMidjourneyReqVO req) {
// 保存数据库
AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
null, AiChatDrawingStatusEnum.SUBMIT, null);
null, AiImageDrawingStatusEnum.SUBMIT, null,
null, null, null);
// 提交 midjourney 任务
Boolean imagine = midjourneyInteractionsApi.imagine(aiImageDO.getId(), req.getPrompt());
if (!imagine) {
@ -107,23 +133,71 @@ public class AiImageServiceImpl implements AiImageService {
}
}
// private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
// try {
// sseEmitter.send(object, MediaType.APPLICATION_JSON);
// } catch (IOException e) {
// throw new RuntimeException(e);
// } finally {
// // 发送 complete
// sseEmitter.complete();
// }
// }
@Transactional(rollbackFor = Exception.class)
@Override
public void midjourneyOperate(AiImageMidjourneyOperateReqVO req) {
// 校验是否存在
AiImageDO aiImageDO = validateExists(req);
// 获取 midjourneyOperations
List<AiImageMidjourneyOperationsVO> midjourneyOperations = getMidjourneyOperations(aiImageDO);
// 校验 OperateId 是否存在
AiImageMidjourneyOperationsVO midjourneyOperationsVO = validateMidjourneyOperationsExists(midjourneyOperations, req.getOperateId());
// 校验 messageId
validateMessageId(aiImageDO.getMjMessageId(), req.getMessageId());
// 获取 mjOperationName
String mjOperationName = midjourneyOperationsVO.getLabel();
// 保存一个 image 任务记录
doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModal(),
null, AiImageDrawingStatusEnum.SUBMIT, null,
req.getMessageId(), req.getOperateId(), mjOperationName);
// 提交操作
midjourneyInteractionsApi.reRoll(
new ReRollReq()
.setCustomId(req.getOperateId())
.setMessageId(req.getMessageId())
);
}
private void validateMessageId(String mjMessageId, String messageId) {
if (!mjMessageId.equals(messageId)) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_MESSAGE_ID_INCORRECT);
}
}
private AiImageMidjourneyOperationsVO validateMidjourneyOperationsExists(List<AiImageMidjourneyOperationsVO> midjourneyOperations, String operateId) {
for (AiImageMidjourneyOperationsVO midjourneyOperation : midjourneyOperations) {
if (midjourneyOperation.getCustom_id().equals(operateId)) {
return midjourneyOperation;
}
}
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_OPERATION_NOT_EXISTS);
}
private List<AiImageMidjourneyOperationsVO> getMidjourneyOperations(AiImageDO aiImageDO) {
if (StrUtil.isBlank(aiImageDO.getMjOperations())) {
return Collections.emptyList();
}
return JsonUtils.parseArray(aiImageDO.getMjOperations(), AiImageMidjourneyOperationsVO.class);
}
private AiImageDO validateExists(AiImageMidjourneyOperateReqVO req) {
AiImageDO aiImageDO = aiImageMapper.selectById(req.getId());
if (aiImageDO == null) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
}
return aiImageDO;
}
private AiImageDO doSave(String prompt,
String size,
String model,
String imageUrl,
AiChatDrawingStatusEnum drawingStatusEnum,
String drawingError) {
String size,
String model,
String drawingImageUrl,
AiImageDrawingStatusEnum drawingStatusEnum,
String drawingErrorMessage,
String mjMessageId,
String mjOperationId,
String mjOperationName) {
// 保存数据库
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
AiImageDO aiImageDO = new AiImageDO();
@ -132,9 +206,15 @@ public class AiImageServiceImpl implements AiImageService {
aiImageDO.setSize(size);
aiImageDO.setModal(model);
aiImageDO.setUserId(loginUserId);
aiImageDO.setDrawingImageUrl(imageUrl);
// TODO @芋艿 如何上传到自己服务器
aiImageDO.setImageUrl(null);
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
aiImageDO.setDrawingError(drawingError);
aiImageDO.setDrawingImageUrl(drawingImageUrl);
aiImageDO.setDrawingErrorMessage(drawingErrorMessage);
//
aiImageDO.setMjMessageId(mjMessageId);
aiImageDO.setMjOperationId(mjOperationId);
aiImageDO.setMjOperationName(mjOperationName);
aiImageMapper.insert(aiImageDO);
return aiImageDO;
}

View File

@ -6,7 +6,7 @@ import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
import com.alibaba.fastjson2.JSON;
import lombok.AllArgsConstructor;
@ -53,14 +53,14 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
imageUrl = midjourneyMessage.getAttachments().get(0).getUrl();
}
// 转换状态
AiChatDrawingStatusEnum drawingStatusEnum = null;
AiImageDrawingStatusEnum drawingStatusEnum = null;
String generateStatus = midjourneyMessage.getGenerateStatus();
if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiChatDrawingStatusEnum.COMPLETE;
drawingStatusEnum = AiImageDrawingStatusEnum.COMPLETE;
} else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiChatDrawingStatusEnum.IN_PROGRESS;
drawingStatusEnum = AiImageDrawingStatusEnum.IN_PROGRESS;
} else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiChatDrawingStatusEnum.WAITING;
drawingStatusEnum = AiImageDrawingStatusEnum.WAITING;
}
aiImageMapper.updateById(
new AiImageDO()