【新增】AI:进一步统一 DALL、SD 的绘制实现

This commit is contained in:
YunaiV 2024-06-01 14:16:37 +08:00
parent 0563503102
commit 6856f5f192
34 changed files with 220 additions and 587 deletions

View File

@ -1,21 +0,0 @@
package cn.iocoder.yudao.module.ai;
/**
* ai 常用的常量
*
* @author fansili
* @time 2024/5/7 09:29
* @since 1.0
*/
public class AiCommonConstants {
/**
* 绘画 request - style
*/
public static final String DRAW_REQ_KEY_STYLE = "style";
/**
* dall size - 模板(1024x1024)
*/
public static final String DALL_SIZE_TEMPLATE = "%sx%s";
}

View File

@ -0,0 +1,4 @@
/**
* 占位没有特别的作用
*/
package cn.iocoder.yudao.module.ai.api;

View File

@ -1,37 +0,0 @@
package cn.iocoder.yudao.module.ai.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 对话类型
* 创建对话继续对话
*
* @author fansili
* @time 2024/4/14 18:15
* @since 1.0
*/
@AllArgsConstructor
@Getter
public enum AiChatConversationTypeEnum {
// roleChatuserChat
ROLE_CHAT("roleChat", "角色对话"),
USER_CHAT("userChat", "用户对话"),
;
private String type;
private String name;
public static AiChatConversationTypeEnum valueOfType(String type) {
for (AiChatConversationTypeEnum itemEnum : AiChatConversationTypeEnum.values()) {
if (itemEnum.getType().equals(type)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + type);
}
}

View File

@ -1,37 +0,0 @@
package cn.iocoder.yudao.module.ai.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 聊天role 分类
*
* @author fansili
* @time 2024/4/24 16:41
* @since 1.0
*/
@AllArgsConstructor
@Getter
public enum AiChatRoleCategoryEnum {
WRITING("writing", "写作"),
ENTERTAINMENT("entertainment", "娱乐"),
;
private String category;
private String name;
public static AiChatRoleCategoryEnum valueOfCategory(String category) {
for (AiChatRoleCategoryEnum itemEnum : AiChatRoleCategoryEnum.values()) {
if (itemEnum.getCategory().equals(category)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + category);
}
}

View File

@ -1,36 +0,0 @@
package cn.iocoder.yudao.module.ai.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* ai绘画 public 状态
*
* @author fansili
* @time 2024/4/28 17:05
* @since 1.0
*/
@AllArgsConstructor
@Getter
public enum AiImagePublicStatusEnum {
PRIVATE("private", "私有"),
PUBLIC("public", "公开"),
;
// TODO @fanfinal 一下
private final String status;
private final String name;
public static AiImagePublicStatusEnum valueOfStatus(String status) {
for (AiImagePublicStatusEnum itemEnum : AiImagePublicStatusEnum.values()) {
if (itemEnum.getStatus().equals(status)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + status);
}
}

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai;
package cn.iocoder.yudao.module.ai.enums;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.enums;
package cn.iocoder.yudao.module.ai.enums.image;
import lombok.AllArgsConstructor;
import lombok.Getter;
@ -13,7 +13,7 @@ import lombok.Getter;
public enum AiImageStatusEnum {
IN_PROGRESS("10", "进行中"),
COMPLETE("20", "完成"),
SUCCESS("20", "完成"),
FAIL("30", "失败");
/**

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.enums;
package cn.iocoder.yudao.module.ai.enums.model;
import lombok.AllArgsConstructor;
import lombok.Getter;

View File

@ -1,17 +1,33 @@
### chat dallDrawing
POST {{baseUrl}}/admin-api/ai/image/dallDrawing
### 生成图片OpenAIDALL
POST {{baseUrl}}/ai/image/draw
Content-Type: application/json
Authorization: {{token}}
{
"modal": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"prompt": "中国长城"
"platform": "OpenAI",
"prompt": "可爱的小喵星人",
"model": "dall-e-3",
"height": "1024",
"width": "1024",
"options": {
"style": "vivid"
}
}
### 生成图片StableDiffusion
POST {{baseUrl}}/ai/image/draw
Content-Type: application/json
Authorization: {{token}}
{
"platform": "StableDiffusion",
"prompt": "中国长城",
"model": "stable-diffusion-v1-6",
"height": "1024",
"width": "1024",
"style": "vivid"
}
### chat midjourney

View File

@ -6,7 +6,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@ -48,13 +48,10 @@ public class AiImageController {
return success(BeanUtils.toBean(image, AiImageRespVO.class));
}
// TODO @fan建议把 dallDrawingmidjourney 融合成一个 draw 接口异步绘制然后返回一个 id 给前端前端通过 get 接口轮询直到获取到生成成功
// TODO @芋艿: 参数差异较大
// TODO @fan直接参数平铺写好注释要么
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
@PostMapping("/dall")
public CommonResult<Long> dall(@Validated @RequestBody AiImageDallReqVO req) {
return success(imageService.dall(getLoginUserId(), req));
@Operation(summary = "生成图片")
@PostMapping("/draw")
public CommonResult<Long> drawImage(@Validated @RequestBody AiImageDrawReqVO drawReqVO) {
return success(imageService.drawImage(getLoginUserId(), drawReqVO));
}
@Operation(summary = "删除【我的】绘画记录")
@ -73,11 +70,11 @@ public class AiImageController {
return success(imageService.midjourneyImagine(getLoginUserId(), req));
}
// TODO @fan可以考虑复用 AiImageDallRespVO统一成 AIImageRespVO
// TODO @芋艿不拦截
@Operation(summary = "midjourney proxy - 回调通知")
@RequestMapping("/midjourney-notify")
public CommonResult<Boolean> midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) {
return success(imageService.midjourneyNotify(getLoginUserId(), notifyReqVO));
}
}
}

View File

@ -1,41 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Data;
/**
* dall2/dall2 绘画
*
* @author fansili
* @time 2024/4/25 16:24
* @since 1.0
*/
@Data
public class AiImageDallReqVO {
@Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private String platform; // 参见 AiPlatformEnum 枚举
@Schema(description = "提示词")
@NotNull(message = "提示词不能为空!")
@Size(max = 1200, message = "提示词最大1200")
private String prompt;
@Schema(description = "模型(dall2、dall3)")
@NotNull(message = "模型不能为空")
private String model;
@Schema(description = "图像生成的风格。可为vivid生动或natural自然)")
private String style;
@Schema(description = "图片高度。对于dall-e-2模型尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 或 1024x1792。")
@NotNull(message = "图片高度不能为空!")
private Integer height;
@Schema(description = "图片宽度。对于dall-e-2模型尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 或 1024x1792。")
@NotNull(message = "图片宽度不能为空!")
private Integer width;
}

View File

@ -0,0 +1,52 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Data;
import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import java.util.Map;
@Schema(description = "管理后台 - 绘画 Request VO")
@Data
public class AiImageDrawReqVO {
@Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
private String platform; // 参见 AiPlatformEnum 枚举
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "画一个长城")
@NotEmpty(message = "提示词不能为空")
@Size(max = 1200, message = "提示词最大 1200")
private String prompt;
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "stable-diffusion-v1-6")
@NotEmpty(message = "模型不能为空")
private String model;
/**
* 1. dall-e-2 模型256x256512x5121024x1024
* 2. dall-e-3 模型1024x1024, 1792x1024, 1024x1792
*/
@Schema(description = "图片高度")
@NotNull(message = "图片高度不能为空")
private Integer height;
@Schema(description = "图片宽度")
@NotNull(message = "图片宽度不能为空")
private Integer width;
// ========== 各平台绘画的拓展参数 ==========
/**
* 绘制参数不同 platform 的不同参数
*
* 1. {@link OpenAiImageOptions}
* 2. {@link StabilityAiImageOptions}
*/
@Schema(description = "绘制参数")
private Map<String, String> options;
}

View File

@ -1,31 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* midjourney req
*
* @author fansili
* @time 2024/4/28 17:42
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiImageMidjourneyOperateReqVO {
@NotNull(message = "图片编号不能为空")
@Schema(description = "编号")
private Long id;
@NotNull(message = "消息编号不能为空")
@Schema(description = "消息编号")
private String messageId;
@NotNull(message = "操作编号不能为空")
@Schema(description = "操作编号")
private String operateId;
}

View File

@ -5,63 +5,44 @@ import lombok.Data;
import java.util.Map;
// TODO @芋艿完善 swagger 注解
@Schema(description = "管理后台 - 绘画 Response VO")
@Data
public class AiImageRespVO {
@Schema(description = "id编号", example = "1")
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Long id;
@Schema(description = "用户id", example = "1")
@Schema(description = "用户编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Long userId;
@Schema(description = "提示词", example = "南极的小企鹅")
private String prompt;
@Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
private String platform; // 参见 AiPlatformEnum 枚举
@Schema(description = "平台", example = "openai")
private String platform;
@Schema(description = "模型", example = "dall2")
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "stable-diffusion-v1-6")
private String model;
@Schema(description = "图片宽度", example = "1024")
private String width;
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "南极的小企鹅")
private String prompt;
@Schema(description = "图片高度", example = "1024")
private String height;
@Schema(description = "图片宽度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
private Integer width;
@Schema(description = "绘画状态10 进行中、20 绘画完成、30 绘画失败", example = "10")
private String status;
@Schema(description = "图片高度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
private Integer height;
@Schema(description = "是否发布", example = "public")
private String publicStatus;
@Schema(description = "绘画状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "10")
private Integer status;
@Schema(description = "图片地址(自己服务器)", example = "http://")
@Schema(description = "是否发布", requiredMode = Schema.RequiredMode.REQUIRED, example = "public")
private Boolean publicStatus;
@Schema(description = "图片地址", example = "https://www.iocoder.cn/1.png")
private String picUrl;
@Schema(description = "绘画图片地址(绘画好的服务器)", example = "http://")
private String originalPicUrl;
@Schema(description = "绘画错误信息", example = "图片错误信息")
private String errorMessage;
// ============ 绘画请求参数
// todo @fan下面的 stylemjNonceId 直接就不用注释啦直接去看 DO 完事哈
/**
* - style
*/
@Schema(description = "绘画请求参数")
private Map<String, Object> drawRequest;
/**
* - mjNonceId
* - mjOperationId
* - mjOperationName
* - mjOperations
*/
@Schema(description = "绘画请求响应参数")
private Map<String, Object> drawResponse;
@Schema(description = "绘制参数")
private Map<String, String> options;
}

View File

@ -3,14 +3,13 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.enums.AiModelEnum;
import cn.iocoder.yudao.module.ai.enums.model.AiModelEnum;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.*;
import java.time.LocalDateTime;
import java.util.Date;
/**
* AI Chat 对话 DO

View File

@ -5,7 +5,7 @@ import org.springframework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.enums.AiModelEnum;
import cn.iocoder.yudao.module.ai.enums.model.AiModelEnum;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.*;

View File

@ -1,13 +1,17 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.image;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import java.util.Map;
@ -16,70 +20,86 @@ import java.util.Map;
*
* @author fansili
*/
@TableName("ai_image")
@TableName(value = "ai_image", autoResultMap = true)
@Data
public class AiImageDO extends BaseDO {
// TODO @fan1使用 java 注释哈不要注解2关联枚举字段要关联到对应类参考 AiChatMessageDO 的注释
/**
* 编号
*/
@TableId(type = IdType.AUTO)
private Long id;
@Schema(description = "用户编号")
/**
* 用户编号
*
* 关联 {@link AdminUserRespDTO#getId()}
*/
private Long userId;
@Schema(description = "midjourney proxy 关联的 job id")
private String jobId;
@Schema(description = "提示词")
/**
* 提示词
*/
private String prompt;
@Schema(description = "平台")
/**
* 平台
*
* 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum}
*/
private String platform;
@Schema(description = "模型")
/**
* 模型
*
* 冗余 {@link AiChatModelDO#getModel()}
*/
private String model;
@Schema(description = "图片宽度")
/**
* 图片宽度
*/
private Integer width;
@Schema(description = "图片高度")
/**
* 图片高度
*/
private Integer height;
// TODO @fan这种就注释绘画状态然后枚举类关联下就好啦
@Schema(description = "绘画状态:提交、排队、绘画中、绘画完成、绘画失败")
/**
* 生成状态
*
* 枚举 {@link AiImageStatusEnum}
*/
private String status;
@Schema(description = "是否发布")
private String publicStatus;
@Schema(description = "图片地址(自己服务器)")
/**
* 图片地址
*/
private String picUrl;
// TODO @芋艿可能要删除掉
@Schema(description = "绘画图片地址(绘画好的服务器)")
private String originalPicUrl;
// ============ 绘画请求参数 ============
/**
* 是否公开
*/
private Boolean publicStatus;
/**
* - style
* 绘制参数不同 platform 的不同参数
*
* 1. {@link OpenAiImageOptions}
* 2. {@link StabilityAiImageOptions}
*/
@Schema(description = "绘画请求参数")
@TableField(typeHandler = JacksonTypeHandler.class)
private Map<String, Object> drawRequest;
private Map<String, String> options;
// TODO @芋艿再瞅瞅
/**
* - mjNonceId
* - mjOperationId
* - mjOperationName
* - mjOperations
* midjourney proxy 关联的 job id
*/
private String jobId;
/**
* 绘画错误信息
*/
@Schema(description = "绘画请求响应参数")
@TableField(typeHandler = JacksonTypeHandler.class)
private Map<String, Object> drawResponse;
@Schema(description = "绘画错误信息")
private String errorMessage;
}

View File

@ -25,8 +25,8 @@ import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_MODEL_ERROR;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_MODEL_ERROR;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
/**
* AI 聊天对话 Service 实现类

View File

@ -8,7 +8,7 @@ import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
@ -39,8 +39,8 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
/**
* AI 聊天消息 Service 实现类

View File

@ -3,7 +3,7 @@ package cn.iocoder.yudao.module.ai.service.image;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@ -32,12 +32,12 @@ public interface AiImageService {
AiImageDO getImage(Long id);
/**
* ai绘画 - dall2/dall3 绘画
* 绘制图片
*
* @param loginUserId
* @param req
* @param userId 用户编号
* @param drawReqVO 绘制请求
*/
Long dall(Long loginUserId, AiImageDallReqVO req);
Long drawImage(Long userId, AiImageDrawReqVO drawReqVO);
/**
* midjourney 图片生成

View File

@ -1,19 +1,16 @@
package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.codec.Base64;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.AiCommonConstants;
import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
@ -21,17 +18,18 @@ import cn.iocoder.yudao.module.ai.client.enums.MidjourneyTaskStatusEnum;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum;
import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.infra.api.file.FileApi;
import com.google.common.collect.ImmutableMap;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.image.*;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.beans.factory.annotation.Autowired;
@ -41,7 +39,7 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.AI_IMAGE_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.AI_IMAGE_NOT_EXISTS;
/**
* AI 绘画 Service 实现类
@ -78,22 +76,18 @@ public class AiImageServiceImpl implements AiImageService {
}
@Override
public Long dall(Long userId, AiImageDallReqVO req) {
req.setPlatform("dall"); // TODO 芋艿临时写死
public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
// 1. 保存数据库
AiImageDO image = BeanUtils.toBean(req, AiImageDO.class)
.setUserId(userId).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
.setWidth(req.getWidth()).setHeight(req.getHeight())
.setDrawRequest(ImmutableMap.of(AiCommonConstants.DRAW_REQ_KEY_STYLE, req.getStyle()))
.setPublicStatus(AiImagePublicStatusEnum.PRIVATE.getStatus());
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setWidth(drawReqVO.getWidth()).setHeight(drawReqVO.getHeight()).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
imageMapper.insert(image);
// 2. 异步绘制后续前端通过返回的 id 进行伦旭
getSelf().doDall(image, req);
// 2. 异步绘制后续前端通过返回的 id 进行轮询结果
getSelf().doDall(image, drawReqVO);
return image.getId();
}
@Async
public void doDall(AiImageDO image, AiImageDallReqVO req) {
public void doDall(AiImageDO image, AiImageDrawReqVO req) {
try {
// 1.1 构建请求
ImageOptions request = buildImageOptions(req);
@ -106,7 +100,7 @@ public class AiImageServiceImpl implements AiImageService {
String filePath = fileApi.createFile(fileContent);
// 3. 更新数据库
imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.COMPLETE.getStatus())
imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.SUCCESS.getStatus())
.setPicUrl(filePath));
} catch (Exception ex) {
log.error("[doDall][image({}) 生成异常]", image, ex);
@ -115,30 +109,28 @@ public class AiImageServiceImpl implements AiImageService {
}
}
private static ImageOptions buildImageOptions(AiImageDallReqVO draw) {
if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPEN_AI_DALL.getPlatform())) {
OpenAiImageOptions request = new OpenAiImageOptions();
request.setModel(OpenAiImageModelEnum.valueOfModel(draw.getModel()).getModel());
request.setStyle(OpenAiImageStyleEnum.valueOfStyle(draw.getStyle()).getStyle());
request.setSize(String.format(AiCommonConstants.DALL_SIZE_TEMPLATE, draw.getWidth(), draw.getHeight()));
request.setResponseFormat("b64_json");
return request;
} else {
// https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post
return StabilityAiImageOptions.builder().withModel(draw.getModel())
private static ImageOptions buildImageOptions(AiImageDrawReqVO draw) {
if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
// https://platform.openai.com/docs/api-reference/images/create
return OpenAiImageOptions.builder().withModel(draw.getModel())
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格
.withResponseFormat("b64_json")
.build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
return StabilityAiImageOptions.builder().withModel(draw.getModel())
.withHeight(draw.getHeight()).withWidth(draw.getWidth()) // TODO @芋艿各种参数
.build();
}
// return null;
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
}
@Override
@Transactional(rollbackFor = Exception.class)
public Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req) {
// 1构建 AiImageDO
AiImageDO aiImageDO = new AiImageDO();
aiImageDO.setId(null);
aiImageDO.setUserId(loginUserId);
aiImageDO.setPrompt(req.getPrompt());
aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
@ -147,12 +139,6 @@ public class AiImageServiceImpl implements AiImageService {
aiImageDO.setWidth(null);
aiImageDO.setHeight(null);
aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
aiImageDO.setPublicStatus(AiImagePublicStatusEnum.PRIVATE.getStatus());
aiImageDO.setPicUrl(null);
aiImageDO.setOriginalPicUrl(null);
aiImageDO.setDrawRequest(null);
aiImageDO.setDrawResponse(null);
aiImageDO.setErrorMessage(null);
// 2保存 image
imageMapper.insert(aiImageDO);
@ -211,7 +197,7 @@ public class AiImageServiceImpl implements AiImageService {
//
String imageStatus = null;
if (MidjourneyTaskStatusEnum.SUCCESS == notifyReqVO.getStatus()) {
imageStatus = AiImageStatusEnum.COMPLETE.getStatus();
imageStatus = AiImageStatusEnum.SUCCESS.getStatus();
} else if (MidjourneyTaskStatusEnum.FAILURE == notifyReqVO.getStatus()) {
imageStatus = AiImageStatusEnum.FAIL.getStatus();
}
@ -226,8 +212,7 @@ public class AiImageServiceImpl implements AiImageService {
.setId(image.getId())
.setStatus(imageStatus)
.setPicUrl(filePath)
.setOriginalPicUrl(notifyReqVO.getImageUrl())
.setDrawResponse(BeanUtil.beanToMap(notifyReqVO))
// .setOriginalPicUrl(notifyReqVO.getImageUrl()) TODO @fan就不存原始的图片地址啦
);
return true;
}

View File

@ -5,7 +5,7 @@ import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperationsVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import com.alibaba.fastjson.JSON;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@ -89,7 +89,7 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
AiImageStatusEnum drawingStatusEnum = null;
String generateStatus = midjourneyMessage.getGenerateStatus();
if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiImageStatusEnum.COMPLETE;
drawingStatusEnum = AiImageStatusEnum.SUCCESS;
} else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiImageStatusEnum.IN_PROGRESS;
}
@ -101,7 +101,6 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
// 更新数据库
aiImageMapper.updateByMjNonce(nonceId,
new AiImageDO()
.setOriginalPicUrl(imageUrl)
.setStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus())
// .setMjNonceId(midjourneyMessage.getId())
// .setMjOperations(JsonUtils.toJsonString(midjourneyOperations))

View File

@ -1,35 +0,0 @@
package cn.iocoder.yudao.module.ai.service.image.midjourneyHandler.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* chat config
*
* @author fansili
* @time 2024/5/6 15:06
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiChatModalChatConfigVO extends AiChatModalConfigVO {
@NotNull
@Schema(description = "在生成消息时采用的Top-K采样大小")
private Double topK;
@NotNull
@Schema(description = "Top-P核采样方法的概率阈值")
private Double topP;
@NotNull
@Schema(description = "温度参数,用于调整生成回复的随机性和多样性程度")
private Double temperature;
@NotNull
@Schema(description = "最大 tokens")
private Integer maxTokens;
}

View File

@ -1,31 +0,0 @@
package cn.iocoder.yudao.module.ai.service.image.midjourneyHandler.vo;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import lombok.Data;
import lombok.experimental.Accessors;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
/**
* modal config
*
* @author fansili
* @time 2024/5/6 15:06
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiChatModalConfigVO {
/**
* 模型平台 (冗余方便类型转换)
* 参考{@link AiPlatformEnum}
*/
private String platform;
/**
* 模型类型(冗余方便类型转换)
* {@link YiYanChatModel}
* {@link XingHuoChatModel}
*/
private String type;
}

View File

@ -1,41 +0,0 @@
package cn.iocoder.yudao.module.ai.service.image.midjourneyHandler.vo;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* dall
*
* @author fansili
* @time 2024/5/6 15:06
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiChatModalDallConfigVO extends AiChatModalConfigVO {
// 可选字段默认为1
// 生成图像的数量必须在1到10之间对于dall-e-3模型目前仅支持n=1
private Integer n = 1;
// 可选字段默认为standard
// 设置生成图像的质量hd质量将创建细节更丰富图像整体一致性更高的图片该参数仅对dall-e-3模型有效
private String quality = "standard";
// 可选字段默认为url
// 返回生成图像的格式必须是url或b64_json中的一种URL链接的有效期是从生成图像后开始计算的60分钟内有效
private String responseFormat = "url";
// 可选字段默认为1024x1024
// 生成图像的尺寸大小对于dall-e-2模型尺寸可为256x256, 512x512, 1024x1024对于dall-e-3模型尺寸可为1024x1024, 1792x1024, 1024x1792
private String size = "1024x1024";
// 可选字段默认为vivid
// 图像生成的风格可为vivid生动或natural自然vivid会使模型偏向生成超现实和戏剧性的图像而natural则会让模型产出更自然不那么超现实的图像该参数仅对dall-e-3模型有效
private OpenAiImageStyleEnum style = OpenAiImageStyleEnum.VIVID;
// 可选字段
// 代表您的终端用户的唯一标识符有助于OpenAI监控并检测滥用行为了解更多信息请参考官方文档
private String endUserId = "UID123456";
}

View File

@ -1,16 +0,0 @@
package cn.iocoder.yudao.module.ai.service.image.midjourneyHandler.vo;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* Midjourney Config
*
* @author fansili
* @time 2024/5/6 15:07
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiChatModalMidjourneyConfigVO extends AiChatModalConfigVO {
}

View File

@ -17,7 +17,7 @@ import org.springframework.validation.annotation.Validated;
import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.*;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
/**
* AI API 密钥 Service 实现类

View File

@ -16,7 +16,7 @@ import java.util.Collection;
import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.*;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
/**
* AI 聊天模型 Service 实现类

View File

@ -21,7 +21,7 @@ import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.*;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
/**
* AI 聊天角色 Service 实现类

View File

@ -3,10 +3,9 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
import lombok.Data;
import lombok.experimental.Accessors;
import org.springframework.ai.autoconfigure.openai.OpenAiImageProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
/**
@ -92,27 +91,6 @@ public class YudaoAiProperties {
private int refreshTokenSecondTime = 86400;
}
@Data
@Accessors(chain = true)
public static class OpenAiImageProperties {
private boolean enable = false;
/**
* api key
*/
private String apiKey;
/**
* 模型
*/
private OpenAiImageModelEnum model = OpenAiImageModelEnum.DALL_E_2;
/**
* 风格
*/
private OpenAiImageStyleEnum style = OpenAiImageStyleEnum.VIVID;
}
@Data
@Accessors(chain = true)
public static class MidjourneyProperties {

View File

@ -20,7 +20,6 @@ public enum AiPlatformEnum {
QIAN_WEN("QianWen", "千问"), // 阿里
GEMIR ("gemir ", "gemir "), // 谷歌
OPEN_AI_DALL("dall", "dall"), // TODO OpenAI 提供的绘图接入中TODO 要不要统一下
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
MIDJOURNEY("midjourney", "midjourney"), // TODO MJ 提供的绘图接入中
;

View File

@ -1,37 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
// TODO 芋艿待梳理
/**
* open ai
*
* @author fansili
* @time 2024/4/28 14:21
* @since 1.0
*/
@AllArgsConstructor
@Getter
@Deprecated
public enum OpenAiImageModelEnum {
DALL_E_2("dall-e-2", "dall-e-2"),
DALL_E_3("dall-e-3", "dall-e-3")
;
private String model;
private String name;
public static OpenAiImageModelEnum valueOfModel(String model) {
for (OpenAiImageModelEnum itemEnum : OpenAiImageModelEnum.values()) {
if (itemEnum.getModel().equals(model)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + model);
}
}

View File

@ -1,38 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
// TODO 芋艿待梳理
/**
* open ai image style
*
* @author fansili
* @time 2024/4/28 16:15
* @since 1.0
*/
@AllArgsConstructor
@Getter
@Deprecated
public enum OpenAiImageStyleEnum {
// 图像生成的风格可为vivid生动 natural自然vivid会使模型偏向生成超现实和戏剧性的图像而natural则会让模型产出更自然不那么超现实的图像该参数仅对dall-e-3模型有效
VIVID("vivid", "生动"),
NATURAL("natural", "自然"),
;
private String style;
private String name;
public static OpenAiImageStyleEnum valueOfStyle(String style) {
for (OpenAiImageStyleEnum itemEnum : OpenAiImageStyleEnum.values()) {
if (itemEnum.getStyle().equals(style)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + style);
}
}

View File

@ -33,6 +33,7 @@ import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.stabilityai.StabilityAiImageClient;
import java.util.List;
@ -88,12 +89,15 @@ public class AiClientFactoryImpl implements AiClientFactory {
@Override
public ImageClient getDefaultImageClient(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPEN_AI_DALL:
case OPENAI:
return SpringUtil.getBean(OpenAiImageClient.class);
case STABLE_DIFFUSION:
return SpringUtil.getBean(StabilityAiImageClient.class);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
return null;
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) {