增加注释

This commit is contained in:
cherishsince 2024-04-05 10:15:59 +08:00
parent 5044a58118
commit 44e44dc4bb
3 changed files with 114 additions and 54 deletions

View File

@ -8,35 +8,7 @@ public final class MjNotifyCode {
* 成功. * 成功.
*/ */
public static final int SUCCESS = 1; public static final int SUCCESS = 1;
/**
* 数据未找到.
*/
public static final int NOT_FOUND = 3;
/**
* 校验错误.
*/
public static final int VALIDATION_ERROR = 4;
/**
* 系统异常.
*/
public static final int FAILURE = 9;
/**
* 已存在.
*/
public static final int EXISTED = 21;
/**
* 排队中.
*/
public static final int IN_QUEUE = 22;
/**
* 队列已满.
*/
public static final int QUEUE_REJECTED = 23;
/**
* prompt包含敏感词.
*/
public static final int BANNED_PROMPT = 24;
} }

View File

@ -7,9 +7,11 @@ import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode; import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.tomcat.websocket.Constants; import org.apache.tomcat.websocket.Constants;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback; import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.WebSocketHttpHeaders;
@ -22,19 +24,43 @@ import java.util.concurrent.TimeoutException;
@Slf4j @Slf4j
public class MjWebSocketStarter implements WebSocketStarter { public class MjWebSocketStarter implements WebSocketStarter {
/**
* 链接重试次数
*/
private static final int CONNECT_RETRY_LIMIT = 5; private static final int CONNECT_RETRY_LIMIT = 5;
/**
* mj 配置文件
*/
private final MidjourneyConfig midjourneyConfig; private final MidjourneyConfig midjourneyConfig;
/**
* mj 监听(所有message 都会 callback到这里)
*/
private final MjMessageListener userMessageListener; private final MjMessageListener userMessageListener;
/**
* wss 服务器
*/
private final String wssServer; private final String wssServer;
/**
*
*/
private final String resumeWss; private final String resumeWss;
/**
private boolean running = false; *
*/
private WebSocketSession webSocketSession = null;
private ResumeData resumeData = null; private ResumeData resumeData = null;
/**
* 是否运行成功
*/
private boolean running = false;
/**
* 链接成功的 session
*/
private WebSocketSession webSocketSession = null;
public MjWebSocketStarter(String wssServer, String resumeWss, MidjourneyConfig midjourneyConfig, MjMessageListener userMessageListener) { public MjWebSocketStarter(String wssServer,
String resumeWss,
MidjourneyConfig midjourneyConfig,
MjMessageListener userMessageListener) {
this.wssServer = wssServer; this.wssServer = wssServer;
this.resumeWss = resumeWss; this.resumeWss = resumeWss;
this.midjourneyConfig = midjourneyConfig; this.midjourneyConfig = midjourneyConfig;
@ -42,11 +68,12 @@ public class MjWebSocketStarter implements WebSocketStarter {
} }
@Override @Override
public void start() throws Exception { public void start() {
start(false); start(false);
} }
private void start(boolean reconnect) { private void start(boolean reconnect) {
// 设置header
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.add("Accept-Encoding", "gzip, deflate, br"); headers.add("Accept-Encoding", "gzip, deflate, br");
headers.add("Accept-Language", "zh-CN,zh;q=0.9"); headers.add("Accept-Language", "zh-CN,zh;q=0.9");
@ -54,19 +81,26 @@ public class MjWebSocketStarter implements WebSocketStarter {
headers.add("Pragma", "no-cache"); headers.add("Pragma", "no-cache");
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits"); headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
headers.add("User-Agent", this.midjourneyConfig.getUserAage()); headers.add("User-Agent", this.midjourneyConfig.getUserAage());
var handler = new MjWebSocketHandler(this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure); // 创建 mjHeader
MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler(
this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
//
String gatewayUrl; String gatewayUrl;
if (reconnect) { if (reconnect) {
gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream"; gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
handler.setSessionId(this.resumeData.sessionId()); mjWebSocketHandler.setSessionId(this.resumeData.getSessionId());
handler.setSequence(this.resumeData.sequence()); mjWebSocketHandler.setSequence(this.resumeData.getSequence());
handler.setResumeGatewayUrl(this.resumeData.resumeGatewayUrl()); mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl());
} else { } else {
gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream"; gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream";
} }
var webSocketClient = new StandardWebSocketClient(); // 创建 StandardWebSocketClient
StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
// 设置 io timeout 时间
webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000"); webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000");
var socketSessionFuture = webSocketClient.doHandshake(handler, headers, URI.create(gatewayUrl)); //
ListenableFuture<WebSocketSession> socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl));
// 添加 callback 进行回调
socketSessionFuture.addCallback(new ListenableFutureCallback<>() { socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
@Override @Override
public void onFailure(@NotNull Throwable e) { public void onFailure(@NotNull Throwable e) {
@ -87,14 +121,18 @@ public class MjWebSocketStarter implements WebSocketStarter {
} }
private void onSocketFailure(int code, String reason) { private void onSocketFailure(int code, String reason) {
// 1001异常可以忽略
if (code == 1001) { if (code == 1001) {
return; return;
} }
// 关闭 socket
closeSocketSessionWhenIsOpen(); closeSocketSessionWhenIsOpen();
// 没有运行通知
if (!this.running) { if (!this.running) {
notifyWssLock(code, reason); notifyWssLock(code, reason);
return; return;
} }
// 已经运行先设置为false发起
this.running = false; this.running = false;
if (code >= 4000) { if (code >= 4000) {
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason); log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason);
@ -107,13 +145,13 @@ public class MjWebSocketStarter implements WebSocketStarter {
} }
} }
/**
* 重连
*/
private void tryReconnect() { private void tryReconnect() {
try { try {
tryStart(true); tryStart(true);
} catch (Exception e) { } catch (Exception e) {
if (e instanceof TimeoutException) {
closeSocketSessionWhenIsOpen();
}
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage()); log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
ThreadUtil.sleep(1000); ThreadUtil.sleep(1000);
tryNewConnect(); tryNewConnect();
@ -121,14 +159,12 @@ public class MjWebSocketStarter implements WebSocketStarter {
} }
private void tryNewConnect() { private void tryNewConnect() {
// 链接重试次数5
for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) { for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) {
try { try {
tryStart(false); tryStart(false);
return; return;
} catch (Exception e) { } catch (Exception e) {
if (e instanceof TimeoutException) {
closeSocketSessionWhenIsOpen();
}
log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage()); log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
ThreadUtil.sleep(5000); ThreadUtil.sleep(5000);
} }
@ -136,7 +172,7 @@ public class MjWebSocketStarter implements WebSocketStarter {
log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId()); log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId());
} }
public void tryStart(boolean reconnect) throws Exception { public void tryStart(boolean reconnect) {
start(reconnect); start(reconnect);
} }
@ -144,6 +180,9 @@ public class MjWebSocketStarter implements WebSocketStarter {
System.err.println("notifyWssLock: " + code + " - " + reason); System.err.println("notifyWssLock: " + code + " - " + reason);
} }
/**
* 关闭 socket session
*/
private void closeSocketSessionWhenIsOpen() { private void closeSocketSessionWhenIsOpen() {
try { try {
if (this.webSocketSession != null && this.webSocketSession.isOpen()) { if (this.webSocketSession != null && this.webSocketSession.isOpen()) {
@ -161,6 +200,20 @@ public class MjWebSocketStarter implements WebSocketStarter {
return this.wssServer; return this.wssServer;
} }
public record ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) { @Getter
public static class ResumeData {
public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
this.sessionId = sessionId;
this.sequence = sequence;
this.resumeGatewayUrl = resumeGatewayUrl;
}
/**
* socket session
*/
private final String sessionId;
private final Object sequence;
private final String resumeGatewayUrl;
} }
} }

View File

@ -30,16 +30,41 @@ import java.util.concurrent.TimeUnit;
@Slf4j @Slf4j
public class MjWebSocketHandler implements WebSocketHandler { public class MjWebSocketHandler implements WebSocketHandler {
/**
* close 错误码重连
*/
public static final int CLOSE_CODE_RECONNECT = 2001; public static final int CLOSE_CODE_RECONNECT = 2001;
/**
* close 错误码无效作废
*/
public static final int CLOSE_CODE_INVALIDATE = 1009; public static final int CLOSE_CODE_INVALIDATE = 1009;
/**
* close 错误码异常
*/
public static final int CLOSE_CODE_EXCEPTION = 1011; public static final int CLOSE_CODE_EXCEPTION = 1011;
/**
* mj配置文件
*/
private final MidjourneyConfig midjourneyConfig; private final MidjourneyConfig midjourneyConfig;
/**
* mj 消息监听
*/
private final MjMessageListener userMessageListener; private final MjMessageListener userMessageListener;
/**
* 成功回调
*/
private final SuccessCallback successCallback; private final SuccessCallback successCallback;
/**
* 失败回调
*/
private final FailureCallback failureCallback; private final FailureCallback failureCallback;
/**
* 心跳执行器
*/
private final ScheduledExecutorService heartExecutor; private final ScheduledExecutorService heartExecutor;
/**
* auth数据
*/
private final DataObject authData; private final DataObject authData;
@Setter @Setter
@ -55,6 +80,9 @@ public class MjWebSocketHandler implements WebSocketHandler {
private Future<?> heartbeatInterval; private Future<?> heartbeatInterval;
private Future<?> heartbeatTimeout; private Future<?> heartbeatTimeout;
/**
* 处理 message 消息的 Decompressor
*/
private final Decompressor decompressor = new ZlibDecompressor(2048); private final Decompressor decompressor = new ZlibDecompressor(2048);
public MjWebSocketHandler(MidjourneyConfig account, public MjWebSocketHandler(MidjourneyConfig account,
@ -77,11 +105,13 @@ public class MjWebSocketHandler implements WebSocketHandler {
@Override @Override
public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception { public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception {
log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e); log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e);
// 通知链接异常
onFailure(CLOSE_CODE_EXCEPTION, "transport error"); onFailure(CLOSE_CODE_EXCEPTION, "transport error");
} }
@Override @Override
public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception { public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception {
// 链接关闭
onFailure(closeStatus.getCode(), closeStatus.getReason()); onFailure(closeStatus.getCode(), closeStatus.getReason());
} }
@ -92,13 +122,18 @@ public class MjWebSocketHandler implements WebSocketHandler {
@Override @Override
public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception { public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception {
// 获取 message 消息
ByteBuffer buffer = (ByteBuffer) message.getPayload(); ByteBuffer buffer = (ByteBuffer) message.getPayload();
// 解析 message
byte[] decompressed = decompressor.decompress(buffer.array()); byte[] decompressed = decompressor.decompress(buffer.array());
if (decompressed == null) { if (decompressed == null) {
return; return;
} }
// 转换 json
String json = new String(decompressed, StandardCharsets.UTF_8); String json = new String(decompressed, StandardCharsets.UTF_8);
// 转换 jda 自带的 dataObject(和json object 差不多)
DataObject data = DataObject.fromJson(json); DataObject data = DataObject.fromJson(json);
// 获取消息类型
int opCode = data.getInt("op"); int opCode = data.getInt("op");
switch (opCode) { switch (opCode) {
case WebSocketCode.HEARTBEAT -> handleHeartbeat(session); case WebSocketCode.HEARTBEAT -> handleHeartbeat(session);