增加注释

This commit is contained in:
cherishsince 2024-04-05 10:15:59 +08:00
parent 5044a58118
commit 44e44dc4bb
3 changed files with 114 additions and 54 deletions

View File

@ -8,35 +8,7 @@ public final class MjNotifyCode {
* 成功.
*/
public static final int SUCCESS = 1;
/**
* 数据未找到.
*/
public static final int NOT_FOUND = 3;
/**
* 校验错误.
*/
public static final int VALIDATION_ERROR = 4;
/**
* 系统异常.
*/
public static final int FAILURE = 9;
/**
* 已存在.
*/
public static final int EXISTED = 21;
/**
* 排队中.
*/
public static final int IN_QUEUE = 22;
/**
* 队列已满.
*/
public static final int QUEUE_REJECTED = 23;
/**
* prompt包含敏感词.
*/
public static final int BANNED_PROMPT = 24;
}

View File

@ -7,9 +7,11 @@ import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.tomcat.websocket.Constants;
import org.jetbrains.annotations.NotNull;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHttpHeaders;
@ -22,19 +24,43 @@ import java.util.concurrent.TimeoutException;
@Slf4j
public class MjWebSocketStarter implements WebSocketStarter {
/**
* 链接重试次数
*/
private static final int CONNECT_RETRY_LIMIT = 5;
/**
* mj 配置文件
*/
private final MidjourneyConfig midjourneyConfig;
/**
* mj 监听(所有message 都会 callback到这里)
*/
private final MjMessageListener userMessageListener;
/**
* wss 服务器
*/
private final String wssServer;
/**
*
*/
private final String resumeWss;
private boolean running = false;
private WebSocketSession webSocketSession = null;
/**
*
*/
private ResumeData resumeData = null;
/**
* 是否运行成功
*/
private boolean running = false;
/**
* 链接成功的 session
*/
private WebSocketSession webSocketSession = null;
public MjWebSocketStarter(String wssServer, String resumeWss, MidjourneyConfig midjourneyConfig, MjMessageListener userMessageListener) {
public MjWebSocketStarter(String wssServer,
String resumeWss,
MidjourneyConfig midjourneyConfig,
MjMessageListener userMessageListener) {
this.wssServer = wssServer;
this.resumeWss = resumeWss;
this.midjourneyConfig = midjourneyConfig;
@ -42,11 +68,12 @@ public class MjWebSocketStarter implements WebSocketStarter {
}
@Override
public void start() throws Exception {
public void start() {
start(false);
}
private void start(boolean reconnect) {
// 设置header
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.add("Accept-Encoding", "gzip, deflate, br");
headers.add("Accept-Language", "zh-CN,zh;q=0.9");
@ -54,19 +81,26 @@ public class MjWebSocketStarter implements WebSocketStarter {
headers.add("Pragma", "no-cache");
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
headers.add("User-Agent", this.midjourneyConfig.getUserAage());
var handler = new MjWebSocketHandler(this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
// 创建 mjHeader
MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler(
this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
//
String gatewayUrl;
if (reconnect) {
gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
handler.setSessionId(this.resumeData.sessionId());
handler.setSequence(this.resumeData.sequence());
handler.setResumeGatewayUrl(this.resumeData.resumeGatewayUrl());
gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
mjWebSocketHandler.setSessionId(this.resumeData.getSessionId());
mjWebSocketHandler.setSequence(this.resumeData.getSequence());
mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl());
} else {
gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream";
}
var webSocketClient = new StandardWebSocketClient();
// 创建 StandardWebSocketClient
StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
// 设置 io timeout 时间
webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000");
var socketSessionFuture = webSocketClient.doHandshake(handler, headers, URI.create(gatewayUrl));
//
ListenableFuture<WebSocketSession> socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl));
// 添加 callback 进行回调
socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
@Override
public void onFailure(@NotNull Throwable e) {
@ -87,14 +121,18 @@ public class MjWebSocketStarter implements WebSocketStarter {
}
private void onSocketFailure(int code, String reason) {
// 1001异常可以忽略
if (code == 1001) {
return;
}
// 关闭 socket
closeSocketSessionWhenIsOpen();
// 没有运行通知
if (!this.running) {
notifyWssLock(code, reason);
return;
}
// 已经运行先设置为false发起
this.running = false;
if (code >= 4000) {
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason);
@ -107,36 +145,34 @@ public class MjWebSocketStarter implements WebSocketStarter {
}
}
/**
* 重连
*/
private void tryReconnect() {
try {
tryStart(true);
} catch (Exception e) {
if (e instanceof TimeoutException) {
closeSocketSessionWhenIsOpen();
}
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
ThreadUtil.sleep(1000);
tryNewConnect();
}
}
private void tryNewConnect() {
// 链接重试次数5
for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) {
try {
tryStart(false);
return;
} catch (Exception e) {
if (e instanceof TimeoutException) {
closeSocketSessionWhenIsOpen();
}
log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
ThreadUtil.sleep(5000);
}
}
log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId());
}
public void tryStart(boolean reconnect) throws Exception {
public void tryStart(boolean reconnect) {
start(reconnect);
}
@ -144,6 +180,9 @@ public class MjWebSocketStarter implements WebSocketStarter {
System.err.println("notifyWssLock: " + code + " - " + reason);
}
/**
* 关闭 socket session
*/
private void closeSocketSessionWhenIsOpen() {
try {
if (this.webSocketSession != null && this.webSocketSession.isOpen()) {
@ -161,6 +200,20 @@ public class MjWebSocketStarter implements WebSocketStarter {
return this.wssServer;
}
public record ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
@Getter
public static class ResumeData {
public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
this.sessionId = sessionId;
this.sequence = sequence;
this.resumeGatewayUrl = resumeGatewayUrl;
}
/**
* socket session
*/
private final String sessionId;
private final Object sequence;
private final String resumeGatewayUrl;
}
}

View File

@ -30,16 +30,41 @@ import java.util.concurrent.TimeUnit;
@Slf4j
public class MjWebSocketHandler implements WebSocketHandler {
/**
* close 错误码重连
*/
public static final int CLOSE_CODE_RECONNECT = 2001;
/**
* close 错误码无效作废
*/
public static final int CLOSE_CODE_INVALIDATE = 1009;
/**
* close 错误码异常
*/
public static final int CLOSE_CODE_EXCEPTION = 1011;
/**
* mj配置文件
*/
private final MidjourneyConfig midjourneyConfig;
/**
* mj 消息监听
*/
private final MjMessageListener userMessageListener;
/**
* 成功回调
*/
private final SuccessCallback successCallback;
/**
* 失败回调
*/
private final FailureCallback failureCallback;
/**
* 心跳执行器
*/
private final ScheduledExecutorService heartExecutor;
/**
* auth数据
*/
private final DataObject authData;
@Setter
@ -55,6 +80,9 @@ public class MjWebSocketHandler implements WebSocketHandler {
private Future<?> heartbeatInterval;
private Future<?> heartbeatTimeout;
/**
* 处理 message 消息的 Decompressor
*/
private final Decompressor decompressor = new ZlibDecompressor(2048);
public MjWebSocketHandler(MidjourneyConfig account,
@ -77,11 +105,13 @@ public class MjWebSocketHandler implements WebSocketHandler {
@Override
public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception {
log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e);
// 通知链接异常
onFailure(CLOSE_CODE_EXCEPTION, "transport error");
}
@Override
public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception {
// 链接关闭
onFailure(closeStatus.getCode(), closeStatus.getReason());
}
@ -92,13 +122,18 @@ public class MjWebSocketHandler implements WebSocketHandler {
@Override
public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception {
// 获取 message 消息
ByteBuffer buffer = (ByteBuffer) message.getPayload();
// 解析 message
byte[] decompressed = decompressor.decompress(buffer.array());
if (decompressed == null) {
return;
}
// 转换 json
String json = new String(decompressed, StandardCharsets.UTF_8);
// 转换 jda 自带的 dataObject(和json object 差不多)
DataObject data = DataObject.fromJson(json);
// 获取消息类型
int opCode = data.getInt("op");
switch (opCode) {
case WebSocketCode.HEARTBEAT -> handleHeartbeat(session);