diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java index fa69ee56a..6eedb5997 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java @@ -15,11 +15,15 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageApi; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; +import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage; import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi; +import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener; +import lombok.extern.slf4j.Slf4j; import org.jetbrains.annotations.NotNull; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.ApplicationContext; @@ -37,6 +41,7 @@ import java.util.Map; * @time 2024/4/12 16:29 * @since 1.0 */ +@Slf4j @AutoConfiguration @EnableConfigurationProperties(YudaoAiProperties.class) public class YudaoAiAutoConfiguration { @@ -116,15 +121,29 @@ public class YudaoAiAutoConfiguration { ); } + @Bean + @ConditionalOnMissingBean(value = MidjourneyMessageHandler.class) + public MidjourneyMessageHandler defaultMidjourneyMessageHandler() { + // 如果没有实现 MidjourneyMessageHandler 默认注入一个 + return new MidjourneyMessageHandler() { + @Override + public void messageHandler(MidjourneyMessage midjourneyMessage) { + log.info("default midjourney message: {}", midjourneyMessage); + } + }; + } + @Bean @ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true") - public MidjourneyWebSocketStarter midjourneyWebSocketStarter(ApplicationContext applicationContext, YudaoAiProperties yudaoAiProperties) { + public MidjourneyWebSocketStarter midjourneyWebSocketStarter(ApplicationContext applicationContext, + MidjourneyMessageHandler midjourneyMessageHandler, + YudaoAiProperties yudaoAiProperties) { // 获取 midjourneyProperties YudaoAiProperties.MidjourneyProperties midjourneyProperties = yudaoAiProperties.getMidjourney(); // 获取 midjourneyConfig MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, midjourneyProperties); // 创建 socket messageListener - MidjourneyMessageListener messageListener = new MidjourneyMessageListener(midjourneyConfig); + MidjourneyMessageListener messageListener = new MidjourneyMessageListener(midjourneyConfig, midjourneyMessageHandler); // 创建 MidjourneyWebSocketStarter return new MidjourneyWebSocketStarter(midjourneyProperties.getWssUrl(), null, midjourneyConfig, messageListener); }