diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MjNotifyCode.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MjNotifyCode.java index 823d13fc2..103bd6fdf 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MjNotifyCode.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MjNotifyCode.java @@ -8,35 +8,7 @@ public final class MjNotifyCode { * 成功. */ public static final int SUCCESS = 1; - /** - * 数据未找到. - */ - public static final int NOT_FOUND = 3; - /** - * 校验错误. - */ - public static final int VALIDATION_ERROR = 4; - /** - * 系统异常. - */ - public static final int FAILURE = 9; - /** - * 已存在. - */ - public static final int EXISTED = 21; - /** - * 排队中. - */ - public static final int IN_QUEUE = 22; - /** - * 队列已满. - */ - public static final int QUEUE_REJECTED = 23; - /** - * prompt包含敏感词. - */ - public static final int BANNED_PROMPT = 24; } \ No newline at end of file diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/MjWebSocketStarter.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/MjWebSocketStarter.java index e6dcafc1a..0d6a7d0c6 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/MjWebSocketStarter.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/MjWebSocketStarter.java @@ -7,9 +7,11 @@ import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener; +import lombok.Getter; import lombok.extern.slf4j.Slf4j; import org.apache.tomcat.websocket.Constants; import org.jetbrains.annotations.NotNull; +import org.springframework.util.concurrent.ListenableFuture; import org.springframework.util.concurrent.ListenableFutureCallback; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketHttpHeaders; @@ -22,19 +24,43 @@ import java.util.concurrent.TimeoutException; @Slf4j public class MjWebSocketStarter implements WebSocketStarter { + /** + * 链接重试次数 + */ private static final int CONNECT_RETRY_LIMIT = 5; - + /** + * mj 配置文件 + */ private final MidjourneyConfig midjourneyConfig; + /** + * mj 监听(所有message 都会 callback到这里) + */ private final MjMessageListener userMessageListener; + /** + * wss 服务器 + */ private final String wssServer; + /** + * + */ private final String resumeWss; - - private boolean running = false; - - private WebSocketSession webSocketSession = null; + /** + * + */ private ResumeData resumeData = null; + /** + * 是否运行成功 + */ + private boolean running = false; + /** + * 链接成功的 session + */ + private WebSocketSession webSocketSession = null; - public MjWebSocketStarter(String wssServer, String resumeWss, MidjourneyConfig midjourneyConfig, MjMessageListener userMessageListener) { + public MjWebSocketStarter(String wssServer, + String resumeWss, + MidjourneyConfig midjourneyConfig, + MjMessageListener userMessageListener) { this.wssServer = wssServer; this.resumeWss = resumeWss; this.midjourneyConfig = midjourneyConfig; @@ -42,11 +68,12 @@ public class MjWebSocketStarter implements WebSocketStarter { } @Override - public void start() throws Exception { + public void start() { start(false); } private void start(boolean reconnect) { + // 设置header WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); headers.add("Accept-Encoding", "gzip, deflate, br"); headers.add("Accept-Language", "zh-CN,zh;q=0.9"); @@ -54,19 +81,26 @@ public class MjWebSocketStarter implements WebSocketStarter { headers.add("Pragma", "no-cache"); headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits"); headers.add("User-Agent", this.midjourneyConfig.getUserAage()); - var handler = new MjWebSocketHandler(this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure); + // 创建 mjHeader + MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler( + this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure); + // String gatewayUrl; if (reconnect) { - gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream"; - handler.setSessionId(this.resumeData.sessionId()); - handler.setSequence(this.resumeData.sequence()); - handler.setResumeGatewayUrl(this.resumeData.resumeGatewayUrl()); + gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream"; + mjWebSocketHandler.setSessionId(this.resumeData.getSessionId()); + mjWebSocketHandler.setSequence(this.resumeData.getSequence()); + mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl()); } else { gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream"; } - var webSocketClient = new StandardWebSocketClient(); + // 创建 StandardWebSocketClient + StandardWebSocketClient webSocketClient = new StandardWebSocketClient(); + // 设置 io timeout 时间 webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000"); - var socketSessionFuture = webSocketClient.doHandshake(handler, headers, URI.create(gatewayUrl)); + // + ListenableFuture socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl)); + // 添加 callback 进行回调 socketSessionFuture.addCallback(new ListenableFutureCallback<>() { @Override public void onFailure(@NotNull Throwable e) { @@ -87,14 +121,18 @@ public class MjWebSocketStarter implements WebSocketStarter { } private void onSocketFailure(int code, String reason) { + // 1001异常可以忽略 if (code == 1001) { return; } + // 关闭 socket closeSocketSessionWhenIsOpen(); + // 没有运行通知 if (!this.running) { notifyWssLock(code, reason); return; } + // 已经运行先设置为false,发起 this.running = false; if (code >= 4000) { log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason); @@ -107,36 +145,34 @@ public class MjWebSocketStarter implements WebSocketStarter { } } + /** + * 重连 + */ private void tryReconnect() { try { tryStart(true); } catch (Exception e) { - if (e instanceof TimeoutException) { - closeSocketSessionWhenIsOpen(); - } - log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage()); + log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage()); ThreadUtil.sleep(1000); tryNewConnect(); } } private void tryNewConnect() { + // 链接重试次数5 for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) { try { tryStart(false); return; } catch (Exception e) { - if (e instanceof TimeoutException) { - closeSocketSessionWhenIsOpen(); - } - log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage()); + log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage()); ThreadUtil.sleep(5000); } } log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId()); } - public void tryStart(boolean reconnect) throws Exception { + public void tryStart(boolean reconnect) { start(reconnect); } @@ -144,6 +180,9 @@ public class MjWebSocketStarter implements WebSocketStarter { System.err.println("notifyWssLock: " + code + " - " + reason); } + /** + * 关闭 socket session + */ private void closeSocketSessionWhenIsOpen() { try { if (this.webSocketSession != null && this.webSocketSession.isOpen()) { @@ -161,6 +200,20 @@ public class MjWebSocketStarter implements WebSocketStarter { return this.wssServer; } - public record ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) { + @Getter + public static class ResumeData { + + public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) { + this.sessionId = sessionId; + this.sequence = sequence; + this.resumeGatewayUrl = resumeGatewayUrl; + } + + /** + * socket session + */ + private final String sessionId; + private final Object sequence; + private final String resumeGatewayUrl; } } \ No newline at end of file diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/handler/MjWebSocketHandler.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/handler/MjWebSocketHandler.java index 481fa1b0c..55fc7e4bc 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/handler/MjWebSocketHandler.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/handler/MjWebSocketHandler.java @@ -30,16 +30,41 @@ import java.util.concurrent.TimeUnit; @Slf4j public class MjWebSocketHandler implements WebSocketHandler { + /** + * close 错误码:重连 + */ public static final int CLOSE_CODE_RECONNECT = 2001; + /** + * close 错误码:无效、作废 + */ public static final int CLOSE_CODE_INVALIDATE = 1009; + /** + * close 错误码:异常 + */ public static final int CLOSE_CODE_EXCEPTION = 1011; - + /** + * mj配置文件 + */ private final MidjourneyConfig midjourneyConfig; + /** + * mj 消息监听 + */ private final MjMessageListener userMessageListener; + /** + * 成功回调 + */ private final SuccessCallback successCallback; + /** + * 失败回调 + */ private final FailureCallback failureCallback; - + /** + * 心跳执行器 + */ private final ScheduledExecutorService heartExecutor; + /** + * auth数据 + */ private final DataObject authData; @Setter @@ -55,6 +80,9 @@ public class MjWebSocketHandler implements WebSocketHandler { private Future heartbeatInterval; private Future heartbeatTimeout; + /** + * 处理 message 消息的 Decompressor + */ private final Decompressor decompressor = new ZlibDecompressor(2048); public MjWebSocketHandler(MidjourneyConfig account, @@ -77,11 +105,13 @@ public class MjWebSocketHandler implements WebSocketHandler { @Override public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception { log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e); + // 通知链接异常 onFailure(CLOSE_CODE_EXCEPTION, "transport error"); } @Override public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception { + // 链接关闭 onFailure(closeStatus.getCode(), closeStatus.getReason()); } @@ -92,13 +122,18 @@ public class MjWebSocketHandler implements WebSocketHandler { @Override public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage message) throws Exception { + // 获取 message 消息 ByteBuffer buffer = (ByteBuffer) message.getPayload(); + // 解析 message byte[] decompressed = decompressor.decompress(buffer.array()); if (decompressed == null) { return; } + // 转换 json String json = new String(decompressed, StandardCharsets.UTF_8); + // 转换 jda 自带的 dataObject(和json object 差不多) DataObject data = DataObject.fromJson(json); + // 获取消息类型 int opCode = data.getInt("op"); switch (opCode) { case WebSocketCode.HEARTBEAT -> handleHeartbeat(session);