【优化】dall保存

This commit is contained in:
cherishsince 2024-05-28 14:47:56 +08:00
parent 9878abb03c
commit 7268f002d8
3 changed files with 31 additions and 41 deletions

View File

@ -1,12 +1,16 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import com.baomidou.mybatisplus.annotation.FieldFill;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.experimental.Accessors;
import java.time.LocalDateTime;
/**
* midjourney req
*
@ -50,6 +54,12 @@ public class AiImageListRespVO extends PageParam {
@Schema(description = "是否发布")
private String publicStatus;
@Schema(description = "创建时间")
private LocalDateTime createTime;
@Schema(description = "更新时间")
private LocalDateTime updateTime;
// ============ mj 需要字段
@Schema(description = "用户操作的Nonce编号(MJ返回)")

View File

@ -62,4 +62,12 @@ public interface AiImageConvert {
* @return
*/
AiImageMidjourneyOperationsVO convertAiImageMidjourneyOperationsVO(MidjourneyMessage.Component component);
/**
* 转换 - AiImageDO
*
* @param req
* @return
*/
AiImageDO convertAiImageDO(AiImageDallReqVO req);
}

View File

@ -100,10 +100,12 @@ public class AiImageServiceImpl implements AiImageService {
@Override
public AiImageDallRespVO dallDrawing(AiImageDallReqVO req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 保存数据库
AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
null, null, AiImageStatusEnum.IN_PROGRESS, null,
null, null, null);
AiImageDO aiImageDO = AiImageConvert.INSTANCE.convertAiImageDO(req);
aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
aiImageDO.setUserId(loginUserId);
aiImageMapper.insert(aiImageDO);
// 异步执行
EXECUTOR.execute(() -> {
try {
@ -149,9 +151,10 @@ public class AiImageServiceImpl implements AiImageService {
public void midjourney(AiImageMidjourneyReqVO req) {
// 保存数据库
String messageId = String.valueOf(IdUtil.getSnowflakeNextId());
AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
null, null, AiImageStatusEnum.SUBMIT, null,
messageId, null, null);
// todo
// AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
// null, null, AiImageStatusEnum.SUBMIT, null,
// messageId, null, null);
// 提交 midjourney 任务
Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt());
if (!imagine) {
@ -173,9 +176,10 @@ public class AiImageServiceImpl implements AiImageService {
// 获取 mjOperationName
String mjOperationName = midjourneyOperationsVO.getLabel();
// 保存一个 image 任务记录
doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModel(),
null, null, AiImageStatusEnum.SUBMIT, null,
req.getMessageId(), req.getOperateId(), mjOperationName);
// todo
// doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModel(),
// null, null, AiImageStatusEnum.SUBMIT, null,
// req.getMessageId(), req.getOperateId(), mjOperationName);
// 提交操作
midjourneyInteractionsApi.reRoll(
new ReRollReq()
@ -222,36 +226,4 @@ public class AiImageServiceImpl implements AiImageService {
}
return aiImageDO;
}
private AiImageDO doSave(String prompt,
String size,
String model,
String picUrl,
String originalPicUrl,
AiImageStatusEnum statusEnum,
String errorMessage,
String mjMessageId,
String mjOperationId,
String mjOperationName) {
// 保存数据库
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
AiImageDO aiImageDO = new AiImageDO();
aiImageDO.setId(null);
aiImageDO.setPrompt(prompt);
aiImageDO.setSize(size);
aiImageDO.setModel(model);
aiImageDO.setUserId(loginUserId);
// TODO @芋艿 如何上传到自己服务器
aiImageDO.setPicUrl(null);
aiImageDO.setStatus(statusEnum.getStatus());
aiImageDO.setPicUrl(picUrl);
aiImageDO.setOriginalPicUrl(originalPicUrl);
aiImageDO.setErrorMessage(errorMessage);
//
aiImageDO.setMjNonceId(mjMessageId);
aiImageDO.setMjOperationId(mjOperationId);
aiImageDO.setMjOperationName(mjOperationName);
aiImageMapper.insert(aiImageDO);
return aiImageDO;
}
}