【新增】AI:绘图(MJ)接入 API KEY 管理

This commit is contained in:
YunaiV 2024-06-29 09:46:44 +08:00
parent 949d5a1815
commit 6225e18f70
9 changed files with 62 additions and 76 deletions

View File

@ -9,7 +9,7 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdatePublicStatusReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@ -114,11 +114,11 @@ public class AiImageController {
return success(BeanUtils.toBean(pageResult, AiImageRespVO.class)); return success(BeanUtils.toBean(pageResult, AiImageRespVO.class));
} }
@PutMapping("/update-public-status") @PutMapping("/update")
@Operation(summary = "更新绘画发布状态") @Operation(summary = "更新绘画")
@PreAuthorize("@ss.hasPermission('ai:image:update')") @PreAuthorize("@ss.hasPermission('ai:image:update')")
public CommonResult<Boolean> updateImagePublicStatus(@Valid @RequestBody AiImageUpdatePublicStatusReqVO updateReqVO) { public CommonResult<Boolean> updateImage(@Valid @RequestBody AiImageUpdateReqVO updateReqVO) {
imageService.updateImagePublicStatus(updateReqVO); imageService.updateImage(updateReqVO);
return success(true); return success(true);
} }

View File

@ -4,15 +4,15 @@ import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
@Schema(description = "管理后台 - AI 绘画修改发布状态 Request VO") @Schema(description = "管理后台 - AI 绘画修改 Request VO")
@Data @Data
public class AiImageUpdatePublicStatusReqVO { public class AiImageUpdateReqVO {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "15583") @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "15583")
@NotNull(message = "编号不能为空")
private Long id; private Long id;
@Schema(description = "是否发布", requiredMode = Schema.RequiredMode.REQUIRED, example = "true") @Schema(description = "是否发布", example = "true")
@NotNull(message = "是否发布不能为空")
private Boolean publicStatus; private Boolean publicStatus;
} }

View File

@ -5,7 +5,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdatePublicStatusReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@ -71,11 +71,11 @@ public interface AiImageService {
PageResult<AiImageDO> getImagePage(AiImagePageReqVO pageReqVO); PageResult<AiImageDO> getImagePage(AiImagePageReqVO pageReqVO);
/** /**
* 更新绘画发布状态 * 更新绘画
* *
* @param updateReqVO 更新信息 * @param updateReqVO 更新信息
*/ */
void updateImagePublicStatus(@Valid AiImageUpdatePublicStatusReqVO updateReqVO); void updateImage(@Valid AiImageUpdateReqVO updateReqVO);
/** /**
* 删除绘画 * 删除绘画

View File

@ -15,7 +15,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdatePublicStatusReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@ -62,9 +62,6 @@ public class AiImageServiceImpl implements AiImageService {
@Resource @Resource
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Resource
private MidjourneyApi midjourneyApi;
@Override @Override
public PageResult<AiImageDO> getImagePageMy(Long userId, PageParam pageReqVO) { public PageResult<AiImageDO> getImagePageMy(Long userId, PageParam pageReqVO) {
return imageMapper.selectPage(userId, pageReqVO); return imageMapper.selectPage(userId, pageReqVO);
@ -151,7 +148,7 @@ public class AiImageServiceImpl implements AiImageService {
} }
@Override @Override
public void updateImagePublicStatus(AiImageUpdatePublicStatusReqVO updateReqVO) { public void updateImage(AiImageUpdateReqVO updateReqVO) {
// 1. 校验存在 // 1. 校验存在
validateImageExists(updateReqVO.getId()); validateImageExists(updateReqVO.getId());
// 2. 更新发布状态 // 2. 更新发布状态
@ -179,6 +176,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) { public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
// 1. 保存数据库 // 1. 保存数据库
AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false) AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()) .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
@ -206,6 +204,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
public Integer midjourneySync() { public Integer midjourneySync() {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
// 1.1 获取 Midjourney 平台状态在 进行中 image // 1.1 获取 Midjourney 平台状态在 进行中 image
List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform( List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform()); AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
@ -272,6 +271,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) { public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
// 1.1 检查 image // 1.1 检查 image
AiImageDO image = validateImageExists(reqVO.getId()); AiImageDO image = validateImageExists(reqVO.getId());
if (ObjUtil.notEqual(userId, image.getUserId())) { if (ObjUtil.notEqual(userId, image.getUserId())) {

View File

@ -1,58 +0,0 @@
package cn.iocoder.yudao.module.ai.service.image;
import lombok.Data;
import org.springframework.ai.image.ImageOptions;
/**
* @author fansili
* @time 2024/6/5 10:34
* @since 1.0
*/
@Data
public class MidjourneyImageOptions implements ImageOptions {
/**
* 模型
*/
private String model;
/**
* 宽度
*/
private Integer width;
/**
* 高度
*/
private Integer height;
/**
* 版本
*/
private String version;
/**
* 参数
*/
private String state;
@Override
public Integer getN() {
return 0;
}
@Override
public String getModel() {
return model;
}
@Override
public Integer getWidth() {
return width;
}
@Override
public Integer getHeight() {
return height;
}
@Override
public String getResponseFormat() {
return "";
}
}

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.model; package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
@ -92,6 +93,15 @@ public interface AiApiKeyService {
*/ */
ImageClient getImageClient(AiPlatformEnum platform); ImageClient getImageClient(AiPlatformEnum platform);
/**
* 获得 MidjourneyApi 对象
*
* TODO 可优化点目前默认获取 Midjourney 对应的第一个开启的配置用于绘画后续可以支持配置选择
*
* @return MidjourneyApi 对象
*/
MidjourneyApi getMidjourneyApi();
/** /**
* 获得 SunoApi 对象 * 获得 SunoApi 对象
* *

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
@ -112,6 +113,16 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl()); return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl());
} }
@Override
public MidjourneyApi getMidjourneyApi() {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) {
return null;
}
return clientFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override @Override
public SunoApi getSunoApi() { public SunoApi getSunoApi() {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus( AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.framework.ai.core.factory; package cn.iocoder.yudao.framework.ai.core.factory;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.image.ImageClient; import org.springframework.ai.image.ImageClient;
@ -56,6 +57,17 @@ public interface AiClientFactory {
*/ */
ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url); ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于指定配置获得 MidjourneyApi 对象
*
* 如果不存在则进行创建
*
* @param apiKey API KEY
* @param url API URL
* @return MidjourneyApi 对象
*/
MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url);
/** /**
* 基于指定配置获得 SunoApi 对象 * 基于指定配置获得 SunoApi 对象
* *

View File

@ -9,6 +9,7 @@ import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration; import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties; import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
@ -110,9 +111,19 @@ public class AiClientFactoryImpl implements AiClientFactory {
} }
} }
@Override
public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) {
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey, url);
return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> {
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class).getMidjourney();
return new MidjourneyApi(url, apiKey, properties.getNotifyUrl());
});
}
@Override @Override
public SunoApi getOrCreateSunoApi(String apiKey, String url) { public SunoApi getOrCreateSunoApi(String apiKey, String url) {
return new SunoApi(url); String cacheKey = buildClientCacheKey(SunoApi.class, AiPlatformEnum.SUNO.getPlatform(), apiKey, url);
return Singleton.get(cacheKey, (Func0<SunoApi>) () -> new SunoApi(url));
} }
private static String buildClientCacheKey(Class<?> clazz, Object... params) { private static String buildClientCacheKey(Class<?> clazz, Object... params) {