mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2024-11-23 15:51:52 +08:00
增加注释
This commit is contained in:
parent
5044a58118
commit
44e44dc4bb
@ -8,35 +8,7 @@ public final class MjNotifyCode {
|
|||||||
* 成功.
|
* 成功.
|
||||||
*/
|
*/
|
||||||
public static final int SUCCESS = 1;
|
public static final int SUCCESS = 1;
|
||||||
/**
|
|
||||||
* 数据未找到.
|
|
||||||
*/
|
|
||||||
public static final int NOT_FOUND = 3;
|
|
||||||
/**
|
|
||||||
* 校验错误.
|
|
||||||
*/
|
|
||||||
public static final int VALIDATION_ERROR = 4;
|
|
||||||
/**
|
|
||||||
* 系统异常.
|
|
||||||
*/
|
|
||||||
public static final int FAILURE = 9;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 已存在.
|
|
||||||
*/
|
|
||||||
public static final int EXISTED = 21;
|
|
||||||
/**
|
|
||||||
* 排队中.
|
|
||||||
*/
|
|
||||||
public static final int IN_QUEUE = 22;
|
|
||||||
/**
|
|
||||||
* 队列已满.
|
|
||||||
*/
|
|
||||||
public static final int QUEUE_REJECTED = 23;
|
|
||||||
/**
|
|
||||||
* prompt包含敏感词.
|
|
||||||
*/
|
|
||||||
public static final int BANNED_PROMPT = 24;
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
@ -7,9 +7,11 @@ import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
|||||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode;
|
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode;
|
||||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler;
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler;
|
||||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.tomcat.websocket.Constants;
|
import org.apache.tomcat.websocket.Constants;
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.springframework.util.concurrent.ListenableFuture;
|
||||||
import org.springframework.util.concurrent.ListenableFutureCallback;
|
import org.springframework.util.concurrent.ListenableFutureCallback;
|
||||||
import org.springframework.web.socket.CloseStatus;
|
import org.springframework.web.socket.CloseStatus;
|
||||||
import org.springframework.web.socket.WebSocketHttpHeaders;
|
import org.springframework.web.socket.WebSocketHttpHeaders;
|
||||||
@ -22,19 +24,43 @@ import java.util.concurrent.TimeoutException;
|
|||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MjWebSocketStarter implements WebSocketStarter {
|
public class MjWebSocketStarter implements WebSocketStarter {
|
||||||
|
/**
|
||||||
|
* 链接重试次数
|
||||||
|
*/
|
||||||
private static final int CONNECT_RETRY_LIMIT = 5;
|
private static final int CONNECT_RETRY_LIMIT = 5;
|
||||||
|
/**
|
||||||
|
* mj 配置文件
|
||||||
|
*/
|
||||||
private final MidjourneyConfig midjourneyConfig;
|
private final MidjourneyConfig midjourneyConfig;
|
||||||
|
/**
|
||||||
|
* mj 监听(所有message 都会 callback到这里)
|
||||||
|
*/
|
||||||
private final MjMessageListener userMessageListener;
|
private final MjMessageListener userMessageListener;
|
||||||
|
/**
|
||||||
|
* wss 服务器
|
||||||
|
*/
|
||||||
private final String wssServer;
|
private final String wssServer;
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
private final String resumeWss;
|
private final String resumeWss;
|
||||||
|
/**
|
||||||
private boolean running = false;
|
*
|
||||||
|
*/
|
||||||
private WebSocketSession webSocketSession = null;
|
|
||||||
private ResumeData resumeData = null;
|
private ResumeData resumeData = null;
|
||||||
|
/**
|
||||||
|
* 是否运行成功
|
||||||
|
*/
|
||||||
|
private boolean running = false;
|
||||||
|
/**
|
||||||
|
* 链接成功的 session
|
||||||
|
*/
|
||||||
|
private WebSocketSession webSocketSession = null;
|
||||||
|
|
||||||
public MjWebSocketStarter(String wssServer, String resumeWss, MidjourneyConfig midjourneyConfig, MjMessageListener userMessageListener) {
|
public MjWebSocketStarter(String wssServer,
|
||||||
|
String resumeWss,
|
||||||
|
MidjourneyConfig midjourneyConfig,
|
||||||
|
MjMessageListener userMessageListener) {
|
||||||
this.wssServer = wssServer;
|
this.wssServer = wssServer;
|
||||||
this.resumeWss = resumeWss;
|
this.resumeWss = resumeWss;
|
||||||
this.midjourneyConfig = midjourneyConfig;
|
this.midjourneyConfig = midjourneyConfig;
|
||||||
@ -42,11 +68,12 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void start() throws Exception {
|
public void start() {
|
||||||
start(false);
|
start(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void start(boolean reconnect) {
|
private void start(boolean reconnect) {
|
||||||
|
// 设置header
|
||||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
||||||
headers.add("Accept-Encoding", "gzip, deflate, br");
|
headers.add("Accept-Encoding", "gzip, deflate, br");
|
||||||
headers.add("Accept-Language", "zh-CN,zh;q=0.9");
|
headers.add("Accept-Language", "zh-CN,zh;q=0.9");
|
||||||
@ -54,19 +81,26 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||||||
headers.add("Pragma", "no-cache");
|
headers.add("Pragma", "no-cache");
|
||||||
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
|
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
|
||||||
headers.add("User-Agent", this.midjourneyConfig.getUserAage());
|
headers.add("User-Agent", this.midjourneyConfig.getUserAage());
|
||||||
var handler = new MjWebSocketHandler(this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
|
// 创建 mjHeader
|
||||||
|
MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler(
|
||||||
|
this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
|
||||||
|
//
|
||||||
String gatewayUrl;
|
String gatewayUrl;
|
||||||
if (reconnect) {
|
if (reconnect) {
|
||||||
gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
|
gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
|
||||||
handler.setSessionId(this.resumeData.sessionId());
|
mjWebSocketHandler.setSessionId(this.resumeData.getSessionId());
|
||||||
handler.setSequence(this.resumeData.sequence());
|
mjWebSocketHandler.setSequence(this.resumeData.getSequence());
|
||||||
handler.setResumeGatewayUrl(this.resumeData.resumeGatewayUrl());
|
mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl());
|
||||||
} else {
|
} else {
|
||||||
gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream";
|
gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream";
|
||||||
}
|
}
|
||||||
var webSocketClient = new StandardWebSocketClient();
|
// 创建 StandardWebSocketClient
|
||||||
|
StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
|
||||||
|
// 设置 io timeout 时间
|
||||||
webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000");
|
webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000");
|
||||||
var socketSessionFuture = webSocketClient.doHandshake(handler, headers, URI.create(gatewayUrl));
|
//
|
||||||
|
ListenableFuture<WebSocketSession> socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl));
|
||||||
|
// 添加 callback 进行回调
|
||||||
socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
|
socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
|
||||||
@Override
|
@Override
|
||||||
public void onFailure(@NotNull Throwable e) {
|
public void onFailure(@NotNull Throwable e) {
|
||||||
@ -87,14 +121,18 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void onSocketFailure(int code, String reason) {
|
private void onSocketFailure(int code, String reason) {
|
||||||
|
// 1001异常可以忽略
|
||||||
if (code == 1001) {
|
if (code == 1001) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// 关闭 socket
|
||||||
closeSocketSessionWhenIsOpen();
|
closeSocketSessionWhenIsOpen();
|
||||||
|
// 没有运行通知
|
||||||
if (!this.running) {
|
if (!this.running) {
|
||||||
notifyWssLock(code, reason);
|
notifyWssLock(code, reason);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// 已经运行先设置为false,发起
|
||||||
this.running = false;
|
this.running = false;
|
||||||
if (code >= 4000) {
|
if (code >= 4000) {
|
||||||
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason);
|
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason);
|
||||||
@ -107,36 +145,34 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重连
|
||||||
|
*/
|
||||||
private void tryReconnect() {
|
private void tryReconnect() {
|
||||||
try {
|
try {
|
||||||
tryStart(true);
|
tryStart(true);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
if (e instanceof TimeoutException) {
|
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
|
||||||
closeSocketSessionWhenIsOpen();
|
|
||||||
}
|
|
||||||
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
|
|
||||||
ThreadUtil.sleep(1000);
|
ThreadUtil.sleep(1000);
|
||||||
tryNewConnect();
|
tryNewConnect();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void tryNewConnect() {
|
private void tryNewConnect() {
|
||||||
|
// 链接重试次数5
|
||||||
for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) {
|
for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) {
|
||||||
try {
|
try {
|
||||||
tryStart(false);
|
tryStart(false);
|
||||||
return;
|
return;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
if (e instanceof TimeoutException) {
|
log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
|
||||||
closeSocketSessionWhenIsOpen();
|
|
||||||
}
|
|
||||||
log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
|
|
||||||
ThreadUtil.sleep(5000);
|
ThreadUtil.sleep(5000);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId());
|
log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void tryStart(boolean reconnect) throws Exception {
|
public void tryStart(boolean reconnect) {
|
||||||
start(reconnect);
|
start(reconnect);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,6 +180,9 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||||||
System.err.println("notifyWssLock: " + code + " - " + reason);
|
System.err.println("notifyWssLock: " + code + " - " + reason);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 关闭 socket session
|
||||||
|
*/
|
||||||
private void closeSocketSessionWhenIsOpen() {
|
private void closeSocketSessionWhenIsOpen() {
|
||||||
try {
|
try {
|
||||||
if (this.webSocketSession != null && this.webSocketSession.isOpen()) {
|
if (this.webSocketSession != null && this.webSocketSession.isOpen()) {
|
||||||
@ -161,6 +200,20 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||||||
return this.wssServer;
|
return this.wssServer;
|
||||||
}
|
}
|
||||||
|
|
||||||
public record ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
|
@Getter
|
||||||
|
public static class ResumeData {
|
||||||
|
|
||||||
|
public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
|
||||||
|
this.sessionId = sessionId;
|
||||||
|
this.sequence = sequence;
|
||||||
|
this.resumeGatewayUrl = resumeGatewayUrl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* socket session
|
||||||
|
*/
|
||||||
|
private final String sessionId;
|
||||||
|
private final Object sequence;
|
||||||
|
private final String resumeGatewayUrl;
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -30,16 +30,41 @@ import java.util.concurrent.TimeUnit;
|
|||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MjWebSocketHandler implements WebSocketHandler {
|
public class MjWebSocketHandler implements WebSocketHandler {
|
||||||
|
/**
|
||||||
|
* close 错误码:重连
|
||||||
|
*/
|
||||||
public static final int CLOSE_CODE_RECONNECT = 2001;
|
public static final int CLOSE_CODE_RECONNECT = 2001;
|
||||||
|
/**
|
||||||
|
* close 错误码:无效、作废
|
||||||
|
*/
|
||||||
public static final int CLOSE_CODE_INVALIDATE = 1009;
|
public static final int CLOSE_CODE_INVALIDATE = 1009;
|
||||||
|
/**
|
||||||
|
* close 错误码:异常
|
||||||
|
*/
|
||||||
public static final int CLOSE_CODE_EXCEPTION = 1011;
|
public static final int CLOSE_CODE_EXCEPTION = 1011;
|
||||||
|
/**
|
||||||
|
* mj配置文件
|
||||||
|
*/
|
||||||
private final MidjourneyConfig midjourneyConfig;
|
private final MidjourneyConfig midjourneyConfig;
|
||||||
|
/**
|
||||||
|
* mj 消息监听
|
||||||
|
*/
|
||||||
private final MjMessageListener userMessageListener;
|
private final MjMessageListener userMessageListener;
|
||||||
|
/**
|
||||||
|
* 成功回调
|
||||||
|
*/
|
||||||
private final SuccessCallback successCallback;
|
private final SuccessCallback successCallback;
|
||||||
|
/**
|
||||||
|
* 失败回调
|
||||||
|
*/
|
||||||
private final FailureCallback failureCallback;
|
private final FailureCallback failureCallback;
|
||||||
|
/**
|
||||||
|
* 心跳执行器
|
||||||
|
*/
|
||||||
private final ScheduledExecutorService heartExecutor;
|
private final ScheduledExecutorService heartExecutor;
|
||||||
|
/**
|
||||||
|
* auth数据
|
||||||
|
*/
|
||||||
private final DataObject authData;
|
private final DataObject authData;
|
||||||
|
|
||||||
@Setter
|
@Setter
|
||||||
@ -55,6 +80,9 @@ public class MjWebSocketHandler implements WebSocketHandler {
|
|||||||
private Future<?> heartbeatInterval;
|
private Future<?> heartbeatInterval;
|
||||||
private Future<?> heartbeatTimeout;
|
private Future<?> heartbeatTimeout;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理 message 消息的 Decompressor
|
||||||
|
*/
|
||||||
private final Decompressor decompressor = new ZlibDecompressor(2048);
|
private final Decompressor decompressor = new ZlibDecompressor(2048);
|
||||||
|
|
||||||
public MjWebSocketHandler(MidjourneyConfig account,
|
public MjWebSocketHandler(MidjourneyConfig account,
|
||||||
@ -77,11 +105,13 @@ public class MjWebSocketHandler implements WebSocketHandler {
|
|||||||
@Override
|
@Override
|
||||||
public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception {
|
public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception {
|
||||||
log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e);
|
log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e);
|
||||||
|
// 通知链接异常
|
||||||
onFailure(CLOSE_CODE_EXCEPTION, "transport error");
|
onFailure(CLOSE_CODE_EXCEPTION, "transport error");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception {
|
public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception {
|
||||||
|
// 链接关闭
|
||||||
onFailure(closeStatus.getCode(), closeStatus.getReason());
|
onFailure(closeStatus.getCode(), closeStatus.getReason());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,13 +122,18 @@ public class MjWebSocketHandler implements WebSocketHandler {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception {
|
public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception {
|
||||||
|
// 获取 message 消息
|
||||||
ByteBuffer buffer = (ByteBuffer) message.getPayload();
|
ByteBuffer buffer = (ByteBuffer) message.getPayload();
|
||||||
|
// 解析 message
|
||||||
byte[] decompressed = decompressor.decompress(buffer.array());
|
byte[] decompressed = decompressor.decompress(buffer.array());
|
||||||
if (decompressed == null) {
|
if (decompressed == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// 转换 json
|
||||||
String json = new String(decompressed, StandardCharsets.UTF_8);
|
String json = new String(decompressed, StandardCharsets.UTF_8);
|
||||||
|
// 转换 jda 自带的 dataObject(和json object 差不多)
|
||||||
DataObject data = DataObject.fromJson(json);
|
DataObject data = DataObject.fromJson(json);
|
||||||
|
// 获取消息类型
|
||||||
int opCode = data.getInt("op");
|
int opCode = data.getInt("op");
|
||||||
switch (opCode) {
|
switch (opCode) {
|
||||||
case WebSocketCode.HEARTBEAT -> handleHeartbeat(session);
|
case WebSocketCode.HEARTBEAT -> handleHeartbeat(session);
|
||||||
|
Loading…
Reference in New Issue
Block a user