增加mj 图片生成消息转换

This commit is contained in:
cherishsince 2024-04-06 21:20:10 +08:00
parent b09fc5579c
commit f2b9c14819
3 changed files with 219 additions and 27 deletions

View File

@ -0,0 +1,124 @@
package cn.iocoder.yudao.framework.ai.midjourney;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.List;
@Data
@Accessors(chain = true)
public class MjMessage {
/**
* id是一个重要的字段在同时生成多个的时候可以区分生成信息
*/
private String id;
/**
* 现在已知
* 0我们发送的消息和指令
* 20: mj生成图片发送过程中
*/
private Integer type;
/**
* content
*/
private Content content;
/**
* 图片生成完成才有
*/
private List<ComponentType> components;
/**
* 生成过程中如果有预展示图片这里会有
*/
private List<Attachment> attachments;
/**
* 原始数据(discard 返回的原始数据)
*/
private String rawData;
/**
* 生成状态(用于区分生成状态)
* 1等待
* 2进行中
* 3完成
* {@link cn.iocoder.yudao.framework.ai.midjourney.constants.MjGennerateStatusEnum}
*/
private String generateStatus;
@Data
@Accessors(chain = true)
public static class ComponentType {
private int type;
private List<Component> components;
}
@Data
@Accessors(chain = true)
public static class Component {
/**
* 自定义ID用于唯一标识特定交互动作及其上下文信息
*/
private String customId;
/**
* 样式编号用于确定按钮的样式外观
* 在某些应用中例如Discord2可能表示一种特定的颜色或形状的按钮
*/
private int style;
/**
* 按钮的标签文本用户可见的内容
*/
private String label;
/**
* 组件类型此处为2可能表示这是一种特定类型的交互组件
* 如在Discord API中类型2对应的是一个可点击的按钮组件
*/
private int type;
}
@Data
@Accessors(chain = true)
public static class Attachment {
// 文件名
private String filename;
// 附件大小字节
private int size;
// 内容类型例如image/webp
private String contentType;
// 图像宽度像素
private int width;
// 占位符版本号
private int placeholderVersion;
// 代理URL用于访问附件资源
private String proxyUrl;
// 占位符标识符
private String placeholder;
// 附件ID
private String id;
// 直接访问附件资源的URL
private String url;
// 图像高度像素
private int height;
}
@Data
@Accessors(chain = true)
public static class Content {
private String prompt;
private String progress;
private String status;
}
}

View File

@ -0,0 +1,29 @@
package cn.iocoder.yudao.framework.ai.midjourney.constants;
import lombok.Getter;
/**
* mj 生成状态
*
* author: fansili
* time: 2024/4/6 21:07
*/
@Getter
public enum MjGennerateStatusEnum {
WAITING("waiting", "等待..."),
IN_PROGRESS("in_progress", "进行中"),
COMPLETED("completed", "完成"),
;
MjGennerateStatusEnum(String value, String message) {
this.value = value;
this.message = message;
}
private String value;
private String message;
}

View File

@ -1,44 +1,83 @@
package cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener; package cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.CharSequenceUtil; import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
import cn.iocoder.yudao.framework.ai.midjourney.MjMessage;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjConstants; import cn.iocoder.yudao.framework.ai.midjourney.constants.MjConstants;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjGennerateStatusEnum;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjMessageTypeEnum; import cn.iocoder.yudao.framework.ai.midjourney.constants.MjMessageTypeEnum;
import cn.iocoder.yudao.framework.ai.midjourney.util.MjUtil;
import com.alibaba.fastjson.JSON;
import com.google.common.collect.Lists;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataObject; import net.dv8tion.jda.api.utils.data.DataObject;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.List;
@Slf4j @Slf4j
public class MjMessageListener { public class MjMessageListener {
private MidjourneyConfig midjourneyConfig; private MidjourneyConfig midjourneyConfig;
public MjMessageListener(MidjourneyConfig midjourneyConfig) { public MjMessageListener(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig; this.midjourneyConfig = midjourneyConfig;
} }
public void onMessage(DataObject raw) { public void onMessage(DataObject raw) {
MjMessageTypeEnum messageType = MjMessageTypeEnum.of(raw.getString("t")); MjMessageTypeEnum messageType = MjMessageTypeEnum.of(raw.getString("t"));
if (messageType == null || MjMessageTypeEnum.DELETE == messageType) { if (messageType == null || MjMessageTypeEnum.DELETE == messageType) {
return; return;
} }
DataObject data = raw.getObject("d"); DataObject data = raw.getObject("d");
if (ignoreAndLogMessage(data, messageType)) { if (ignoreAndLogMessage(data, messageType)) {
return; return;
} }
System.err.println(data);
// if (data.getBoolean(Constants.MJ_MESSAGE_HANDLED, false)) {
// return;
// }
}
private boolean ignoreAndLogMessage(DataObject data, MjMessageTypeEnum messageType) { MjMessage mjMessage = new MjMessage();
String channelId = data.getString(MjConstants.CHANNEL_ID); mjMessage.setId(data.getString("id"));
if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) { mjMessage.setType(data.getInt("type"));
return true; mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
} mjMessage.setContent(MjUtil.parseContent(data.getString("content")));
String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse("")); if (!data.getArray("components").isEmpty()) {
return false; String componentsJson = StrUtil.str(data.getArray("components").toJson(), "UTF-8");
} List<MjMessage.ComponentType> components = JSON.parseArray(componentsJson, MjMessage.ComponentType.class);
mjMessage.setComponents(components);
}
if (!data.getArray("attachments").isEmpty()) {
String attachmentsJson = StrUtil.str(data.getArray("attachments").toJson(), "UTF-8");
List<MjMessage.Attachment> attachments = JSON.parseArray(attachmentsJson, MjMessage.Attachment.class);
mjMessage.setAttachments(attachments);
}
// 转换状态
convertGenerateStatus(mjMessage);
System.err.println(JSONUtil.toJsonPrettyStr(mjMessage));
}
private void convertGenerateStatus(MjMessage mjMessage) {
if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
mjMessage.setGenerateStatus(MjGennerateStatusEnum.WAITING.getValue());
} else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {
mjMessage.setGenerateStatus(MjGennerateStatusEnum.IN_PROGRESS.getValue());
} else if (mjMessage.getType() == 0 && !CollUtil.isEmpty(mjMessage.getComponents())) {
mjMessage.setGenerateStatus(MjGennerateStatusEnum.COMPLETED.getValue());
}
}
private boolean ignoreAndLogMessage(DataObject data, MjMessageTypeEnum messageType) {
String channelId = data.getString(MjConstants.CHANNEL_ID);
if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) {
return true;
}
String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse(""));
return false;
}
} }