增加mj web socket

This commit is contained in:
cherishsince 2024-04-05 09:27:09 +08:00
parent 28b2fad7b6
commit 5044a58118
26 changed files with 110 additions and 591 deletions

View File

@ -0,0 +1,8 @@
package cn.iocoder.yudao.framework.ai.midjourney.constants;
public final class MjConstants {
public static final String CHANNEL_ID = "channel_id";
}

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.framework.ai.midjourney;
package cn.iocoder.yudao.framework.ai.midjourney.constants;
import lombok.Getter;
@ -6,7 +6,7 @@ import lombok.Getter;
* MJ 命令
*/
@Getter
public enum MidjourneyInteractionsEnum {
public enum MjInteractionsEnum {
IMAGINE("imagine", "生成图片"),
DESCRIBE("describe", "生成描述"),
@ -17,7 +17,7 @@ public enum MidjourneyInteractionsEnum {
;
MidjourneyInteractionsEnum(String value, String message) {
MjInteractionsEnum(String value, String message) {
this.value =value;
this.message =message;
}

View File

@ -1,7 +1,7 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.wss;
package cn.iocoder.yudao.framework.ai.midjourney.constants;
public enum MessageType {
public enum MjMessageTypeEnum {
/**
* 创建.
*/
@ -15,7 +15,7 @@ public enum MessageType {
*/
DELETE;
public static MessageType of(String type) {
public static MjMessageTypeEnum of(String type) {
return switch (type) {
case "MESSAGE_CREATE" -> CREATE;
case "MESSAGE_UPDATE" -> UPDATE;

View File

@ -1,9 +1,9 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.wss;
package cn.iocoder.yudao.framework.ai.midjourney.constants;
import lombok.experimental.UtilityClass;
@UtilityClass
public final class ReturnCode {
public final class MjNotifyCode {
/**
* 成功.
*/

View File

@ -1,10 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo;
/**
* author: fansili
* time: 2024/4/3 15:54
*/
public class DiscordJadMain {
}

View File

@ -1,141 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo;
import cn.hutool.http.useragent.UserAgent;
import cn.hutool.http.useragent.UserAgentUtil;
import cn.hutool.json.JSONObject;
import com.alibaba.fastjson.JSON;
import org.springframework.web.socket.*;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
/**
* https://blog.csdn.net/qq_38490457/article/details/125250135
*/
public class DiscordWebSocketClient {
private static final String DISCORD_GATEWAY_URL = "wss://gateway.discord.gg/?v=9&encoding=json";
public static void main(String[] args) throws InterruptedException, ExecutionException, IOException, URISyntaxException {
StandardWebSocketClient client = new StandardWebSocketClient();
DiscordWebSocketHandler handler = new DiscordWebSocketHandler();
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.add("Accept-Encoding", "gzip, deflate, br");
headers.add("Accept-Language", "zh-CN,zh;q=0.9");
headers.add("Cache-Control", "no-cache");
headers.add("Pragma", "no-cache");
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
headers.add("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36");
Future<WebSocketSession> futureSession = client.doHandshake(handler, headers, new URI(DISCORD_GATEWAY_URL));
WebSocketSession session = futureSession.get(); // 这会阻塞直到连接建立
// 登录过程你需要替换 TOKEN 为你的 Discord Bot Token
// String token = "YOUR_DISCORD_BOT_TOKEN"; // 请替换为你的 Bot Token
// String identifyPayload = "{\"op\":2,\"d\":{\"token\":\"" + token + "\",\"properties\":{\"$os\":\"java\",\"$browser\":\"spring-websocket\",\"$device\":\"spring-websocket\"},\"compress\":false,\"large_threshold\":256,\"shard\":[0,1]}}";
// session.sendMessage(new TextMessage(identifyPayload));
// 发送心跳以保持连接活跃
Thread heartbeatThread = new Thread(() -> {
int interval = 0; // 初始心跳间隔后续从 Discord 服务器获取
while (!Thread.currentThread().isInterrupted()) {
try {
Thread.sleep(interval * 1000); // 等待指定的心跳间隔
session.sendMessage(new TextMessage("{\"op\":1,\"d\":null}")); // 发送心跳包
} catch (Exception e) {
e.printStackTrace();
break;
}
}
});
heartbeatThread.start();
// 等待用户输入来保持程序运行仅用于示例
System.in.read();
// 关闭连接和线程
session.close();
heartbeatThread.interrupt();
}
private static class DiscordWebSocketHandler implements WebSocketHandler {
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
Object payload = message.getPayload();
session.sendMessage(new TextMessage(JSON.toJSONString(createAuthData())));
String a= "";
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
}
@Override
public boolean supportsPartialMessages() {
return false;
}
private JSONObject createAuthData() {
String userAgentStr = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36";
UserAgent userAgent = UserAgentUtil.parse(userAgentStr);
String token = "NTY5MDY4NDAxNzEyOTU1Mzky.G4-Fu0.MzD-7ll-ElbXTTgDPHF-WS_UyhMAfbKN3WyyBc";
JSONObject connectionProperties = new JSONObject()
.put("browser", userAgent.getBrowser().getName())
.put("browser_user_agent", userAgentStr)
.put("browser_version", userAgent.getVersion())
.put("client_build_number", 222963)
.put("client_event_source", null)
.put("device", "")
.put("os", userAgent.getOs().getName())
.put("referer", "https://www.midjourney.com")
.put("referrer_current", "")
.put("referring_domain", "www.midjourney.com")
.put("referring_domain_current", "")
.put("release_channel", "stable")
.put("system_locale", "zh-CN");
JSONObject presence = new JSONObject()
.put("activities", "")
.put("afk", false)
.put("since", 0)
.put("status", "online");
JSONObject clientState = new JSONObject()
.put("api_code_version", 0)
.put("guild_versions", "")
.put("highest_last_message_id", "0")
.put("private_channels_version", "0")
.put("read_state_version", 0)
.put("user_guild_settings_version", -1)
.put("user_settings_version", -1);
return new JSONObject()
.put("capabilities", 16381)
.put("client_state", clientState)
.put("compress", false)
.put("presence", presence)
.put("properties", connectionProperties)
.put("token", token);
}
}
}

View File

@ -1,19 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo;
/**
* 文档: https://www.xiubbs.com/t-401-1-1.html
*
* https://github.com/novicezk/midjourney-proxy/blob/main/README_CN.md
*
* discord4jhttps://github.com/discord-jda/JDA
*
*/
public class MidjourneyApi {
// https://discord.com/api/v9/interactions
//
// payload_json: {"type":2,"application_id":"936929561302675456","guild_id":"1224337694918971392","channel_id":"1224337694918971396","session_id":"696318caed5180a2210e358e44801449","data":{"version":"1166847114203123795","id":"938956540159881230","name":"imagine","type":1,"options":[{"type":3,"name":"prompt","value":"中国的是什么样子"}],"application_command":{"id":"938956540159881230","type":1,"application_id":"936929561302675456","version":"1166847114203123795","name":"imagine","description":"Create images with Midjourney","options":[{"type":3,"name":"prompt","description":"The prompt to imagine","required":true,"description_localized":"The prompt to imagine","name_localized":"prompt"}],"dm_permission":true,"integration_types":[0],"global_popularity_rank":1,"description_localized":"Create images with Midjourney","name_localized":"imagine"},"attachments":[]},"nonce":"1224342266261274624","analytics_location":"slash_ui"}
//
}

View File

@ -1,19 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyInteractionsEnum;
/**
* mj 命令
*/
public interface MjExecute {
/**
* 执行命令
*
* @param mjCommand
* @param prompt
* @return
*/
boolean execute(MidjourneyInteractionsEnum mjCommand, String prompt);
}

View File

@ -1,50 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo;
import cn.hutool.core.io.FileUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyInteractionsEnum;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.web.client.RestTemplate;
import java.nio.charset.Charset;
public class MjHttpExecute implements MjExecute {
private static final String URL = "https://discord.com/";
@Override
public boolean execute(MidjourneyInteractionsEnum mjCommand, String prompt) {
// 发送的 uri
String uri = "api/v9/interactions";
// restTemplate 发送post请求
// String result = restTemplate.postForObject(URL + uri, prompt, String.class);
// 加载当前目录下文件
return false;
}
public static void main(String[] args) {
RestTemplate restTemplate = new RestTemplate();
String token = "NTY5MDY4NDAxNzEyOTU1Mzky.G4-Fu0.MzD-7ll-ElbXTTgDPHF-WS_UyhMAfbKN3WyyBc";
String body = FileUtil.readString("/Users/fansili/projects/github/ruoyi-vue-pro/yudao-module-ai" +
"/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney" +
"/interactions_type2.json", Charset.forName("utf-8"));
// 创建HTTP头部
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON); // 设置内容类型为JSON
headers.set("Authorization", token); // 如果需要设置认证信息例如JWT令牌
headers.set("Referer", "https://discord.com/channels/1221445697157468200/1221445862962630706"); // 如果需要设置认证信息例如JWT令牌
headers.set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"); // 如果需要设置认证信息例如JWT令牌
headers.set("Cookie", "__dcfduid=6ca536c0e3fa11eeb7cbe34c31b49caf; __sdcfduid=6ca536c1e3fa11eeb7cbe34c31b49caf52cce5ffd8983d2a052cf6aba75fe5fe566f2c265902e283ce30dbf98b8c9c93; _gcl_au=1.1.245923998.1710853617; _ga=GA1.1.111061823.1710853617; __cfruid=6385bb3f48345a006b25992db7dcf984e395736d-1712124666; _cfuvid=O09la5ms0ypNptiG0iD8A6BKWlTxz1LG0WR7qRStD7o-1712124666575-0.0.1.1-604800000; locale=zh-CN; cf_clearance=l_YGod1_SUtYxpDVeZXiX7DLLPl1DYrquZe8WVltvYs-1712124668-1.0.1.1-Hl2.fToel23EpF2HCu9J20rB4D7OhhCzoajPSdo.9Up.wPxhvq22DP9RHzEBKuIUlKyH.kJLxXJfAt2N.LD5WQ; OptanonConsent=isIABGlobal=false&datestamp=Wed+Apr+03+2024+14%3A11%3A15+GMT%2B0800+(%E4%B8%AD%E5%9B%BD%E6%A0%87%E5%87%86%E6%97%B6%E9%97%B4)&version=6.33.0&hosts=&landingPath=https%3A%2F%2Fdiscord.com%2F&groups=C0001%3A1%2CC0002%3A1%2CC0003%3A1; _ga_Q149DFWHT7=GS1.1.1712124668.4.1.1712124679.0.0.0"); // 如果需要设置认证信息例如JWT令牌
// 封装请求体和头部信息
HttpEntity<String> requestEntity = new HttpEntity<>(body, headers);
// 定义请求URL和返回类型
String uri = "api/v9/interactions";
String res = restTemplate.postForObject(URL + uri, requestEntity, String.class);
System.err.println("11");
}
}

View File

@ -1,22 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.jad;
import lombok.experimental.UtilityClass;
@UtilityClass
public final class Constants {
// 任务扩展属性 start
public static final String TASK_PROPERTY_NOTIFY_HOOK = "notifyHook";
public static final String TASK_PROPERTY_FINAL_PROMPT = "finalPrompt";
public static final String TASK_PROPERTY_MESSAGE_ID = "messageId";
public static final String TASK_PROPERTY_MESSAGE_HASH = "messageHash";
public static final String TASK_PROPERTY_PROGRESS_MESSAGE_ID = "progressMessageId";
public static final String TASK_PROPERTY_FLAGS = "flags";
public static final String TASK_PROPERTY_NONCE = "nonce";
public static final String TASK_PROPERTY_DISCORD_INSTANCE_ID = "discordInstanceId";
// 任务扩展属性 end
public static final String API_SECRET_HEADER_NAME = "mj-api-secret";
public static final String DEFAULT_DISCORD_USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36";
public static final String MJ_MESSAGE_HANDLED = "mj_proxy_handled";
}

View File

@ -1,35 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.jad;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = true)
public class DiscordAccount extends DomainObject {
private String guildId;
private String channelId;
private String userToken;
private String userAgent = Constants.DEFAULT_DISCORD_USER_AGENT;
private boolean enable = true;
private int coreSize = 3;
private int queueSize = 10;
private int timeoutMinutes = 5;
@JsonIgnore
public String getDisplay() {
return this.channelId;
}
}

View File

@ -1,70 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.jad;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.Getter;
import lombok.Setter;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
public class DomainObject implements Serializable {
@Getter
@Setter
protected String id;
@Setter
protected Map<String, Object> properties; // 扩展属性仅支持基本类型
@JsonIgnore
private final transient Object lock = new Object();
public void sleep() throws InterruptedException {
synchronized (this.lock) {
this.lock.wait();
}
}
public void awake() {
synchronized (this.lock) {
this.lock.notifyAll();
}
}
public DomainObject setProperty(String name, Object value) {
getProperties().put(name, value);
return this;
}
public DomainObject removeProperty(String name) {
getProperties().remove(name);
return this;
}
public Object getProperty(String name) {
return getProperties().get(name);
}
@SuppressWarnings("unchecked")
public <T> T getPropertyGeneric(String name) {
return (T) getProperty(name);
}
public <T> T getProperty(String name, Class<T> clz) {
return getProperty(name, clz, null);
}
public <T> T getProperty(String name, Class<T> clz, T defaultValue) {
Object value = getProperty(name);
return value == null ? defaultValue : clz.cast(value);
}
public Map<String, Object> getProperties() {
if (this.properties == null) {
this.properties = new HashMap<>();
}
return this.properties;
}
}

View File

@ -1,61 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.wss;
import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.TimedCache;
import cn.hutool.core.thread.ThreadUtil;
import cn.iocoder.yudao.framework.ai.midjourney.demo.jad.DomainObject;
import lombok.experimental.UtilityClass;
import java.time.Duration;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
@UtilityClass
public class AsyncLockUtils {
private static final TimedCache<String, LockObject> LOCK_MAP = CacheUtil.newTimedCache(Duration.ofDays(1).toMillis());
public static synchronized LockObject getLock(String key) {
return LOCK_MAP.get(key);
}
public static LockObject waitForLock(String key, Duration duration) throws TimeoutException {
LockObject lockObject;
synchronized (LOCK_MAP) {
if (LOCK_MAP.containsKey(key)) {
lockObject = LOCK_MAP.get(key);
} else {
lockObject = new LockObject(key);
LOCK_MAP.put(key, lockObject);
}
}
Future<?> future = ThreadUtil.execAsync(() -> {
try {
lockObject.sleep();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});
try {
future.get(duration.toMillis(), TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (ExecutionException e) {
// do nothing
} catch (TimeoutException e) {
future.cancel(true);
throw new TimeoutException("Wait Timeout");
} finally {
LOCK_MAP.remove(lockObject.getId());
}
return lockObject;
}
public static class LockObject extends DomainObject {
public LockObject(String id) {
this.id = id;
}
}
}

View File

@ -1,57 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.wss;
import cn.hutool.core.text.CharSequenceUtil;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
@Component
@RequiredArgsConstructor
public class DiscordHelper {
/**
* DISCORD_SERVER_URL.
*/
public static final String DISCORD_SERVER_URL = "https://discord.com";
/**
* DISCORD_CDN_URL.
*/
public static final String DISCORD_CDN_URL = "https://cdn.discordapp.com";
/**
* DISCORD_WSS_URL.
*/
public static final String DISCORD_WSS_URL = "wss://gateway.discord.gg";
/**
* DISCORD_UPLOAD_URL.
*/
public static final String DISCORD_UPLOAD_URL = "https://discord-attachments-uploads-prd.storage.googleapis.com";
public String getServer() {
return DISCORD_SERVER_URL;
}
public String getCdn() {
return DISCORD_CDN_URL;
}
public String getWss() {
return DISCORD_WSS_URL;
}
public String getMessageHash(String imageUrl) {
if (CharSequenceUtil.isBlank(imageUrl)) {
return null;
}
if (CharSequenceUtil.endWith(imageUrl, "_grid_0.webp")) {
int hashStartIndex = imageUrl.lastIndexOf("/");
if (hashStartIndex < 0) {
return null;
}
return CharSequenceUtil.sub(imageUrl, hashStartIndex + 1, imageUrl.length() - "_grid_0.webp".length());
}
int hashStartIndex = imageUrl.lastIndexOf("_");
if (hashStartIndex < 0) {
return null;
}
return CharSequenceUtil.subBefore(imageUrl.substring(hashStartIndex + 1), ".", true);
}
}

View File

@ -1,27 +0,0 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.wss.user;
import cn.hutool.core.thread.ThreadUtil;
import cn.iocoder.yudao.framework.ai.midjourney.demo.wss.MessageType;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataObject;
@Slf4j
public class UserMessageListener {
public void onMessage(DataObject raw) {
MessageType messageType = MessageType.of(raw.getString("t"));
if (messageType == null || MessageType.DELETE == messageType) {
return;
}
DataObject data = raw.getObject("d");
System.err.println(data);
ThreadUtil.sleep(50);
// for (MessageHandler messageHandler : this.messageHandlers) {
// if (data.getBoolean(Constants.MJ_MESSAGE_HANDLED, false)) {
// return;
// }
// messageHandler.handle(this.instance, messageType, data);
// }
}
}

View File

@ -3,7 +3,7 @@ package cn.iocoder.yudao.framework.ai.midjourney.interactions;
import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyInteractionsEnum;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjInteractionsEnum;
import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j;
@ -27,7 +27,7 @@ public class MjImagineInteractions implements MjInteractions {
}
@Override
public List<MidjourneyInteractionsEnum> supperInteractions() {
public List<MjInteractionsEnum> supperInteractions() {
return null;
}

View File

@ -1,6 +1,6 @@
package cn.iocoder.yudao.framework.ai.midjourney.interactions;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyInteractionsEnum;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjInteractionsEnum;
import java.util.List;
@ -12,7 +12,7 @@ import java.util.List;
*/
public interface MjInteractions {
List<MidjourneyInteractionsEnum> supperInteractions();
List<MjInteractionsEnum> supperInteractions();
Boolean execute(String prompt);
}

View File

@ -1,5 +0,0 @@
/**
* author: fansili
* time: 2024/4/3 17:08
*/
package cn.iocoder.yudao.framework.ai.midjourney;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.wss.user;
package cn.iocoder.yudao.framework.ai.midjourney.webSocket;
public interface FailureCallback {

View File

@ -1,13 +1,12 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.wss.user;
package cn.iocoder.yudao.framework.ai.midjourney.webSocket;
import cn.hutool.core.exceptions.ValidateException;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.thread.ThreadUtil;
import cn.iocoder.yudao.framework.ai.midjourney.demo.jad.DiscordAccount;
import cn.iocoder.yudao.framework.ai.midjourney.demo.wss.AsyncLockUtils;
import cn.iocoder.yudao.framework.ai.midjourney.demo.wss.ReturnCode;
import cn.iocoder.yudao.framework.ai.midjourney.demo.wss.WebSocketStarter;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
import lombok.extern.slf4j.Slf4j;
import org.apache.tomcat.websocket.Constants;
import org.jetbrains.annotations.NotNull;
@ -19,15 +18,14 @@ import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.concurrent.TimeoutException;
@Slf4j
public class SpringUserWebSocketStarter implements WebSocketStarter {
public class MjWebSocketStarter implements WebSocketStarter {
private static final int CONNECT_RETRY_LIMIT = 5;
private final DiscordAccount account;
private final UserMessageListener userMessageListener;
private final MidjourneyConfig midjourneyConfig;
private final MjMessageListener userMessageListener;
private final String wssServer;
private final String resumeWss;
@ -36,10 +34,10 @@ public class SpringUserWebSocketStarter implements WebSocketStarter {
private WebSocketSession webSocketSession = null;
private ResumeData resumeData = null;
public SpringUserWebSocketStarter(String wssServer, String resumeWss, DiscordAccount account, UserMessageListener userMessageListener) {
public MjWebSocketStarter(String wssServer, String resumeWss, MidjourneyConfig midjourneyConfig, MjMessageListener userMessageListener) {
this.wssServer = wssServer;
this.resumeWss = resumeWss;
this.account = account;
this.midjourneyConfig = midjourneyConfig;
this.userMessageListener = userMessageListener;
}
@ -55,8 +53,8 @@ public class SpringUserWebSocketStarter implements WebSocketStarter {
headers.add("Cache-Control", "no-cache");
headers.add("Pragma", "no-cache");
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
headers.add("User-Agent", this.account.getUserAgent());
var handler = new SpringWebSocketHandler(this.account, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
headers.add("User-Agent", this.midjourneyConfig.getUserAage());
var handler = new MjWebSocketHandler(this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
String gatewayUrl;
if (reconnect) {
gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
@ -72,12 +70,12 @@ public class SpringUserWebSocketStarter implements WebSocketStarter {
socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
@Override
public void onFailure(@NotNull Throwable e) {
onSocketFailure(SpringWebSocketHandler.CLOSE_CODE_EXCEPTION, e.getMessage());
onSocketFailure(MjWebSocketHandler.CLOSE_CODE_EXCEPTION, e.getMessage());
}
@Override
public void onSuccess(WebSocketSession session) {
SpringUserWebSocketStarter.this.webSocketSession = session;
MjWebSocketStarter.this.webSocketSession = session;
}
});
}
@ -85,7 +83,7 @@ public class SpringUserWebSocketStarter implements WebSocketStarter {
private void onSocketSuccess(String sessionId, Object sequence, String resumeGatewayUrl) {
this.resumeData = new ResumeData(sessionId, sequence, resumeGatewayUrl);
this.running = true;
notifyWssLock(ReturnCode.SUCCESS, "");
notifyWssLock(MjNotifyCode.SUCCESS, "");
}
private void onSocketFailure(int code, String reason) {
@ -99,13 +97,12 @@ public class SpringUserWebSocketStarter implements WebSocketStarter {
}
this.running = false;
if (code >= 4000) {
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.account.getDisplay(), code, reason);
disableAccount();
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason);
} else if (code == 2001) {
log.warn("[wss-{}] Closed by {}({}). Try reconnect...", this.account.getDisplay(), code, reason);
log.warn("[wss-{}] Closed by {}({}). Try reconnect...", this.midjourneyConfig.getChannelId(), code, reason);
tryReconnect();
} else {
log.warn("[wss-{}] Closed by {}({}). Try new connection...", this.account.getDisplay(), code, reason);
log.warn("[wss-{}] Closed by {}({}). Try new connection...", this.midjourneyConfig.getChannelId(), code, reason);
tryNewConnect();
}
}
@ -117,7 +114,7 @@ public class SpringUserWebSocketStarter implements WebSocketStarter {
if (e instanceof TimeoutException) {
closeSocketSessionWhenIsOpen();
}
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.account.getDisplay(), e.getMessage());
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
ThreadUtil.sleep(1000);
tryNewConnect();
}
@ -132,39 +129,19 @@ public class SpringUserWebSocketStarter implements WebSocketStarter {
if (e instanceof TimeoutException) {
closeSocketSessionWhenIsOpen();
}
log.warn("[wss-{}] New connect fail ({}): {}", this.account.getDisplay(), i, e.getMessage());
log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
ThreadUtil.sleep(5000);
}
}
log.error("[wss-{}] Account disabled", this.account.getDisplay());
disableAccount();
log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId());
}
public void tryStart(boolean reconnect) throws Exception {
start(reconnect);
AsyncLockUtils.LockObject lock = AsyncLockUtils.waitForLock("wss:" + this.account.getId(), Duration.ofSeconds(20));
int code = lock.getProperty("code", Integer.class, 0);
if (code == ReturnCode.SUCCESS) {
log.debug("[wss-{}] {} success.", this.account.getDisplay(), reconnect ? "Reconnect" : "New connect");
return;
}
throw new ValidateException(lock.getProperty("description", String.class));
}
private void notifyWssLock(int code, String reason) {
AsyncLockUtils.LockObject lock = AsyncLockUtils.getLock("wss:" + this.account.getId());
if (lock != null) {
lock.setProperty("code", code);
lock.setProperty("description", reason);
lock.awake();
}
}
private void disableAccount() {
if (Boolean.FALSE.equals(this.account.isEnable())) {
return;
}
this.account.setEnable(false);
System.err.println("notifyWssLock: " + code + " - " + reason);
}
private void closeSocketSessionWhenIsOpen() {

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.wss.user;
package cn.iocoder.yudao.framework.ai.midjourney.webSocket;
public interface SuccessCallback {

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.wss;
package cn.iocoder.yudao.framework.ai.midjourney.webSocket;
public interface WebSocketStarter {

View File

@ -1,11 +1,14 @@
package cn.iocoder.yudao.framework.ai.midjourney.demo.wss.user;
package cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.http.useragent.UserAgent;
import cn.hutool.http.useragent.UserAgentUtil;
import cn.iocoder.yudao.framework.ai.midjourney.demo.jad.DiscordAccount;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.FailureCallback;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.SuccessCallback;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataArray;
@ -26,13 +29,13 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
@Slf4j
public class SpringWebSocketHandler implements WebSocketHandler {
public class MjWebSocketHandler implements WebSocketHandler {
public static final int CLOSE_CODE_RECONNECT = 2001;
public static final int CLOSE_CODE_INVALIDATE = 1009;
public static final int CLOSE_CODE_EXCEPTION = 1011;
private final DiscordAccount account;
private final UserMessageListener userMessageListener;
private final MidjourneyConfig midjourneyConfig;
private final MjMessageListener userMessageListener;
private final SuccessCallback successCallback;
private final FailureCallback failureCallback;
@ -54,8 +57,11 @@ public class SpringWebSocketHandler implements WebSocketHandler {
private final Decompressor decompressor = new ZlibDecompressor(2048);
public SpringWebSocketHandler(DiscordAccount account, UserMessageListener userMessageListener, SuccessCallback successCallback, FailureCallback failureCallback) {
this.account = account;
public MjWebSocketHandler(MidjourneyConfig account,
MjMessageListener userMessageListener,
SuccessCallback successCallback,
FailureCallback failureCallback) {
this.midjourneyConfig = account;
this.userMessageListener = userMessageListener;
this.successCallback = successCallback;
this.failureCallback = failureCallback;
@ -70,7 +76,7 @@ public class SpringWebSocketHandler implements WebSocketHandler {
@Override
public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception {
log.error("[wss-{}] Transport error", this.account.getDisplay(), e);
log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e);
onFailure(CLOSE_CODE_EXCEPTION, "transport error");
}
@ -108,7 +114,7 @@ public class SpringWebSocketHandler implements WebSocketHandler {
case WebSocketCode.RECONNECT -> onFailure(CLOSE_CODE_RECONNECT, "receive server reconnect");
case WebSocketCode.INVALIDATE_SESSION -> onFailure(CLOSE_CODE_INVALIDATE, "receive session invalid");
case WebSocketCode.DISPATCH -> handleDispatch(data);
default -> log.debug("[wss-{}] Receive unknown code: {}.", account.getDisplay(), data);
default -> log.debug("[wss-{}] Receive unknown code: {}.", midjourneyConfig.getChannelId(), data);
}
}
@ -129,7 +135,7 @@ public class SpringWebSocketHandler implements WebSocketHandler {
try {
this.userMessageListener.onMessage(raw);
} catch (Exception e) {
log.error("[wss-{}] Handle message error", this.account.getDisplay(), e);
log.error("[wss-{}] Handle message error", this.midjourneyConfig.getChannelId(), e);
}
}
}
@ -160,7 +166,7 @@ public class SpringWebSocketHandler implements WebSocketHandler {
if (CharSequenceUtil.isBlank(this.sessionId)) {
sendMessage(session, WebSocketCode.IDENTIFY, this.authData);
} else {
var data = DataObject.empty().put("token", this.account.getUserToken())
var data = DataObject.empty().put("token", this.midjourneyConfig.getToken())
.put("session_id", this.sessionId).put("seq", this.sequence);
sendMessage(session, WebSocketCode.RESUME, data);
}
@ -171,7 +177,7 @@ public class SpringWebSocketHandler implements WebSocketHandler {
try {
session.sendMessage(new TextMessage(data.toString()));
} catch (IOException e) {
log.error("[wss-{}] Send message error", this.account.getDisplay(), e);
log.error("[wss-{}] Send message error", this.midjourneyConfig.getChannelId(), e);
onFailure(CLOSE_CODE_EXCEPTION, "send message error");
}
}
@ -201,10 +207,10 @@ public class SpringWebSocketHandler implements WebSocketHandler {
}
private DataObject createAuthData() {
UserAgent userAgent = UserAgentUtil.parse(this.account.getUserAgent());
UserAgent userAgent = UserAgentUtil.parse(this.midjourneyConfig.getUserAage());
DataObject connectionProperties = DataObject.empty()
.put("browser", userAgent.getBrowser().getName())
.put("browser_user_agent", this.account.getUserAgent())
.put("browser_user_agent", this.midjourneyConfig.getUserAage())
.put("browser_version", userAgent.getVersion())
.put("client_build_number", 222963)
.put("client_event_source", null)
@ -235,6 +241,6 @@ public class SpringWebSocketHandler implements WebSocketHandler {
.put("compress", false)
.put("presence", presence)
.put("properties", connectionProperties)
.put("token", this.account.getUserToken());
.put("token", this.midjourneyConfig.getToken());
}
}

View File

@ -0,0 +1,44 @@
package cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener;
import cn.hutool.core.text.CharSequenceUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjConstants;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjMessageTypeEnum;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataObject;
@Slf4j
public class MjMessageListener {
private MidjourneyConfig midjourneyConfig;
public MjMessageListener(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig;
}
public void onMessage(DataObject raw) {
MjMessageTypeEnum messageType = MjMessageTypeEnum.of(raw.getString("t"));
if (messageType == null || MjMessageTypeEnum.DELETE == messageType) {
return;
}
DataObject data = raw.getObject("d");
if (ignoreAndLogMessage(data, messageType)) {
return;
}
System.err.println(data);
// if (data.getBoolean(Constants.MJ_MESSAGE_HANDLED, false)) {
// return;
// }
}
private boolean ignoreAndLogMessage(DataObject data, MjMessageTypeEnum messageType) {
String channelId = data.getString(MjConstants.CHANNEL_ID);
if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) {
return true;
}
String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse(""));
return false;
}
}