【代码评审】AI:音乐接入

This commit is contained in:
YunaiV 2024-06-25 21:45:26 +08:00
parent 5c73e5e1f4
commit ec1376f4cb
9 changed files with 60 additions and 44 deletions

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.enums.music;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
// TODO @xiaoxin这个也叫 AiMusicGenerateModeEnum 虽然长但是和项目统一一点
/** /**
* AI 音乐状态的枚举 * AI 音乐状态的枚举
* *

View File

@ -12,7 +12,7 @@ import lombok.Getter;
@Getter @Getter
public enum AiMusicStatusEnum { public enum AiMusicStatusEnum {
// @xin 文档中无失败这个返回值 // @xin 文档中无失败这个返回值 TODO @xin Integer 另外个枚举类也是
STREAMING("10", "进行中"), STREAMING("10", "进行中"),
COMPLETE("20", "完成"); COMPLETE("20", "完成");

View File

@ -6,13 +6,11 @@ import lombok.Data;
import java.util.List; import java.util.List;
/**
* @author xiaoxin
*/
@Schema(description = "管理后台 - 音乐生成 Request VO") @Schema(description = "管理后台 - 音乐生成 Request VO")
@Data @Data
public class AiSunoGenerateReqVO { public class AiSunoGenerateReqVO {
// TODO @xin每个参数必要的是否必填校验
@Schema(description = "用于生成音乐音频的提示", example = "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。") @Schema(description = "用于生成音乐音频的提示", example = "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。")
private String prompt; private String prompt;
@ -20,7 +18,7 @@ public class AiSunoGenerateReqVO {
private Boolean makeInstrumental; private Boolean makeInstrumental;
@Schema(description = "模型版本, 默认 chirp-v3.5", example = "chirp-v3.5") @Schema(description = "模型版本, 默认 chirp-v3.5", example = "chirp-v3.5")
private String modelVersion;// 参见 AiModelEnum 枚举 private String modelVersion; // 参见 AiModelEnum 枚举
@Schema(description = "音乐风格", example = "[\"pop\",\"jazz\",\"punk\"]") @Schema(description = "音乐风格", example = "[\"pop\",\"jazz\",\"punk\"]")
private List<String> tags; private List<String> tags;
@ -30,10 +28,10 @@ public class AiSunoGenerateReqVO {
@Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "Suno") @Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "Suno")
@NotBlank(message = "平台不能为空") @NotBlank(message = "平台不能为空")
private String platform;// 参见 AiPlatformEnum 枚举 private String platform; // 参见 AiPlatformEnum 枚举
@Schema(description = "生成模式 1(歌词模式), 2(描述模式)", requiredMode = Schema.RequiredMode.REQUIRED, example = "2") @Schema(description = "生成模式", requiredMode = Schema.RequiredMode.REQUIRED, example = "2")
@NotBlank(message = "生成模式不能为空") @NotBlank(message = "生成模式不能为空")
private String generateMode;// 参见 AiMusicGenerateEnum 枚举 private String generateMode; // 参见 AiMusicGenerateEnum 枚举
} }

View File

@ -68,7 +68,6 @@ public class AiImageDO extends BaseDO {
*/ */
private Integer height; private Integer height;
// TODO @fan这种就注释绘画状态然后枚举类关联下就好啦
/** /**
* 生成状态 * 生成状态
* *
@ -76,6 +75,11 @@ public class AiImageDO extends BaseDO {
*/ */
private Integer status; private Integer status;
/**
* 绘画错误信息
*/
private String errorMessage;
/** /**
* 图片地址 * 图片地址
*/ */
@ -101,15 +105,12 @@ public class AiImageDO extends BaseDO {
private List<MidjourneyApi.Button> buttons; private List<MidjourneyApi.Button> buttons;
/** /**
* midjourney proxy 关联的 task id * 任务编号
*
* 1. midjourney proxy关联的 task id
*/ */
private String taskId; private String taskId;
/**
* 绘画错误信息
*/
private String errorMessage;
public static class ButtonTypeHandler extends AbstractJsonTypeHandler<Object> { public static class ButtonTypeHandler extends AbstractJsonTypeHandler<Object> {
@Override @Override

View File

@ -11,7 +11,6 @@ import lombok.Data;
import java.util.List; import java.util.List;
/** /**
* AI 音乐 DO * AI 音乐 DO
* *
@ -29,6 +28,8 @@ public class AiMusicDO extends BaseDO {
/** /**
* 用户编号 * 用户编号
*
* 关联 AdminUserDO userId 字段
*/ */
private Long userId; private Long userId;
@ -105,4 +106,5 @@ public class AiMusicDO extends BaseDO {
* 任务编号 * 任务编号
*/ */
private String taskId; private String taskId;
} }

View File

@ -13,7 +13,7 @@ import org.springframework.stereotype.Component;
*/ */
@Component @Component
@Slf4j @Slf4j
public class MidjourneySyncJob implements JobHandler { public class AiMidjourneySyncJob implements JobHandler {
@Resource @Resource
private AiImageService imageService; private AiImageService imageService;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.job; package cn.iocoder.yudao.module.ai.job.sun;
import cn.iocoder.yudao.framework.quartz.core.handler.JobHandler; import cn.iocoder.yudao.framework.quartz.core.handler.JobHandler;
import cn.iocoder.yudao.module.ai.service.music.AiMusicService; import cn.iocoder.yudao.module.ai.service.music.AiMusicService;
@ -14,15 +14,16 @@ import org.springframework.stereotype.Component;
*/ */
@Component @Component
@Slf4j @Slf4j
public class SunoJob implements JobHandler { public class AiSunoSyncJob implements JobHandler {
@Resource @Resource
private AiMusicService musicService; private AiMusicService musicService;
@Override @Override
public String execute(String param) { public String execute(String param) {
Integer count = musicService.syncMusicTask(); Integer count = musicService.syncMusic();
log.info("[execute][Suno 同步任务数量 [{}] 个]", count); log.info("[execute][同步 Suno ({}) 个]", count);
return String.format("Suno 同步 - [%s]任务", count); return String.format("同步 Suno %s 个", count);
} }
} }

View File

@ -27,13 +27,12 @@ public interface AiMusicService {
*/ */
List<AiMusicDO> getUnCompletedTask(); List<AiMusicDO> getUnCompletedTask();
/** /**
* 同步音乐任务 * 同步音乐任务
* *
* @return 同步数量 * @return 同步数量
*/ */
Integer syncMusicTask(); Integer syncMusic();
/** /**
* 批量更新音乐信息 * 批量更新音乐信息

View File

@ -39,51 +39,61 @@ public class AiMusicServiceImpl implements AiMusicService {
public List<Long> generateMusic(AiSunoGenerateReqVO reqVO) { public List<Long> generateMusic(AiSunoGenerateReqVO reqVO) {
List<SunoApi.MusicData> musicDataList; List<SunoApi.MusicData> musicDataList;
if (Objects.equals(AiMusicGenerateEnum.LYRIC.getMode(), reqVO.getGenerateMode())) { if (Objects.equals(AiMusicGenerateEnum.LYRIC.getMode(), reqVO.getGenerateMode())) {
//歌词模式 // 1.1 歌词模式
SunoApi.MusicGenerateRequest sunoReq = new SunoApi.MusicGenerateRequest(reqVO.getPrompt(), reqVO.getModelVersion(), CollUtil.join(reqVO.getTags(), StrPool.COMMA), reqVO.getTitle()); SunoApi.MusicGenerateRequest sunoReq = new SunoApi.MusicGenerateRequest(
reqVO.getPrompt(), reqVO.getModelVersion(), CollUtil.join(reqVO.getTags(), StrPool.COMMA), reqVO.getTitle());
musicDataList = sunoApi.customGenerate(sunoReq); musicDataList = sunoApi.customGenerate(sunoReq);
} else if (Objects.equals(AiMusicGenerateEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) { } else if (Objects.equals(AiMusicGenerateEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
//描述模式 // 1.2 描述模式
SunoApi.MusicGenerateRequest sunoReq = new SunoApi.MusicGenerateRequest(reqVO.getPrompt(), reqVO.getModelVersion(), reqVO.getMakeInstrumental()); SunoApi.MusicGenerateRequest sunoReq = new SunoApi.MusicGenerateRequest(
reqVO.getPrompt(), reqVO.getModelVersion(), reqVO.getMakeInstrumental());
musicDataList = sunoApi.generate(sunoReq); musicDataList = sunoApi.generate(sunoReq);
} else { } else {
// TODO @xin不用 log error直接抛异常 reqVO 呆进去有全局处理的哈
log.error("未知的生成模式:{}", reqVO.getGenerateMode()); log.error("未知的生成模式:{}", reqVO.getGenerateMode());
throw new IllegalArgumentException("未知的生成模式"); throw new IllegalArgumentException("未知的生成模式");
} }
// 2. 插入数据库 // 2. 插入数据库
// TODO @xin因为 insertMusicData 复用的比较少所以不用愁单独的方法直接写在这里就好啦
return insertMusicData(musicDataList, reqVO.getGenerateMode(), reqVO.getPlatform()); return insertMusicData(musicDataList, reqVO.getGenerateMode(), reqVO.getPlatform());
} }
// TODO @xin1service 里面不要直接查询 db2不要用 ne STREAMING
@Override @Override
public List<AiMusicDO> getUnCompletedTask() { public List<AiMusicDO> getUnCompletedTask() {
return musicMapper.selectList(new LambdaQueryWrapper<AiMusicDO>().ne(AiMusicDO::getStatus, AiMusicStatusEnum.COMPLETE.getStatus())); return musicMapper.selectList(new LambdaQueryWrapper<AiMusicDO>().ne(AiMusicDO::getStatus, AiMusicStatusEnum.COMPLETE.getStatus()));
} }
@Override @Override
public Integer syncMusicTask() { public Integer syncMusic() {
List<AiMusicDO> unCompletedTask = this.getUnCompletedTask(); List<AiMusicDO> unCompletedTask = this.getUnCompletedTask();
if (CollUtil.isEmpty(unCompletedTask)) { if (CollUtil.isEmpty(unCompletedTask)) {
// TODO @xin这里不用打反正 Job 也打了
log.info("Suno 无进行中任务需要更新!"); log.info("Suno 无进行中任务需要更新!");
return 0; return 0;
} }
log.info("Suno 开始同步, 共 [{}] 个任务!", unCompletedTask.size()); log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", unCompletedTask.size());
//GET 请求为避免参数过长分批次处理 // GET 请求为避免参数过长分批次处理
CollUtil.split(unCompletedTask, 4) // TODO @xin建议批量更大一些
.forEach(chunk -> { CollUtil.split(unCompletedTask, 4).forEach(chunk -> {
Map<String, Long> taskIdMap = CollUtil.toMap(chunk, new HashMap<>(), AiMusicDO::getTaskId, AiMusicDO::getId); // TODO @xin可以使用 CollectionUtils 里的 map 转换
List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet())); Map<String, Long> taskIdMap = CollUtil.toMap(chunk, new HashMap<>(), AiMusicDO::getTaskId, AiMusicDO::getId);
if (CollUtil.isNotEmpty(musicTaskList)) { List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet()));
List<AiMusicDO> aiMusicDOS = buildMusicDOList(musicTaskList); // TODO @xin查询不到直接 return这样真正逻辑的 85 - 87 就不用多一层括号
//回填id if (CollUtil.isNotEmpty(musicTaskList)) {
aiMusicDOS.forEach(aiMusicDO -> aiMusicDO.setId(taskIdMap.get(aiMusicDO.getTaskId()))); List<AiMusicDO> aiMusicDOS = buildMusicDOList(musicTaskList);
this.updateBatch(aiMusicDOS); //回填id
} else { aiMusicDOS.forEach(aiMusicDO -> aiMusicDO.setId(taskIdMap.get(aiMusicDO.getTaskId())));
log.warn("Suno 任务同步失败, 任务ID: [{}]", taskIdMap.keySet()); this.updateBatch(aiMusicDOS);
} } else {
}); log.warn("Suno 任务同步失败, 任务ID: [{}]", taskIdMap.keySet());
}
});
return unCompletedTask.size(); return unCompletedTask.size();
} }
// TODO @xin这个方法看着不用啦
@Override @Override
public Boolean updateBatch(List<AiMusicDO> musicDOS) { public Boolean updateBatch(List<AiMusicDO> musicDOS) {
return musicMapper.updateBatch(musicDOS); return musicMapper.updateBatch(musicDOS);
@ -105,6 +115,7 @@ public class AiMusicServiceImpl implements AiMusicService {
.setPlatform(platform)) .setPlatform(platform))
.toList(); .toList();
musicMapper.insertBatch(aiMusicDOList); musicMapper.insertBatch(aiMusicDOList);
// TODO @xin CollectionUtils 简化操作
return aiMusicDOList.stream() return aiMusicDOList.stream()
.map(AiMusicDO::getId) .map(AiMusicDO::getId)
.collect(Collectors.toList()); .collect(Collectors.toList());
@ -117,6 +128,7 @@ public class AiMusicServiceImpl implements AiMusicService {
* @return AiMusicDO 集合 * @return AiMusicDO 集合
*/ */
private static List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicTaskList) { private static List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicTaskList) {
// TODO @xin想通的变量放在同一行避免过长
return CollectionUtils.convertList(musicTaskList, musicData -> new AiMusicDO() return CollectionUtils.convertList(musicTaskList, musicData -> new AiMusicDO()
.setTaskId(musicData.id()) .setTaskId(musicData.id())
.setPrompt(musicData.prompt()) .setPrompt(musicData.prompt())
@ -128,6 +140,8 @@ public class AiMusicServiceImpl implements AiMusicService {
.setTitle(musicData.title()) .setTitle(musicData.title())
.setStatus(Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.COMPLETE.getStatus() : AiMusicStatusEnum.STREAMING.getStatus()) .setStatus(Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.COMPLETE.getStatus() : AiMusicStatusEnum.STREAMING.getStatus())
.setModel(musicData.modelName()) .setModel(musicData.modelName())
// TODO @xin可以用 hutool StrUtil split 之类的
.setTags(StrUtil.isNotBlank(musicData.tags()) ? List.of(musicData.tags().split(StrPool.COMMA)) : null)); .setTags(StrUtil.isNotBlank(musicData.tags()) ? List.of(musicData.tags().split(StrPool.COMMA)) : null));
} }
} }