😄模拟发送、接收discard消息成功!

This commit is contained in:
cherishsince 2024-04-03 16:54:49 +08:00
parent bd8e6c2b40
commit 84825579b6
23 changed files with 1041 additions and 30 deletions

View File

@ -1,24 +0,0 @@
package cn.iocoder.yudao.framework.ai.Midjourney;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.web.client.RestTemplate;
public class MjHttpExecute implements MjExecute {
private static final String URL = "https://discord.com/";
private RestTemplate restTemplate = new RestTemplate();
@Override
public boolean execute(MjCommandEnum mjCommand, String prompt) {
// 发送的 uri
String uri = "api/v9/interactions";
// restTemplate 发送post请求
String result = restTemplate.postForObject(URL + uri, prompt, String.class);
// 加载当前目录下文件
return false;
}
}

View File

@ -0,0 +1,10 @@
package cn.iocoder.yudao.framework.ai.midjourney;
/**
* author: fansili
* time: 2024/4/3 15:54
*/
public class DiscordJadMain {
}

View File

@ -0,0 +1,141 @@
package cn.iocoder.yudao.framework.ai.midjourney;
import cn.hutool.http.useragent.UserAgent;
import cn.hutool.http.useragent.UserAgentUtil;
import cn.hutool.json.JSONObject;
import com.alibaba.fastjson.JSON;
import org.springframework.web.socket.*;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
/**
* https://blog.csdn.net/qq_38490457/article/details/125250135
*/
public class DiscordWebSocketClient {
private static final String DISCORD_GATEWAY_URL = "wss://gateway.discord.gg/?v=9&encoding=json";
public static void main(String[] args) throws InterruptedException, ExecutionException, IOException, URISyntaxException {
StandardWebSocketClient client = new StandardWebSocketClient();
DiscordWebSocketHandler handler = new DiscordWebSocketHandler();
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.add("Accept-Encoding", "gzip, deflate, br");
headers.add("Accept-Language", "zh-CN,zh;q=0.9");
headers.add("Cache-Control", "no-cache");
headers.add("Pragma", "no-cache");
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
headers.add("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36");
Future<WebSocketSession> futureSession = client.doHandshake(handler, headers, new URI(DISCORD_GATEWAY_URL));
WebSocketSession session = futureSession.get(); // 这会阻塞直到连接建立
// 登录过程你需要替换 TOKEN 为你的 Discord Bot Token
// String token = "YOUR_DISCORD_BOT_TOKEN"; // 请替换为你的 Bot Token
// String identifyPayload = "{\"op\":2,\"d\":{\"token\":\"" + token + "\",\"properties\":{\"$os\":\"java\",\"$browser\":\"spring-websocket\",\"$device\":\"spring-websocket\"},\"compress\":false,\"large_threshold\":256,\"shard\":[0,1]}}";
// session.sendMessage(new TextMessage(identifyPayload));
// 发送心跳以保持连接活跃
Thread heartbeatThread = new Thread(() -> {
int interval = 0; // 初始心跳间隔后续从 Discord 服务器获取
while (!Thread.currentThread().isInterrupted()) {
try {
Thread.sleep(interval * 1000); // 等待指定的心跳间隔
session.sendMessage(new TextMessage("{\"op\":1,\"d\":null}")); // 发送心跳包
} catch (Exception e) {
e.printStackTrace();
break;
}
}
});
heartbeatThread.start();
// 等待用户输入来保持程序运行仅用于示例
System.in.read();
// 关闭连接和线程
session.close();
heartbeatThread.interrupt();
}
private static class DiscordWebSocketHandler implements WebSocketHandler {
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
Object payload = message.getPayload();
session.sendMessage(new TextMessage(JSON.toJSONString(createAuthData())));
String a= "";
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
}
@Override
public boolean supportsPartialMessages() {
return false;
}
private JSONObject createAuthData() {
String userAgentStr = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36";
UserAgent userAgent = UserAgentUtil.parse(userAgentStr);
String token = "NTY5MDY4NDAxNzEyOTU1Mzky.G4-Fu0.MzD-7ll-ElbXTTgDPHF-WS_UyhMAfbKN3WyyBc";
JSONObject connectionProperties = new JSONObject()
.put("browser", userAgent.getBrowser().getName())
.put("browser_user_agent", userAgentStr)
.put("browser_version", userAgent.getVersion())
.put("client_build_number", 222963)
.put("client_event_source", null)
.put("device", "")
.put("os", userAgent.getOs().getName())
.put("referer", "https://www.midjourney.com")
.put("referrer_current", "")
.put("referring_domain", "www.midjourney.com")
.put("referring_domain_current", "")
.put("release_channel", "stable")
.put("system_locale", "zh-CN");
JSONObject presence = new JSONObject()
.put("activities", "")
.put("afk", false)
.put("since", 0)
.put("status", "online");
JSONObject clientState = new JSONObject()
.put("api_code_version", 0)
.put("guild_versions", "")
.put("highest_last_message_id", "0")
.put("private_channels_version", "0")
.put("read_state_version", 0)
.put("user_guild_settings_version", -1)
.put("user_settings_version", -1);
return new JSONObject()
.put("capabilities", 16381)
.put("client_state", clientState)
.put("compress", false)
.put("presence", presence)
.put("properties", connectionProperties)
.put("token", token);
}
}
}

View File

@ -1,10 +1,12 @@
package cn.iocoder.yudao.framework.ai.Midjourney;
package cn.iocoder.yudao.framework.ai.midjourney;
/**
* 文档: https://www.xiubbs.com/t-401-1-1.html
*
* https://github.com/novicezk/midjourney-proxy/blob/main/README_CN.md
*
* discord4jhttps://github.com/discord-jda/JDA
*
*/
public class MidjourneyApi {

View File

@ -0,0 +1,52 @@
package cn.iocoder.yudao.framework.ai.midjourney;
import cn.hutool.core.io.FileUtil;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.web.client.RestTemplate;
import java.nio.charset.Charset;
public class MjHttpExecute implements MjExecute {
private static final String URL = "https://discord.com/";
@Override
public boolean execute(MjCommandEnum mjCommand, String prompt) {
// 发送的 uri
String uri = "api/v9/interactions";
// restTemplate 发送post请求
// String result = restTemplate.postForObject(URL + uri, prompt, String.class);
// 加载当前目录下文件
return false;
}
public static void main(String[] args) {
RestTemplate restTemplate = new RestTemplate();
String token = "NTY5MDY4NDAxNzEyOTU1Mzky.G4-Fu0.MzD-7ll-ElbXTTgDPHF-WS_UyhMAfbKN3WyyBc";
String body = FileUtil.readString("/Users/fansili/projects/github/ruoyi-vue-pro/yudao-module-ai" +
"/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney" +
"/interactions_type2.json", Charset.forName("utf-8"));
// 创建HTTP头部
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON); // 设置内容类型为JSON
headers.set("Authorization", token); // 如果需要设置认证信息例如JWT令牌
headers.set("Referer", "https://discord.com/channels/1221445697157468200/1221445862962630706"); // 如果需要设置认证信息例如JWT令牌
headers.set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"); // 如果需要设置认证信息例如JWT令牌
headers.set("Cookie", "__dcfduid=6ca536c0e3fa11eeb7cbe34c31b49caf; __sdcfduid=6ca536c1e3fa11eeb7cbe34c31b49caf52cce5ffd8983d2a052cf6aba75fe5fe566f2c265902e283ce30dbf98b8c9c93; _gcl_au=1.1.245923998.1710853617; _ga=GA1.1.111061823.1710853617; __cfruid=6385bb3f48345a006b25992db7dcf984e395736d-1712124666; _cfuvid=O09la5ms0ypNptiG0iD8A6BKWlTxz1LG0WR7qRStD7o-1712124666575-0.0.1.1-604800000; locale=zh-CN; cf_clearance=l_YGod1_SUtYxpDVeZXiX7DLLPl1DYrquZe8WVltvYs-1712124668-1.0.1.1-Hl2.fToel23EpF2HCu9J20rB4D7OhhCzoajPSdo.9Up.wPxhvq22DP9RHzEBKuIUlKyH.kJLxXJfAt2N.LD5WQ; OptanonConsent=isIABGlobal=false&datestamp=Wed+Apr+03+2024+14%3A11%3A15+GMT%2B0800+(%E4%B8%AD%E5%9B%BD%E6%A0%87%E5%87%86%E6%97%B6%E9%97%B4)&version=6.33.0&hosts=&landingPath=https%3A%2F%2Fdiscord.com%2F&groups=C0001%3A1%2CC0002%3A1%2CC0003%3A1; _ga_Q149DFWHT7=GS1.1.1712124668.4.1.1712124679.0.0.0"); // 如果需要设置认证信息例如JWT令牌
// 封装请求体和头部信息
HttpEntity<String> requestEntity = new HttpEntity<>(body, headers);
// 定义请求URL和返回类型
String uri = "api/v9/interactions";
String res = restTemplate.postForObject(URL + uri, requestEntity, String.class);
System.err.println("11");
//
// MjHttpExecute mjHttpExecute = new MjHttpExecute();
// mjHttpExecute.execute(null, "童话世界应该是什么样?");
}
}

View File

@ -1,8 +1,8 @@
{
"type": 2,
"application_id": "936929561302675456",
"guild_id": "1224337694918971392",
"channel_id": "1224337694918971396",
"guild_id": "1221445697157468200",
"channel_id": "1221445862962630706",
"session_id": "696318caed5180a2210e358e44801449",
"data": {
"version": "1166847114203123795",
@ -13,7 +13,7 @@
{
"type": 3,
"name": "prompt",
"value": "中国的是什么样子"
"value": "童话世界应该是什么样?"
}
],
"application_command": {

View File

@ -0,0 +1,22 @@
package cn.iocoder.yudao.framework.ai.midjourney.jad;
import lombok.experimental.UtilityClass;
@UtilityClass
public final class Constants {
// 任务扩展属性 start
public static final String TASK_PROPERTY_NOTIFY_HOOK = "notifyHook";
public static final String TASK_PROPERTY_FINAL_PROMPT = "finalPrompt";
public static final String TASK_PROPERTY_MESSAGE_ID = "messageId";
public static final String TASK_PROPERTY_MESSAGE_HASH = "messageHash";
public static final String TASK_PROPERTY_PROGRESS_MESSAGE_ID = "progressMessageId";
public static final String TASK_PROPERTY_FLAGS = "flags";
public static final String TASK_PROPERTY_NONCE = "nonce";
public static final String TASK_PROPERTY_DISCORD_INSTANCE_ID = "discordInstanceId";
// 任务扩展属性 end
public static final String API_SECRET_HEADER_NAME = "mj-api-secret";
public static final String DEFAULT_DISCORD_USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36";
public static final String MJ_MESSAGE_HANDLED = "mj_proxy_handled";
}

View File

@ -0,0 +1,35 @@
package cn.iocoder.yudao.framework.ai.midjourney.jad;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = true)
public class DiscordAccount extends DomainObject {
private String guildId;
private String channelId;
private String userToken;
private String userAgent = Constants.DEFAULT_DISCORD_USER_AGENT;
private boolean enable = true;
private int coreSize = 3;
private int queueSize = 10;
private int timeoutMinutes = 5;
@JsonIgnore
public String getDisplay() {
return this.channelId;
}
}

View File

@ -0,0 +1,70 @@
package cn.iocoder.yudao.framework.ai.midjourney.jad;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.Getter;
import lombok.Setter;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
public class DomainObject implements Serializable {
@Getter
@Setter
protected String id;
@Setter
protected Map<String, Object> properties; // 扩展属性仅支持基本类型
@JsonIgnore
private final transient Object lock = new Object();
public void sleep() throws InterruptedException {
synchronized (this.lock) {
this.lock.wait();
}
}
public void awake() {
synchronized (this.lock) {
this.lock.notifyAll();
}
}
public DomainObject setProperty(String name, Object value) {
getProperties().put(name, value);
return this;
}
public DomainObject removeProperty(String name) {
getProperties().remove(name);
return this;
}
public Object getProperty(String name) {
return getProperties().get(name);
}
@SuppressWarnings("unchecked")
public <T> T getPropertyGeneric(String name) {
return (T) getProperty(name);
}
public <T> T getProperty(String name, Class<T> clz) {
return getProperty(name, clz, null);
}
public <T> T getProperty(String name, Class<T> clz, T defaultValue) {
Object value = getProperty(name);
return value == null ? defaultValue : clz.cast(value);
}
public Map<String, Object> getProperties() {
if (this.properties == null) {
this.properties = new HashMap<>();
}
return this.properties;
}
}

View File

@ -0,0 +1,61 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss;
import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.TimedCache;
import cn.hutool.core.thread.ThreadUtil;
import cn.iocoder.yudao.framework.ai.midjourney.jad.DomainObject;
import lombok.experimental.UtilityClass;
import java.time.Duration;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
@UtilityClass
public class AsyncLockUtils {
private static final TimedCache<String, LockObject> LOCK_MAP = CacheUtil.newTimedCache(Duration.ofDays(1).toMillis());
public static synchronized LockObject getLock(String key) {
return LOCK_MAP.get(key);
}
public static LockObject waitForLock(String key, Duration duration) throws TimeoutException {
LockObject lockObject;
synchronized (LOCK_MAP) {
if (LOCK_MAP.containsKey(key)) {
lockObject = LOCK_MAP.get(key);
} else {
lockObject = new LockObject(key);
LOCK_MAP.put(key, lockObject);
}
}
Future<?> future = ThreadUtil.execAsync(() -> {
try {
lockObject.sleep();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});
try {
future.get(duration.toMillis(), TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (ExecutionException e) {
// do nothing
} catch (TimeoutException e) {
future.cancel(true);
throw new TimeoutException("Wait Timeout");
} finally {
LOCK_MAP.remove(lockObject.getId());
}
return lockObject;
}
public static class LockObject extends DomainObject {
public LockObject(String id) {
this.id = id;
}
}
}

View File

@ -0,0 +1,57 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss;
import cn.hutool.core.text.CharSequenceUtil;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
@Component
@RequiredArgsConstructor
public class DiscordHelper {
/**
* DISCORD_SERVER_URL.
*/
public static final String DISCORD_SERVER_URL = "https://discord.com";
/**
* DISCORD_CDN_URL.
*/
public static final String DISCORD_CDN_URL = "https://cdn.discordapp.com";
/**
* DISCORD_WSS_URL.
*/
public static final String DISCORD_WSS_URL = "wss://gateway.discord.gg";
/**
* DISCORD_UPLOAD_URL.
*/
public static final String DISCORD_UPLOAD_URL = "https://discord-attachments-uploads-prd.storage.googleapis.com";
public String getServer() {
return DISCORD_SERVER_URL;
}
public String getCdn() {
return DISCORD_CDN_URL;
}
public String getWss() {
return DISCORD_WSS_URL;
}
public String getMessageHash(String imageUrl) {
if (CharSequenceUtil.isBlank(imageUrl)) {
return null;
}
if (CharSequenceUtil.endWith(imageUrl, "_grid_0.webp")) {
int hashStartIndex = imageUrl.lastIndexOf("/");
if (hashStartIndex < 0) {
return null;
}
return CharSequenceUtil.sub(imageUrl, hashStartIndex + 1, imageUrl.length() - "_grid_0.webp".length());
}
int hashStartIndex = imageUrl.lastIndexOf("_");
if (hashStartIndex < 0) {
return null;
}
return CharSequenceUtil.subBefore(imageUrl.substring(hashStartIndex + 1), ".", true);
}
}

View File

@ -0,0 +1,39 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss;
import cn.iocoder.yudao.framework.ai.midjourney.jad.DiscordAccount;
import cn.iocoder.yudao.framework.ai.midjourney.wss.user.SpringUserWebSocketStarter;
import cn.iocoder.yudao.framework.ai.midjourney.wss.user.UserMessageListener;
import java.util.Scanner;
/**
* author: fansili
* time: 2024/4/3 16:40
*/
public class Main {
public static void main(String[] args) {
String token = "NTY5MDY4NDAxNzEyOTU1Mzky.G4-Fu0.MzD-7ll-ElbXTTgDPHF-WS_UyhMAfbKN3WyyBc";
DiscordHelper discordHelper = new DiscordHelper();
DiscordAccount discordAccount = new DiscordAccount();
discordAccount.setUserToken(token);
discordAccount.setGuildId("1221445697157468200");
discordAccount.setChannelId("1221445862962630706");
var messageListener = new UserMessageListener();
var webSocketStarter = new SpringUserWebSocketStarter(discordHelper.getWss(), null, discordAccount, messageListener);
try {
webSocketStarter.start();
} catch (Exception e) {
throw new RuntimeException(e);
}
Scanner scanner = new Scanner(System.in);
scanner.nextLine();
}
}

View File

@ -0,0 +1,26 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss;
public enum MessageType {
/**
* 创建.
*/
CREATE,
/**
* 修改.
*/
UPDATE,
/**
* 删除.
*/
DELETE;
public static MessageType of(String type) {
return switch (type) {
case "MESSAGE_CREATE" -> CREATE;
case "MESSAGE_UPDATE" -> UPDATE;
case "MESSAGE_DELETE" -> DELETE;
default -> null;
};
}
}

View File

@ -0,0 +1,42 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss;
import lombok.experimental.UtilityClass;
@UtilityClass
public final class ReturnCode {
/**
* 成功.
*/
public static final int SUCCESS = 1;
/**
* 数据未找到.
*/
public static final int NOT_FOUND = 3;
/**
* 校验错误.
*/
public static final int VALIDATION_ERROR = 4;
/**
* 系统异常.
*/
public static final int FAILURE = 9;
/**
* 已存在.
*/
public static final int EXISTED = 21;
/**
* 排队中.
*/
public static final int IN_QUEUE = 22;
/**
* 队列已满.
*/
public static final int QUEUE_REJECTED = 23;
/**
* prompt包含敏感词.
*/
public static final int BANNED_PROMPT = 24;
}

View File

@ -0,0 +1,8 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss;
public interface WebSocketStarter {
void start() throws Exception;
}

View File

@ -0,0 +1,6 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss.user;
public interface FailureCallback {
void onFailure(int code, String reason);
}

View File

@ -0,0 +1,189 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss.user;
import cn.hutool.core.exceptions.ValidateException;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.thread.ThreadUtil;
import cn.iocoder.yudao.framework.ai.midjourney.jad.DiscordAccount;
import cn.iocoder.yudao.framework.ai.midjourney.wss.AsyncLockUtils;
import cn.iocoder.yudao.framework.ai.midjourney.wss.ReturnCode;
import cn.iocoder.yudao.framework.ai.midjourney.wss.WebSocketStarter;
import lombok.extern.slf4j.Slf4j;
import org.apache.tomcat.websocket.Constants;
import org.jetbrains.annotations.NotNull;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.concurrent.TimeoutException;
@Slf4j
public class SpringUserWebSocketStarter implements WebSocketStarter {
private static final int CONNECT_RETRY_LIMIT = 5;
private final DiscordAccount account;
private final UserMessageListener userMessageListener;
private final String wssServer;
private final String resumeWss;
private boolean running = false;
private WebSocketSession webSocketSession = null;
private ResumeData resumeData = null;
public SpringUserWebSocketStarter(String wssServer, String resumeWss, DiscordAccount account, UserMessageListener userMessageListener) {
this.wssServer = wssServer;
this.resumeWss = resumeWss;
this.account = account;
this.userMessageListener = userMessageListener;
}
@Override
public void start() throws Exception {
start(false);
}
private void start(boolean reconnect) {
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.add("Accept-Encoding", "gzip, deflate, br");
headers.add("Accept-Language", "zh-CN,zh;q=0.9");
headers.add("Cache-Control", "no-cache");
headers.add("Pragma", "no-cache");
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
headers.add("User-Agent", this.account.getUserAgent());
var handler = new SpringWebSocketHandler(this.account, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
String gatewayUrl;
if (reconnect) {
gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
handler.setSessionId(this.resumeData.sessionId());
handler.setSequence(this.resumeData.sequence());
handler.setResumeGatewayUrl(this.resumeData.resumeGatewayUrl());
} else {
gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream";
}
var webSocketClient = new StandardWebSocketClient();
webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000");
var socketSessionFuture = webSocketClient.doHandshake(handler, headers, URI.create(gatewayUrl));
socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
@Override
public void onFailure(@NotNull Throwable e) {
onSocketFailure(SpringWebSocketHandler.CLOSE_CODE_EXCEPTION, e.getMessage());
}
@Override
public void onSuccess(WebSocketSession session) {
SpringUserWebSocketStarter.this.webSocketSession = session;
}
});
}
private void onSocketSuccess(String sessionId, Object sequence, String resumeGatewayUrl) {
this.resumeData = new ResumeData(sessionId, sequence, resumeGatewayUrl);
this.running = true;
notifyWssLock(ReturnCode.SUCCESS, "");
}
private void onSocketFailure(int code, String reason) {
if (code == 1001) {
return;
}
closeSocketSessionWhenIsOpen();
if (!this.running) {
notifyWssLock(code, reason);
return;
}
this.running = false;
if (code >= 4000) {
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.account.getDisplay(), code, reason);
disableAccount();
} else if (code == 2001) {
log.warn("[wss-{}] Closed by {}({}). Try reconnect...", this.account.getDisplay(), code, reason);
tryReconnect();
} else {
log.warn("[wss-{}] Closed by {}({}). Try new connection...", this.account.getDisplay(), code, reason);
tryNewConnect();
}
}
private void tryReconnect() {
try {
tryStart(true);
} catch (Exception e) {
if (e instanceof TimeoutException) {
closeSocketSessionWhenIsOpen();
}
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.account.getDisplay(), e.getMessage());
ThreadUtil.sleep(1000);
tryNewConnect();
}
}
private void tryNewConnect() {
for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) {
try {
tryStart(false);
return;
} catch (Exception e) {
if (e instanceof TimeoutException) {
closeSocketSessionWhenIsOpen();
}
log.warn("[wss-{}] New connect fail ({}): {}", this.account.getDisplay(), i, e.getMessage());
ThreadUtil.sleep(5000);
}
}
log.error("[wss-{}] Account disabled", this.account.getDisplay());
disableAccount();
}
public void tryStart(boolean reconnect) throws Exception {
start(reconnect);
AsyncLockUtils.LockObject lock = AsyncLockUtils.waitForLock("wss:" + this.account.getId(), Duration.ofSeconds(20));
int code = lock.getProperty("code", Integer.class, 0);
if (code == ReturnCode.SUCCESS) {
log.debug("[wss-{}] {} success.", this.account.getDisplay(), reconnect ? "Reconnect" : "New connect");
return;
}
throw new ValidateException(lock.getProperty("description", String.class));
}
private void notifyWssLock(int code, String reason) {
AsyncLockUtils.LockObject lock = AsyncLockUtils.getLock("wss:" + this.account.getId());
if (lock != null) {
lock.setProperty("code", code);
lock.setProperty("description", reason);
lock.awake();
}
}
private void disableAccount() {
if (Boolean.FALSE.equals(this.account.isEnable())) {
return;
}
this.account.setEnable(false);
}
private void closeSocketSessionWhenIsOpen() {
try {
if (this.webSocketSession != null && this.webSocketSession.isOpen()) {
this.webSocketSession.close(CloseStatus.GOING_AWAY);
}
} catch (IOException e) {
// do nothing
}
}
private String getGatewayServer(String resumeGatewayUrl) {
if (CharSequenceUtil.isNotBlank(resumeGatewayUrl)) {
return CharSequenceUtil.isBlank(this.resumeWss) ? resumeGatewayUrl : this.resumeWss;
}
return this.wssServer;
}
public record ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
}
}

View File

@ -0,0 +1,240 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss.user;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.http.useragent.UserAgent;
import cn.hutool.http.useragent.UserAgentUtil;
import cn.iocoder.yudao.framework.ai.midjourney.jad.DiscordAccount;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataArray;
import net.dv8tion.jda.api.utils.data.DataObject;
import net.dv8tion.jda.api.utils.data.DataType;
import net.dv8tion.jda.internal.requests.WebSocketCode;
import net.dv8tion.jda.internal.utils.compress.Decompressor;
import net.dv8tion.jda.internal.utils.compress.ZlibDecompressor;
import org.jetbrains.annotations.NotNull;
import org.springframework.web.socket.*;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
@Slf4j
public class SpringWebSocketHandler implements WebSocketHandler {
public static final int CLOSE_CODE_RECONNECT = 2001;
public static final int CLOSE_CODE_INVALIDATE = 1009;
public static final int CLOSE_CODE_EXCEPTION = 1011;
private final DiscordAccount account;
private final UserMessageListener userMessageListener;
private final SuccessCallback successCallback;
private final FailureCallback failureCallback;
private final ScheduledExecutorService heartExecutor;
private final DataObject authData;
@Setter
private String sessionId = null;
@Setter
private Object sequence = null;
@Setter
private String resumeGatewayUrl = null;
private long interval = 41250;
private boolean heartbeatAck = false;
private Future<?> heartbeatInterval;
private Future<?> heartbeatTimeout;
private final Decompressor decompressor = new ZlibDecompressor(2048);
public SpringWebSocketHandler(DiscordAccount account, UserMessageListener userMessageListener, SuccessCallback successCallback, FailureCallback failureCallback) {
this.account = account;
this.userMessageListener = userMessageListener;
this.successCallback = successCallback;
this.failureCallback = failureCallback;
this.heartExecutor = Executors.newSingleThreadScheduledExecutor();
this.authData = createAuthData();
}
@Override
public void afterConnectionEstablished(@NotNull WebSocketSession session) throws Exception {
// do nothing
}
@Override
public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception {
log.error("[wss-{}] Transport error", this.account.getDisplay(), e);
onFailure(CLOSE_CODE_EXCEPTION, "transport error");
}
@Override
public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception {
onFailure(closeStatus.getCode(), closeStatus.getReason());
}
@Override
public boolean supportsPartialMessages() {
return true;
}
@Override
public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception {
ByteBuffer buffer = (ByteBuffer) message.getPayload();
byte[] decompressed = decompressor.decompress(buffer.array());
if (decompressed == null) {
return;
}
String json = new String(decompressed, StandardCharsets.UTF_8);
DataObject data = DataObject.fromJson(json);
int opCode = data.getInt("op");
switch (opCode) {
case WebSocketCode.HEARTBEAT -> handleHeartbeat(session);
case WebSocketCode.HEARTBEAT_ACK -> {
this.heartbeatAck = true;
clearHeartbeatTimeout();
}
case WebSocketCode.HELLO -> {
handleHello(session, data);
doResumeOrIdentify(session);
}
case WebSocketCode.RESUME -> onSuccess();
case WebSocketCode.RECONNECT -> onFailure(CLOSE_CODE_RECONNECT, "receive server reconnect");
case WebSocketCode.INVALIDATE_SESSION -> onFailure(CLOSE_CODE_INVALIDATE, "receive session invalid");
case WebSocketCode.DISPATCH -> handleDispatch(data);
default -> log.debug("[wss-{}] Receive unknown code: {}.", account.getDisplay(), data);
}
}
private void handleDispatch(DataObject raw) {
this.sequence = raw.opt("s").orElse(null);
if (!raw.isType("d", DataType.OBJECT)) {
return;
}
DataObject content = raw.getObject("d");
String t = raw.getString("t", null);
if ("READY".equals(t)) {
this.sessionId = content.getString("session_id");
this.resumeGatewayUrl = content.getString("resume_gateway_url");
onSuccess();
} else if ("RESUMED".equals(t)) {
onSuccess();
} else {
try {
this.userMessageListener.onMessage(raw);
} catch (Exception e) {
log.error("[wss-{}] Handle message error", this.account.getDisplay(), e);
}
}
}
private void handleHeartbeat(WebSocketSession session) {
sendMessage(session, WebSocketCode.HEARTBEAT, this.sequence);
this.heartbeatTimeout = ThreadUtil.execAsync(() -> {
ThreadUtil.sleep(this.interval);
onFailure(CLOSE_CODE_RECONNECT, "heartbeat has not ack");
});
}
private void handleHello(WebSocketSession session, DataObject data) {
clearHeartbeatInterval();
this.interval = data.getObject("d").getLong("heartbeat_interval");
this.heartbeatAck = true;
this.heartbeatInterval = this.heartExecutor.scheduleAtFixedRate(() -> {
if (this.heartbeatAck) {
this.heartbeatAck = false;
sendMessage(session, WebSocketCode.HEARTBEAT, this.sequence);
} else {
onFailure(CLOSE_CODE_RECONNECT, "heartbeat has not ack interval");
}
}, (long) Math.floor(RandomUtil.randomDouble(0, 1) * this.interval), this.interval, TimeUnit.MILLISECONDS);
}
private void doResumeOrIdentify(WebSocketSession session) {
if (CharSequenceUtil.isBlank(this.sessionId)) {
sendMessage(session, WebSocketCode.IDENTIFY, this.authData);
} else {
var data = DataObject.empty().put("token", this.account.getUserToken())
.put("session_id", this.sessionId).put("seq", this.sequence);
sendMessage(session, WebSocketCode.RESUME, data);
}
}
private void sendMessage(WebSocketSession session, int op, Object d) {
var data = DataObject.empty().put("op", op).put("d", d);
try {
session.sendMessage(new TextMessage(data.toString()));
} catch (IOException e) {
log.error("[wss-{}] Send message error", this.account.getDisplay(), e);
onFailure(CLOSE_CODE_EXCEPTION, "send message error");
}
}
private void onSuccess() {
ThreadUtil.execute(() -> this.successCallback.onSuccess(this.sessionId, this.sequence, this.resumeGatewayUrl));
}
private void onFailure(int code, String reason) {
clearHeartbeatTimeout();
clearHeartbeatInterval();
ThreadUtil.execute(() -> this.failureCallback.onFailure(code, reason));
}
private void clearHeartbeatTimeout() {
if (this.heartbeatTimeout != null) {
this.heartbeatTimeout.cancel(true);
this.heartbeatTimeout = null;
}
}
private void clearHeartbeatInterval() {
if (this.heartbeatInterval != null) {
this.heartbeatInterval.cancel(true);
this.heartbeatInterval = null;
}
}
private DataObject createAuthData() {
UserAgent userAgent = UserAgentUtil.parse(this.account.getUserAgent());
DataObject connectionProperties = DataObject.empty()
.put("browser", userAgent.getBrowser().getName())
.put("browser_user_agent", this.account.getUserAgent())
.put("browser_version", userAgent.getVersion())
.put("client_build_number", 222963)
.put("client_event_source", null)
.put("device", "")
.put("os", userAgent.getOs().getName())
.put("referer", "https://www.midjourney.com")
.put("referrer_current", "")
.put("referring_domain", "www.midjourney.com")
.put("referring_domain_current", "")
.put("release_channel", "stable")
.put("system_locale", "zh-CN");
DataObject presence = DataObject.empty()
.put("activities", DataArray.empty())
.put("afk", false)
.put("since", 0)
.put("status", "online");
DataObject clientState = DataObject.empty()
.put("api_code_version", 0)
.put("guild_versions", DataObject.empty())
.put("highest_last_message_id", "0")
.put("private_channels_version", "0")
.put("read_state_version", 0)
.put("user_guild_settings_version", -1)
.put("user_settings_version", -1);
return DataObject.empty()
.put("capabilities", 16381)
.put("client_state", clientState)
.put("compress", false)
.put("presence", presence)
.put("properties", connectionProperties)
.put("token", this.account.getUserToken());
}
}

View File

@ -0,0 +1,7 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss.user;
public interface SuccessCallback {
void onSuccess(String sessionId, Object sequence, String resumeGatewayUrl);
}

View File

@ -0,0 +1,28 @@
package cn.iocoder.yudao.framework.ai.midjourney.wss.user;
import cn.hutool.core.thread.ThreadUtil;
import cn.iocoder.yudao.framework.ai.midjourney.wss.MessageType;
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataObject;
@Slf4j
public class UserMessageListener {
public void onMessage(DataObject raw) {
MessageType messageType = MessageType.of(raw.getString("t"));
if (messageType == null || MessageType.DELETE == messageType) {
return;
}
DataObject data = raw.getObject("d");
System.err.println(data);
ThreadUtil.sleep(50);
// for (MessageHandler messageHandler : this.messageHandlers) {
// if (data.getBoolean(Constants.MJ_MESSAGE_HANDLED, false)) {
// return;
// }
// messageHandler.handle(this.instance, messageType, data);
// }
}
}