mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2024-11-23 07:41:53 +08:00
增加注释
This commit is contained in:
parent
5044a58118
commit
44e44dc4bb
@ -8,35 +8,7 @@ public final class MjNotifyCode {
|
||||
* 成功.
|
||||
*/
|
||||
public static final int SUCCESS = 1;
|
||||
/**
|
||||
* 数据未找到.
|
||||
*/
|
||||
public static final int NOT_FOUND = 3;
|
||||
/**
|
||||
* 校验错误.
|
||||
*/
|
||||
public static final int VALIDATION_ERROR = 4;
|
||||
/**
|
||||
* 系统异常.
|
||||
*/
|
||||
public static final int FAILURE = 9;
|
||||
|
||||
/**
|
||||
* 已存在.
|
||||
*/
|
||||
public static final int EXISTED = 21;
|
||||
/**
|
||||
* 排队中.
|
||||
*/
|
||||
public static final int IN_QUEUE = 22;
|
||||
/**
|
||||
* 队列已满.
|
||||
*/
|
||||
public static final int QUEUE_REJECTED = 23;
|
||||
/**
|
||||
* prompt包含敏感词.
|
||||
*/
|
||||
public static final int BANNED_PROMPT = 24;
|
||||
|
||||
|
||||
}
|
@ -7,9 +7,11 @@ import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.tomcat.websocket.Constants;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.util.concurrent.ListenableFuture;
|
||||
import org.springframework.util.concurrent.ListenableFutureCallback;
|
||||
import org.springframework.web.socket.CloseStatus;
|
||||
import org.springframework.web.socket.WebSocketHttpHeaders;
|
||||
@ -22,19 +24,43 @@ import java.util.concurrent.TimeoutException;
|
||||
|
||||
@Slf4j
|
||||
public class MjWebSocketStarter implements WebSocketStarter {
|
||||
/**
|
||||
* 链接重试次数
|
||||
*/
|
||||
private static final int CONNECT_RETRY_LIMIT = 5;
|
||||
|
||||
/**
|
||||
* mj 配置文件
|
||||
*/
|
||||
private final MidjourneyConfig midjourneyConfig;
|
||||
/**
|
||||
* mj 监听(所有message 都会 callback到这里)
|
||||
*/
|
||||
private final MjMessageListener userMessageListener;
|
||||
/**
|
||||
* wss 服务器
|
||||
*/
|
||||
private final String wssServer;
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private final String resumeWss;
|
||||
|
||||
private boolean running = false;
|
||||
|
||||
private WebSocketSession webSocketSession = null;
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private ResumeData resumeData = null;
|
||||
/**
|
||||
* 是否运行成功
|
||||
*/
|
||||
private boolean running = false;
|
||||
/**
|
||||
* 链接成功的 session
|
||||
*/
|
||||
private WebSocketSession webSocketSession = null;
|
||||
|
||||
public MjWebSocketStarter(String wssServer, String resumeWss, MidjourneyConfig midjourneyConfig, MjMessageListener userMessageListener) {
|
||||
public MjWebSocketStarter(String wssServer,
|
||||
String resumeWss,
|
||||
MidjourneyConfig midjourneyConfig,
|
||||
MjMessageListener userMessageListener) {
|
||||
this.wssServer = wssServer;
|
||||
this.resumeWss = resumeWss;
|
||||
this.midjourneyConfig = midjourneyConfig;
|
||||
@ -42,11 +68,12 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() throws Exception {
|
||||
public void start() {
|
||||
start(false);
|
||||
}
|
||||
|
||||
private void start(boolean reconnect) {
|
||||
// 设置header
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
||||
headers.add("Accept-Encoding", "gzip, deflate, br");
|
||||
headers.add("Accept-Language", "zh-CN,zh;q=0.9");
|
||||
@ -54,19 +81,26 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
||||
headers.add("Pragma", "no-cache");
|
||||
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
|
||||
headers.add("User-Agent", this.midjourneyConfig.getUserAage());
|
||||
var handler = new MjWebSocketHandler(this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
|
||||
// 创建 mjHeader
|
||||
MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler(
|
||||
this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
|
||||
//
|
||||
String gatewayUrl;
|
||||
if (reconnect) {
|
||||
gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
|
||||
handler.setSessionId(this.resumeData.sessionId());
|
||||
handler.setSequence(this.resumeData.sequence());
|
||||
handler.setResumeGatewayUrl(this.resumeData.resumeGatewayUrl());
|
||||
gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
|
||||
mjWebSocketHandler.setSessionId(this.resumeData.getSessionId());
|
||||
mjWebSocketHandler.setSequence(this.resumeData.getSequence());
|
||||
mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl());
|
||||
} else {
|
||||
gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream";
|
||||
}
|
||||
var webSocketClient = new StandardWebSocketClient();
|
||||
// 创建 StandardWebSocketClient
|
||||
StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
|
||||
// 设置 io timeout 时间
|
||||
webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000");
|
||||
var socketSessionFuture = webSocketClient.doHandshake(handler, headers, URI.create(gatewayUrl));
|
||||
//
|
||||
ListenableFuture<WebSocketSession> socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl));
|
||||
// 添加 callback 进行回调
|
||||
socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
|
||||
@Override
|
||||
public void onFailure(@NotNull Throwable e) {
|
||||
@ -87,14 +121,18 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
||||
}
|
||||
|
||||
private void onSocketFailure(int code, String reason) {
|
||||
// 1001异常可以忽略
|
||||
if (code == 1001) {
|
||||
return;
|
||||
}
|
||||
// 关闭 socket
|
||||
closeSocketSessionWhenIsOpen();
|
||||
// 没有运行通知
|
||||
if (!this.running) {
|
||||
notifyWssLock(code, reason);
|
||||
return;
|
||||
}
|
||||
// 已经运行先设置为false,发起
|
||||
this.running = false;
|
||||
if (code >= 4000) {
|
||||
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason);
|
||||
@ -107,13 +145,13 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 重连
|
||||
*/
|
||||
private void tryReconnect() {
|
||||
try {
|
||||
tryStart(true);
|
||||
} catch (Exception e) {
|
||||
if (e instanceof TimeoutException) {
|
||||
closeSocketSessionWhenIsOpen();
|
||||
}
|
||||
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
|
||||
ThreadUtil.sleep(1000);
|
||||
tryNewConnect();
|
||||
@ -121,14 +159,12 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
||||
}
|
||||
|
||||
private void tryNewConnect() {
|
||||
// 链接重试次数5
|
||||
for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) {
|
||||
try {
|
||||
tryStart(false);
|
||||
return;
|
||||
} catch (Exception e) {
|
||||
if (e instanceof TimeoutException) {
|
||||
closeSocketSessionWhenIsOpen();
|
||||
}
|
||||
log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
|
||||
ThreadUtil.sleep(5000);
|
||||
}
|
||||
@ -136,7 +172,7 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
||||
log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId());
|
||||
}
|
||||
|
||||
public void tryStart(boolean reconnect) throws Exception {
|
||||
public void tryStart(boolean reconnect) {
|
||||
start(reconnect);
|
||||
}
|
||||
|
||||
@ -144,6 +180,9 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
||||
System.err.println("notifyWssLock: " + code + " - " + reason);
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭 socket session
|
||||
*/
|
||||
private void closeSocketSessionWhenIsOpen() {
|
||||
try {
|
||||
if (this.webSocketSession != null && this.webSocketSession.isOpen()) {
|
||||
@ -161,6 +200,20 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
||||
return this.wssServer;
|
||||
}
|
||||
|
||||
public record ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
|
||||
@Getter
|
||||
public static class ResumeData {
|
||||
|
||||
public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
|
||||
this.sessionId = sessionId;
|
||||
this.sequence = sequence;
|
||||
this.resumeGatewayUrl = resumeGatewayUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* socket session
|
||||
*/
|
||||
private final String sessionId;
|
||||
private final Object sequence;
|
||||
private final String resumeGatewayUrl;
|
||||
}
|
||||
}
|
@ -30,16 +30,41 @@ import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Slf4j
|
||||
public class MjWebSocketHandler implements WebSocketHandler {
|
||||
/**
|
||||
* close 错误码:重连
|
||||
*/
|
||||
public static final int CLOSE_CODE_RECONNECT = 2001;
|
||||
/**
|
||||
* close 错误码:无效、作废
|
||||
*/
|
||||
public static final int CLOSE_CODE_INVALIDATE = 1009;
|
||||
/**
|
||||
* close 错误码:异常
|
||||
*/
|
||||
public static final int CLOSE_CODE_EXCEPTION = 1011;
|
||||
|
||||
/**
|
||||
* mj配置文件
|
||||
*/
|
||||
private final MidjourneyConfig midjourneyConfig;
|
||||
/**
|
||||
* mj 消息监听
|
||||
*/
|
||||
private final MjMessageListener userMessageListener;
|
||||
/**
|
||||
* 成功回调
|
||||
*/
|
||||
private final SuccessCallback successCallback;
|
||||
/**
|
||||
* 失败回调
|
||||
*/
|
||||
private final FailureCallback failureCallback;
|
||||
|
||||
/**
|
||||
* 心跳执行器
|
||||
*/
|
||||
private final ScheduledExecutorService heartExecutor;
|
||||
/**
|
||||
* auth数据
|
||||
*/
|
||||
private final DataObject authData;
|
||||
|
||||
@Setter
|
||||
@ -55,6 +80,9 @@ public class MjWebSocketHandler implements WebSocketHandler {
|
||||
private Future<?> heartbeatInterval;
|
||||
private Future<?> heartbeatTimeout;
|
||||
|
||||
/**
|
||||
* 处理 message 消息的 Decompressor
|
||||
*/
|
||||
private final Decompressor decompressor = new ZlibDecompressor(2048);
|
||||
|
||||
public MjWebSocketHandler(MidjourneyConfig account,
|
||||
@ -77,11 +105,13 @@ public class MjWebSocketHandler implements WebSocketHandler {
|
||||
@Override
|
||||
public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception {
|
||||
log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e);
|
||||
// 通知链接异常
|
||||
onFailure(CLOSE_CODE_EXCEPTION, "transport error");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception {
|
||||
// 链接关闭
|
||||
onFailure(closeStatus.getCode(), closeStatus.getReason());
|
||||
}
|
||||
|
||||
@ -92,13 +122,18 @@ public class MjWebSocketHandler implements WebSocketHandler {
|
||||
|
||||
@Override
|
||||
public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception {
|
||||
// 获取 message 消息
|
||||
ByteBuffer buffer = (ByteBuffer) message.getPayload();
|
||||
// 解析 message
|
||||
byte[] decompressed = decompressor.decompress(buffer.array());
|
||||
if (decompressed == null) {
|
||||
return;
|
||||
}
|
||||
// 转换 json
|
||||
String json = new String(decompressed, StandardCharsets.UTF_8);
|
||||
// 转换 jda 自带的 dataObject(和json object 差不多)
|
||||
DataObject data = DataObject.fromJson(json);
|
||||
// 获取消息类型
|
||||
int opCode = data.getInt("op");
|
||||
switch (opCode) {
|
||||
case WebSocketCode.HEARTBEAT -> handleHeartbeat(session);
|
||||
|
Loading…
Reference in New Issue
Block a user