【代码优化】AI:音乐生成

This commit is contained in:
YunaiV 2024-06-27 12:48:15 +08:00
parent c77fc954d2
commit 23baaff84d
6 changed files with 62 additions and 62 deletions

View File

@ -12,9 +12,8 @@ import lombok.Getter;
@Getter @Getter
public enum AiMusicStatusEnum { public enum AiMusicStatusEnum {
// @xin 文档中无失败这个返回值 IN_PROGRESS(10, "进行中"),
STREAMING(10, "进行中"), SUCCESS(20, "已完成");
COMPLETE(20, "完成");
/** /**
* 状态 * 状态

View File

@ -7,11 +7,20 @@ import lombok.Data;
import java.util.List; import java.util.List;
@Schema(description = "管理后台 - 音乐生成 Request VO") @Schema(description = "管理后台 - AI 音乐生成 Request VO")
@Data @Data
public class AiSunoGenerateReqVO { public class AiSunoGenerateReqVO {
@Schema(description = "用于生成音乐音频的提示", requiredMode = Schema.RequiredMode.REQUIRED, example = "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。") @Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "Suno")
@NotBlank(message = "平台不能为空")
private String platform; // 参见 AiPlatformEnum 枚举
@Schema(description = "生成模式", requiredMode = Schema.RequiredMode.REQUIRED, example = "2")
@NotNull(message = "生成模式不能为空")
private Integer generateMode; // 参见 AiMusicGenerateEnum 枚举
@Schema(description = "用于生成音乐音频的提示", requiredMode = Schema.RequiredMode.REQUIRED,
example = "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。")
private String prompt; private String prompt;
@Schema(description = "是否纯音乐", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "true") @Schema(description = "是否纯音乐", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "true")
@ -26,12 +35,4 @@ public class AiSunoGenerateReqVO {
@Schema(description = "音乐/歌曲名称", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "夜空中最亮的星") @Schema(description = "音乐/歌曲名称", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "夜空中最亮的星")
private String title; private String title;
@Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "Suno")
@NotBlank(message = "平台不能为空")
private String platform; // 参见 AiPlatformEnum 枚举
@Schema(description = "生成模式", requiredMode = Schema.RequiredMode.REQUIRED, example = "2")
@NotNull(message = "生成模式不能为空")
private Integer generateMode; // 参见 AiMusicGenerateEnum 枚举
} }

View File

@ -1,6 +1,8 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.music; package cn.iocoder.yudao.module.ai.dal.dataobject.music;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum; import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
import com.baomidou.mybatisplus.annotation.IdType; import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableField;
@ -38,21 +40,19 @@ public class AiMusicDO extends BaseDO {
*/ */
private String title; private String title;
/**
* 图片地址
*/
private String imageUrl;
/** /**
* 歌词 * 歌词
*/ */
private String lyric; private String lyric;
/**
* 图片地址
*/
private String imageUrl;
/** /**
* 音频地址 * 音频地址
*/ */
private String audioUrl; private String audioUrl;
/** /**
* 视频地址 * 视频地址
*/ */
@ -65,6 +65,13 @@ public class AiMusicDO extends BaseDO {
*/ */
private Integer status; private Integer status;
/**
* 生成模式
*
* 枚举 {@link AiMusicGenerateModeEnum}
*/
private Integer generateMode;
/** /**
* 描述词 * 描述词
*/ */
@ -74,28 +81,17 @@ public class AiMusicDO extends BaseDO {
*/ */
private String prompt; private String prompt;
/**
* 生成模式
*/
private Integer generateMode;
/** /**
* 平台 * 平台
* <p> * <p>
* 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum} * 枚举 {@link AiPlatformEnum}
*/ */
private String platform; private String platform;
/** /**
* 模型 * 模型
*/ */
private String model; private String model;
/**
* 错误信息
*/
private String errorMessage;
/** /**
* 音乐风格标签 * 音乐风格标签
*/ */
@ -107,4 +103,9 @@ public class AiMusicDO extends BaseDO {
*/ */
private String taskId; private String taskId;
/**
* 错误信息
*/
private String errorMessage;
} }

View File

@ -1,7 +1,6 @@
package cn.iocoder.yudao.module.ai.dal.mysql.music; package cn.iocoder.yudao.module.ai.dal.mysql.music;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO; import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
@ -16,8 +15,7 @@ import java.util.List;
public interface AiMusicMapper extends BaseMapperX<AiMusicDO> { public interface AiMusicMapper extends BaseMapperX<AiMusicDO> {
default List<AiMusicDO> selectListByStatus(Integer status) { default List<AiMusicDO> selectListByStatus(Integer status) {
return selectList(new LambdaQueryWrapperX<AiMusicDO>() return selectList(AiMusicDO::getStatus, status);
.eq(AiMusicDO::getStatus, status));
} }
} }

View File

@ -14,10 +14,11 @@ public interface AiMusicService {
/** /**
* 音乐生成 * 音乐生成
* *
* @param userId 用户编号
* @param reqVO 请求参数 * @param reqVO 请求参数
* @return 生成的音乐ID * @return 生成的音乐ID
*/ */
List<Long> generateMusic(AiSunoGenerateReqVO reqVO); List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO);
/** /**
* 同步音乐任务 * 同步音乐任务
@ -25,4 +26,5 @@ public interface AiMusicService {
* @return 同步数量 * @return 同步数量
*/ */
Integer syncMusic(); Integer syncMusic();
} }

View File

@ -4,7 +4,6 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.StrPool; import cn.hutool.core.text.StrPool;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiSunoGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiSunoGenerateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO; import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper; import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper;
@ -16,7 +15,8 @@ import org.springframework.stereotype.Service;
import java.util.*; import java.util.*;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
/** /**
* AI 音乐 Service 实现类 * AI 音乐 Service 实现类
@ -34,54 +34,53 @@ public class AiMusicServiceImpl implements AiMusicService {
private AiMusicMapper musicMapper; private AiMusicMapper musicMapper;
@Override @Override
public List<Long> generateMusic(AiSunoGenerateReqVO reqVO) { public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
// 1. 调用 Suno 生成音乐
List<SunoApi.MusicData> musicDataList; List<SunoApi.MusicData> musicDataList;
if (Objects.equals(AiMusicGenerateModeEnum.LYRIC.getMode(), reqVO.getGenerateMode())) { if (Objects.equals(AiMusicGenerateModeEnum.LYRIC.getMode(), reqVO.getGenerateMode())) {
// 1.1 歌词模式 // 1.1 歌词模式
SunoApi.MusicGenerateRequest sunoReq = new SunoApi.MusicGenerateRequest( SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
reqVO.getPrompt(), reqVO.getModelVersion(), CollUtil.join(reqVO.getTags(), StrPool.COMMA), reqVO.getTitle()); reqVO.getPrompt(), reqVO.getModelVersion(), CollUtil.join(reqVO.getTags(), StrPool.COMMA), reqVO.getTitle());
musicDataList = sunoApi.customGenerate(sunoReq); musicDataList = sunoApi.customGenerate(generateRequest);
} else if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) { } else if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
// 1.2 描述模式 // 1.2 描述模式
SunoApi.MusicGenerateRequest sunoReq = new SunoApi.MusicGenerateRequest( SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
reqVO.getPrompt(), reqVO.getModelVersion(), reqVO.getMakeInstrumental()); reqVO.getPrompt(), reqVO.getModelVersion(), reqVO.getMakeInstrumental());
musicDataList = sunoApi.generate(sunoReq); musicDataList = sunoApi.generate(generateRequest);
} else { } else {
throw new IllegalArgumentException(StrUtil.format("未知生成模式({})", reqVO)); throw new IllegalArgumentException(StrUtil.format("未知生成模式({})", reqVO));
} }
// 2. 插入数据库 // 2. 插入数据库
if (CollUtil.isEmpty(musicDataList)) { if (CollUtil.isEmpty(musicDataList)) {
return Collections.emptyList(); return Collections.emptyList();
} }
List<AiMusicDO> aiMusicDOList = CollectionUtils.convertList(buildMusicDOList(musicDataList), musicDO -> List<AiMusicDO> musicList = buildMusicDOList(musicDataList);
musicDO.setUserId(getLoginUserId()) musicList.forEach(music -> music.setUserId(userId).setPlatform(music.getPlatform()).setGenerateMode(reqVO.getGenerateMode()));
.setGenerateMode(reqVO.getGenerateMode()) musicMapper.insertBatch(musicList);
.setPlatform(reqVO.getPlatform() return convertList(musicList, AiMusicDO::getId);
));
musicMapper.insertBatch(aiMusicDOList);
return CollectionUtils.convertList(aiMusicDOList, AiMusicDO::getId);
} }
@Override @Override
public Integer syncMusic() { public Integer syncMusic() {
List<AiMusicDO> streamingTask = musicMapper.selectListByStatus(AiMusicStatusEnum.STREAMING.getStatus()); List<AiMusicDO> streamingTask = musicMapper.selectListByStatus(AiMusicStatusEnum.IN_PROGRESS.getStatus());
if (CollUtil.isEmpty(streamingTask)) { if (CollUtil.isEmpty(streamingTask)) {
return 0; return 0;
} }
log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size()); log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size());
// GET 请求为避免参数过长分批次处理 // GET 请求为避免参数过长分批次处理
CollUtil.split(streamingTask, 36).forEach(chunk -> { CollUtil.split(streamingTask, 36).forEach(chunkList -> {
Map<String, Long> taskIdMap = CollectionUtils.convertMap(chunk, AiMusicDO::getTaskId, AiMusicDO::getId); Map<String, Long> taskIdMap = convertMap(chunkList, AiMusicDO::getTaskId, AiMusicDO::getId);
List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet())); List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet()));
if (CollUtil.isEmpty(musicTaskList)) { if (CollUtil.isEmpty(musicTaskList)) {
log.warn("Suno 任务同步失败, 任务ID: [{}]", taskIdMap.keySet()); log.warn("Suno 任务同步失败, 任务ID: [{}]", taskIdMap.keySet());
return; return;
} }
List<AiMusicDO> aiMusicDOS = buildMusicDOList(musicTaskList); // 更新进度
//回填id List<AiMusicDO> updateMusicList = buildMusicDOList(musicTaskList);
aiMusicDOS.forEach(aiMusicDO -> aiMusicDO.setId(taskIdMap.get(aiMusicDO.getTaskId()))); updateMusicList.forEach(music -> music.setId(taskIdMap.get(music.getTaskId())));
musicMapper.updateBatch(aiMusicDOS); musicMapper.updateBatch(updateMusicList);
}); });
return streamingTask.size(); return streamingTask.size();
} }
@ -89,16 +88,16 @@ public class AiMusicServiceImpl implements AiMusicService {
/** /**
* 构建 AiMusicDO 集合 * 构建 AiMusicDO 集合
* *
* @param musicTaskList suno 音乐任务列表 * @param musicList suno 音乐任务列表
* @return AiMusicDO 集合 * @return AiMusicDO 集合
*/ */
private static List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicTaskList) { private static List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicList) {
return CollectionUtils.convertList(musicTaskList, musicData -> new AiMusicDO() return convertList(musicList, musicData -> new AiMusicDO()
.setTaskId(musicData.id()) .setTaskId(musicData.id()).setModel(musicData.modelName())
.setPrompt(musicData.prompt()).setGptDescriptionPrompt(musicData.gptDescriptionPrompt()) .setPrompt(musicData.prompt()).setGptDescriptionPrompt(musicData.gptDescriptionPrompt())
.setAudioUrl(musicData.audioUrl()).setVideoUrl(musicData.videoUrl()).setImageUrl(musicData.imageUrl()) .setAudioUrl(musicData.audioUrl()).setVideoUrl(musicData.videoUrl()).setImageUrl(musicData.imageUrl())
.setTitle(musicData.title()).setLyric(musicData.lyric()).setTags(StrUtil.split(musicData.tags(), StrPool.COMMA)) .setTitle(musicData.title()).setLyric(musicData.lyric()).setTags(StrUtil.split(musicData.tags(), StrPool.COMMA))
.setModel(musicData.modelName()).setStatus(Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.COMPLETE.getStatus() : AiMusicStatusEnum.STREAMING.getStatus())); .setStatus(Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.SUCCESS.getStatus() : AiMusicStatusEnum.IN_PROGRESS.getStatus()));
} }
} }