【优化】Midjourney 自动配置,增加 messageHanlder

This commit is contained in:
cherishsince 2024-04-29 14:50:13 +08:00
parent 80787d1dcc
commit ab0a49a1a7

View File

@ -15,11 +15,15 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageApi;
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient;
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi; import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
@ -37,6 +41,7 @@ import java.util.Map;
* @time 2024/4/12 16:29 * @time 2024/4/12 16:29
* @since 1.0 * @since 1.0
*/ */
@Slf4j
@AutoConfiguration @AutoConfiguration
@EnableConfigurationProperties(YudaoAiProperties.class) @EnableConfigurationProperties(YudaoAiProperties.class)
public class YudaoAiAutoConfiguration { public class YudaoAiAutoConfiguration {
@ -116,15 +121,29 @@ public class YudaoAiAutoConfiguration {
); );
} }
@Bean
@ConditionalOnMissingBean(value = MidjourneyMessageHandler.class)
public MidjourneyMessageHandler defaultMidjourneyMessageHandler() {
// 如果没有实现 MidjourneyMessageHandler 默认注入一个
return new MidjourneyMessageHandler() {
@Override
public void messageHandler(MidjourneyMessage midjourneyMessage) {
log.info("default midjourney message: {}", midjourneyMessage);
}
};
}
@Bean @Bean
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true") @ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
public MidjourneyWebSocketStarter midjourneyWebSocketStarter(ApplicationContext applicationContext, YudaoAiProperties yudaoAiProperties) { public MidjourneyWebSocketStarter midjourneyWebSocketStarter(ApplicationContext applicationContext,
MidjourneyMessageHandler midjourneyMessageHandler,
YudaoAiProperties yudaoAiProperties) {
// 获取 midjourneyProperties // 获取 midjourneyProperties
YudaoAiProperties.MidjourneyProperties midjourneyProperties = yudaoAiProperties.getMidjourney(); YudaoAiProperties.MidjourneyProperties midjourneyProperties = yudaoAiProperties.getMidjourney();
// 获取 midjourneyConfig // 获取 midjourneyConfig
MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, midjourneyProperties); MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, midjourneyProperties);
// 创建 socket messageListener // 创建 socket messageListener
MidjourneyMessageListener messageListener = new MidjourneyMessageListener(midjourneyConfig); MidjourneyMessageListener messageListener = new MidjourneyMessageListener(midjourneyConfig, midjourneyMessageHandler);
// 创建 MidjourneyWebSocketStarter // 创建 MidjourneyWebSocketStarter
return new MidjourneyWebSocketStarter(midjourneyProperties.getWssUrl(), null, midjourneyConfig, messageListener); return new MidjourneyWebSocketStarter(midjourneyProperties.getWssUrl(), null, midjourneyConfig, messageListener);
} }