【代码优化】AI:MJ 配置类的简化

This commit is contained in:
YunaiV 2024-06-25 21:27:59 +08:00
parent 4c3add508b
commit b4eed07d61
7 changed files with 81 additions and 104 deletions

View File

@ -1,24 +0,0 @@
package cn.iocoder.yudao.module.ai.config;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* 配置
*
* @author fansili
* @time 2024/6/13 09:50
*/
@Configuration
public class YudaoMidjourneyConfiguration {
@Bean
@ConditionalOnProperty(value = "ai.midjourney-proxy.enable", havingValue = "true")
public MidjourneyApi midjourneyApi(YudaoMidjourneyProperties midjourneyProperties) {
return new MidjourneyApi(BeanUtils.toBean(midjourneyProperties, MidjourneyConfig.class));
}
}

View File

@ -1,22 +0,0 @@
package cn.iocoder.yudao.module.ai.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;
/**
* Midjourney 属性
*
* @author fansili
* @time 2024/6/5 15:02
* @since 1.0
*/
@Configuration
@ConfigurationProperties(prefix = "ai.midjourney-proxy")
@Data
public class YudaoMidjourneyProperties {
private String enable;
private String key;
private String url;
}

View File

@ -29,7 +29,6 @@ import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
@ -63,9 +62,6 @@ public class AiImageServiceImpl implements AiImageService {
@Resource @Resource
private MidjourneyApi midjourneyApi; private MidjourneyApi midjourneyApi;
@Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}")
private String midjourneyNotifyUrl;
@Override @Override
public PageResult<AiImageDO> getImagePageMy(Long userId, PageParam pageReqVO) { public PageResult<AiImageDO> getImagePageMy(Long userId, PageParam pageReqVO) {
return imageMapper.selectPage(userId, pageReqVO); return imageMapper.selectPage(userId, pageReqVO);
@ -159,7 +155,7 @@ public class AiImageServiceImpl implements AiImageService {
// 2. 调用 Midjourney Proxy 提交任务 // 2. 调用 Midjourney Proxy 提交任务
MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest( MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(
null, midjourneyNotifyUrl, reqVO.getPrompt(), null, reqVO.getPrompt(),null,
MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(), reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel())); MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(), reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel()));
MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest); MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest);
@ -258,7 +254,7 @@ public class AiImageServiceImpl implements AiImageService {
// 2. 调用 Midjourney Proxy 提交任务 // 2. 调用 Midjourney Proxy 提交任务
MidjourneyApi.SubmitResponse actionResponse = midjourneyApi.action( MidjourneyApi.SubmitResponse actionResponse = midjourneyApi.action(
new MidjourneyApi.ActionRequest(button.customId(), image.getTaskId(), midjourneyNotifyUrl)); new MidjourneyApi.ActionRequest(button.customId(), image.getTaskId(), null));
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(actionResponse.code())) { if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(actionResponse.code())) {
String description = actionResponse.description().contains("quota_not_enough") ? String description = actionResponse.description().contains("quota_not_enough") ?
"账户余额不足" : actionResponse.description(); "账户余额不足" : actionResponse.description();

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl; import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
@ -96,6 +97,13 @@ public class YudaoAiAutoConfiguration {
); );
} }
@Bean
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
public MidjourneyApi midjourneyApi(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.MidjourneyProperties config = yudaoAiProperties.getMidjourney();
return new MidjourneyApi(config.getBaseUrl(), config.getApiKey(), config.getNotifyUrl());
}
@Bean @Bean
@ConditionalOnProperty(value = "yudao.ai.suno.enable", havingValue = "true") @ConditionalOnProperty(value = "yudao.ai.suno.enable", havingValue = "true")
public SunoApi sunoApi(YudaoAiProperties yudaoAiProperties) { public SunoApi sunoApi(YudaoAiProperties yudaoAiProperties) {

View File

@ -64,15 +64,18 @@ public class YudaoAiProperties {
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public static class XingHuoProperties extends ChatProperties { public static class XingHuoProperties extends ChatProperties {
private String appId; private String appId;
private String appKey; private String appKey;
private String secretKey; private String secretKey;
private XingHuoChatModel model; private XingHuoChatModel model;
} }
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public static class YiYanProperties extends ChatProperties { public static class YiYanProperties extends ChatProperties {
/** /**
* appKey * appKey
*/ */
@ -92,26 +95,13 @@ public class YudaoAiProperties {
} }
@Data @Data
@Accessors(chain = true)
public static class MidjourneyProperties { public static class MidjourneyProperties {
private boolean enable = false;
/** private String enable;
* socket 链接地址 private String apiKey;
*/ private String baseUrl;
private String wssUrl = "wss://gateway.discord.gg"; private String notifyUrl;
/**
* token
*/
private String token;
/**
* 服务id
*/
private String guildId;
/**
* 频道id
*/
private String channelId;
} }
@Data @Data

View File

@ -1,10 +1,11 @@
package cn.iocoder.yudao.framework.ai.core.model.midjourney.api; package cn.iocoder.yudao.framework.ai.core.model.midjourney.api;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.MidjourneyConfig; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils; import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.openai.api.ApiUtils; import org.springframework.ai.openai.api.ApiUtils;
@ -26,11 +27,17 @@ public class MidjourneyApi {
private final WebClient webClient; private final WebClient webClient;
public MidjourneyApi(MidjourneyConfig midjourneyConfig) { /**
* 回调地址
*/
private final String notifyUrl;
public MidjourneyApi(String baseUrl, String apiKey, String notifyUrl) {
this.webClient = WebClient.builder() this.webClient = WebClient.builder()
.baseUrl(midjourneyConfig.getUrl()) .baseUrl(baseUrl)
.defaultHeaders(ApiUtils.getJsonContentHeaders(midjourneyConfig.getKey())) .defaultHeaders(ApiUtils.getJsonContentHeaders(apiKey))
.build(); .build();
this.notifyUrl = notifyUrl;
} }
/** /**
@ -40,6 +47,9 @@ public class MidjourneyApi {
* @return 提交结果 * @return 提交结果
*/ */
public SubmitResponse imagine(ImagineRequest request) { public SubmitResponse imagine(ImagineRequest request) {
if (StrUtil.isEmpty(request.getNotifyHook())) {
request.setNotifyHook(notifyUrl);
}
String response = post("/submit/imagine", request); String response = post("/submit/imagine", request);
return JsonUtils.parseObject(response, SubmitResponse.class); return JsonUtils.parseObject(response, SubmitResponse.class);
} }
@ -51,8 +61,11 @@ public class MidjourneyApi {
* @return 提交结果 * @return 提交结果
*/ */
public SubmitResponse action(ActionRequest request) { public SubmitResponse action(ActionRequest request) {
String res = post("/submit/action", request); if (StrUtil.isEmpty(request.getNotifyHook())) {
return JsonUtils.parseObject(res, SubmitResponse.class); request.setNotifyHook(notifyUrl);
}
String response = post("/submit/action", request);
return JsonUtils.parseObject(response, SubmitResponse.class);
} }
/** /**
@ -86,23 +99,40 @@ public class MidjourneyApi {
/** /**
* Imagine 请求生成图片 * Imagine 请求生成图片
*
* @param base64Array 垫图(参考图) base64数
* @param notifyHook 通知地址
* @param prompt 提示词
* @param state 自定义参数
*/ */
public record ImagineRequest(List<String> base64Array, @Data
String notifyHook, public static final class ImagineRequest {
String prompt,
String state) { /**
* 垫图(参考图) base64 数组
*/
private List<String> base64Array;
/**
* 提示词
*/
private String prompt;
/**
* 通知地址
*/
private String notifyHook;
/**
* 自定义参数
*/
private String state;
public ImagineRequest(List<String> base64Array, String prompt, String notifyHook, String state) {
this.base64Array = base64Array;
this.prompt = prompt;
this.notifyHook = notifyHook;
this.state = state;
}
public static String buildState(Integer width, Integer height, String version, String model) { public static String buildState(Integer width, Integer height, String version, String model) {
StringBuilder params = new StringBuilder(); StringBuilder params = new StringBuilder();
// --ar 来设置尺寸 // --ar 来设置尺寸
params.append(String.format(" --ar %s:%s ", width, height)); params.append(String.format(" --ar %s:%s ", width, height));
// --niji 模型 // --niji 模型
if (MidjourneyApi.ModelEnum.NIJI.getModel().equals(model)) { if (ModelEnum.NIJI.getModel().equals(model)) {
params.append(String.format(" --niji %s ", version)); params.append(String.format(" --niji %s ", version));
} else { } else {
params.append(String.format(" --v %s ", version)); params.append(String.format(" --v %s ", version));
@ -114,15 +144,20 @@ public class MidjourneyApi {
/** /**
* Action 请求 * Action 请求
*
* @param customId 操作按钮id
* @param taskId 操作按钮id
* @param notifyHook 通知地址
*/ */
public record ActionRequest(String customId, @Data
String taskId, public static final class ActionRequest {
String notifyHook
) { private String customId;
private String taskId;
private String notifyHook;
public ActionRequest(String taskId, String customId, String notifyHook) {
this.customId = customId;
this.taskId = taskId;
this.notifyHook = notifyHook;
}
} }
/** /**

View File

@ -194,20 +194,14 @@ yudao.ai:
api-key: sk-Zsd81gZYg7 api-key: sk-Zsd81gZYg7
midjourney: midjourney:
enable: true enable: true
token: MTE4MjE3MjY2MjkxNTY3ODIzOA.GEV1SG.c49F8lZoGCUHwsj8O0UdodmM6nyQHvuD2fXflw # base-url: https://api.holdai.top/mj-relax/mj
guild-id: 1237948819677904956 base-url: https://api.holdai.top/mj
channel-id: 1237948819677904960 api-key: sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf
notify-url: http://java.nat300.top/admin-api/ai/image/midjourney/notify
suno: suno:
enable: true enable: true
base-url: https://suno-imrqwwui8-status2xxs-projects.vercel.app base-url: https://suno-imrqwwui8-status2xxs-projects.vercel.app
ai:
midjourney-proxy:
enable: true
url: https://api.holdai.top/mj
notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify
key: sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf
--- #################### 芋道相关配置 #################### --- #################### 芋道相关配置 ####################
yudao: yudao: