diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java index b265f6814..fa69ee56a 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java @@ -1,5 +1,6 @@ package cn.iocoder.yudao.framework.ai.config; +import cn.hutool.core.io.IoUtil; import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient; import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatModal; import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions; @@ -13,10 +14,21 @@ import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageApi; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions; +import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; +import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi; +import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter; +import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener; +import org.jetbrains.annotations.NotNull; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.core.io.Resource; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; /** * ai 自动配置 @@ -103,4 +115,45 @@ public class YudaoAiAutoConfiguration { .setStyle(openAiImageProperties.getStyle()) ); } + + @Bean + @ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true") + public MidjourneyWebSocketStarter midjourneyWebSocketStarter(ApplicationContext applicationContext, YudaoAiProperties yudaoAiProperties) { + // 获取 midjourneyProperties + YudaoAiProperties.MidjourneyProperties midjourneyProperties = yudaoAiProperties.getMidjourney(); + // 获取 midjourneyConfig + MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, midjourneyProperties); + // 创建 socket messageListener + MidjourneyMessageListener messageListener = new MidjourneyMessageListener(midjourneyConfig); + // 创建 MidjourneyWebSocketStarter + return new MidjourneyWebSocketStarter(midjourneyProperties.getWssUrl(), null, midjourneyConfig, messageListener); + } + + @Bean + @ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true") + public MidjourneyInteractionsApi midjourneyInteractionsApi(ApplicationContext applicationContext, YudaoAiProperties yudaoAiProperties) { + // 获取 midjourneyConfig + MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, yudaoAiProperties.getMidjourney()); + // 创建 MidjourneyInteractionsApi + return new MidjourneyInteractionsApi(midjourneyConfig); + } + + + private static @NotNull MidjourneyConfig getMidjourneyConfig(ApplicationContext applicationContext, + YudaoAiProperties.MidjourneyProperties midjourneyProperties) { + Map requestTemplates = new HashMap<>(); + try { + Resource[] resources = applicationContext.getResources("classpath:http-body/*.json"); + for (var resource : resources) { + String filename = resource.getFilename(); + String params = IoUtil.readUtf8(resource.getInputStream()); + requestTemplates.put(filename.substring(0, filename.length() - 5), params); + } + } catch (IOException e) { + throw new IllegalArgumentException("Midjourney json模板初始化出错! " + e.getMessage()); + } + // 创建 midjourneyConfig + return new MidjourneyConfig(midjourneyProperties.getToken(), + midjourneyProperties.getGuildId(), midjourneyProperties.getChannelId(), requestTemplates); + } } \ No newline at end of file diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java index 79c1ff229..f9dcae04a 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java @@ -26,6 +26,7 @@ public class YudaoAiProperties { private XingHuoProperties xinghuo; private YiYanProperties yiyan; private OpenAiImageProperties openAiImage; + private MidjourneyProperties midjourney; @Data @Accessors(chain = true) @@ -94,6 +95,8 @@ public class YudaoAiProperties { @Data @Accessors(chain = true) public static class OpenAiImageProperties { + private boolean enable = false; + /** * api key */ @@ -107,4 +110,27 @@ public class YudaoAiProperties { */ private OpenAiImageStyleEnum style = OpenAiImageStyleEnum.VIVID; } + + @Data + @Accessors(chain = true) + public static class MidjourneyProperties { + private boolean enable = false; + + /** + * socket 链接地址 + */ + private String wssUrl = "wss://gateway.discord.gg"; + /** + * token + */ + private String token; + /** + * 服务id + */ + private String guildId; + /** + * 频道id + */ + private String channelId; + } } diff --git a/yudao-server/src/main/resources/application-local.yaml b/yudao-server/src/main/resources/application-local.yaml index 32ab21c73..ffc9970c9 100644 --- a/yudao-server/src/main/resources/application-local.yaml +++ b/yudao-server/src/main/resources/application-local.yaml @@ -260,7 +260,11 @@ yudao: api-key: ${OPEN_AI_KEY} model: dall_e_2 style: vivid - + midjourney: + enable: true + token: OTcyNzIxMzA0ODkxNDUzNDUw.G_vMOz.BO_Q0sXAD80u5ZKIHPNYDTRX_FgeKL3cKFc53I + guild-id: 1225608134878302329 + channel-id: 1225608134878302332 captcha: enable: false # 本地环境,暂时关闭图片验证码,方便登录等接口的测试; security: