Merge remote-tracking branch 'origin/master-jdk21-ai' into master-jdk21-ai

This commit is contained in:
cherishsince 2024-06-17 21:38:32 +08:00
commit 4c89342d5b
11 changed files with 67 additions and 48 deletions

View File

@ -3,6 +3,8 @@ package cn.iocoder.yudao.module.ai.enums;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
// TODO @xin这个类挪到 enums/music 包下
// TODO @xin1@author 这个是标准的 javadoc2@date 可以不要哈3可以加下枚举类的注释
/** /**
* @Author xiaoxin * @Author xiaoxin
* @Date 2024/6/5 * @Date 2024/6/5
@ -11,6 +13,8 @@ import lombok.Getter;
@Getter @Getter
public enum AiMusicStatusEnum { public enum AiMusicStatusEnum {
// TODO @xin是不是收敛成只有 3 进行中成功失败类似 AiImageStatusEnum
SUBMITTED("submitted", "已提交"), SUBMITTED("submitted", "已提交"),
QUEUED("queued", "排队中"), QUEUED("queued", "排队中"),
STREAMING("streaming", "进行中"), STREAMING("streaming", "进行中"),

View File

@ -50,6 +50,7 @@ public enum AiModelEnum {
XING_HUO_3_0("星火大模型3.0", "generalv3", "/v3.1/chat"), XING_HUO_3_0("星火大模型3.0", "generalv3", "/v3.1/chat"),
XING_HUO_3_5("星火大模型3.5", "generalv3.5", "/v3.5/chat"), XING_HUO_3_5("星火大模型3.5", "generalv3.5", "/v3.5/chat"),
// TODO @xin// Suno中间加个空格会更清晰一点一般来说不同类型的单词之间最好有空格例如说// 新增一个再例如说// 这是 1 create 逻辑
//Suno //Suno
SUNO_2( "SUNO-2", "chirp-v2-xxl-alpha",null), SUNO_2( "SUNO-2", "chirp-v2-xxl-alpha",null),
SUNO_3_0( "SUNO-3.0", "chirp-v3-0",null), SUNO_3_0( "SUNO-3.0", "chirp-v3-0",null),

View File

@ -17,6 +17,7 @@ import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
// TODO @xinAI 前缀都要加下哈
@Tag(name = "管理后台 - AI 音乐生成") @Tag(name = "管理后台 - AI 音乐生成")
@RestController @RestController
@RequestMapping("/ai/music") @RequestMapping("/ai/music")

View File

@ -4,12 +4,13 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.Data; import lombok.Data;
@Data @Data
@JsonInclude(value = JsonInclude.Include.NON_NULL) @JsonInclude(value = JsonInclude.Include.NON_NULL) // TODO @xin不用加这个哈
public class SunoReqVO { public class SunoReqVO {
/** /**
* 用于生成音乐音频的提示 * 用于生成音乐音频的提示
*/ */
private String prompt; private String prompt;
// TODO @xinBoolean不使用基本类型
/** /**
* 是否纯音乐 * 是否纯音乐
*/ */

View File

@ -18,6 +18,8 @@ import java.util.stream.Collectors;
@TableName("ai_music") @TableName("ai_music")
@Data @Data
public class AiMusicDO extends BaseDO { public class AiMusicDO extends BaseDO {
// TODO @xin@Schema 只在 VO 里使用这里还是使用标准的注释哈
@TableId(type = IdType.AUTO) @TableId(type = IdType.AUTO)
@Schema(description = "编号") @Schema(description = "编号")
private Long id; private Long id;
@ -40,6 +42,7 @@ public class AiMusicDO extends BaseDO {
@Schema(description = "视频地址") @Schema(description = "视频地址")
private String videoUrl; private String videoUrl;
// TODO @xin需要关联下对应的枚举
@Schema(description = "音乐状态") @Schema(description = "音乐状态")
private String status; private String status;
@ -49,19 +52,24 @@ public class AiMusicDO extends BaseDO {
@Schema(description = "提示词") @Schema(description = "提示词")
private String prompt; private String prompt;
// TODO @xin生成模式需要记录下歌词描述
// TODO @xin多存储一个平台platform考虑未来可能有别的音乐接口
@Schema(description = "模型") @Schema(description = "模型")
private String model; private String model;
@Schema(description = "错误信息") @Schema(description = "错误信息")
private String errorMessage; private String errorMessage;
// TODO @xintags 要不要使用 List<String>
@Schema(description = "音乐风格标签") @Schema(description = "音乐风格标签")
private String tags; private String tags;
@Schema(description = "任务id") @Schema(description = "任务编号")
private String taskId; private String taskId;
// TODO @xin转换不放在 DO 里面哈
public static AiMusicDO convertFrom(SunoApi.MusicData musicData) { public static AiMusicDO convertFrom(SunoApi.MusicData musicData) {
return new AiMusicDO() return new AiMusicDO()
@ -84,5 +92,4 @@ public class AiMusicDO extends BaseDO {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
} }

View File

@ -5,10 +5,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
/** /**
* @Author xiaoxin * AI 音乐 Mapper
* @Date 2024/6/5 * @author xiaoxin
*/ */
@Mapper @Mapper
public interface AiMusicMapper extends BaseMapperX<AiMusicDO> { public interface AiMusicMapper extends BaseMapperX<AiMusicDO> {
} }

View File

@ -31,26 +31,29 @@ import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUti
@Slf4j @Slf4j
public class MusicServiceImpl implements MusicService { public class MusicServiceImpl implements MusicService {
// TODO @xin使用 @Resource 注入整个项目保持统一哈
private final SunoApi sunoApi; private final SunoApi sunoApi;
private final AiMusicMapper musicMapper; private final AiMusicMapper musicMapper;
private final Queue<String> taskQueue = new ConcurrentLinkedQueue<>(); private final Queue<String> taskQueue = new ConcurrentLinkedQueue<>();
// TODO @xin要不把 descriptionModelyricMode 合并同一个 generateMusic 方法然后根据传入的 mode 模式歌词描述来区分
@Override @Override
public List<Long> descriptionMode(SunoReqVO reqVO) { public List<Long> descriptionMode(SunoReqVO reqVO) {
SunoApi.SunoReq sunoReq = new SunoApi.SunoReq(reqVO.getPrompt(), reqVO.getMv(), reqVO.isMakeInstrumental()); // 1. 异步生成
//默认异步 SunoApi.SunoRequest sunoReq = new SunoApi.SunoRequest(reqVO.getPrompt(), reqVO.getMv(), reqVO.isMakeInstrumental());
List<SunoApi.MusicData> musicDataList = sunoApi.generate(sunoReq); List<SunoApi.MusicData> musicDataList = sunoApi.generate(sunoReq);
// 2. 插入数据库
return insertMusicData(musicDataList); return insertMusicData(musicDataList);
} }
@Override @Override
public List<Long> lyricMode(SunoLyricModeVO reqVO) { public List<Long> lyricMode(SunoLyricModeVO reqVO) {
SunoApi.SunoReq sunoReq = new SunoApi.SunoReq(reqVO.getPrompt(), reqVO.getMv(), reqVO.getTags(), reqVO.getTitle()); // 1. 异步生成
//默认异步 SunoApi.SunoRequest sunoReq = new SunoApi.SunoRequest(reqVO.getPrompt(), reqVO.getMv(), reqVO.getTags(), reqVO.getTitle());
List<SunoApi.MusicData> musicDataList = sunoApi.customGenerate(sunoReq); List<SunoApi.MusicData> musicDataList = sunoApi.customGenerate(sunoReq);
// 2. 插入数据库
return insertMusicData(musicDataList); return insertMusicData(musicDataList);
} }
@ -64,6 +67,7 @@ public class MusicServiceImpl implements MusicService {
if (CollUtil.isEmpty(musicDataList)) { if (CollUtil.isEmpty(musicDataList)) {
return Collections.emptyList(); return Collections.emptyList();
} }
// TODO @xin建议使用 insertBatch 方法批量插入
return AiMusicDO.convertFrom(musicDataList).stream() return AiMusicDO.convertFrom(musicDataList).stream()
.peek(musicDO -> musicMapper.insert(musicDO.setUserId(getLoginUserId()))) .peek(musicDO -> musicMapper.insert(musicDO.setUserId(getLoginUserId())))
.peek(e -> Optional.of(e.getTaskId()).ifPresent(taskQueue::add)) .peek(e -> Optional.of(e.getTaskId()).ifPresent(taskQueue::add))
@ -71,6 +75,7 @@ public class MusicServiceImpl implements MusicService {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
// TODO @xin这个改成标准的 job 来实现哈从数据库加载任务然后执行
@Scheduled(fixedDelay = 5, timeUnit = TimeUnit.SECONDS) @Scheduled(fixedDelay = 5, timeUnit = TimeUnit.SECONDS)
@Transactional @Transactional
public void flushSunoTask() { public void flushSunoTask() {

View File

@ -118,8 +118,9 @@ public class YudaoAiProperties {
public static class SunoProperties { public static class SunoProperties {
private boolean enable = false; private boolean enable = false;
/** /**
* suno-api 服务的基本地址 * API 服务的基本地址
*/ */
private String baseUrl; private String baseUrl;

View File

@ -4,16 +4,20 @@ import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
// TODO @xin不需要这个类哈直接 SunoApi 传入 baseUrl 参数即可
/** /**
* @Author xiaoxin * Suno 配置类
* @Date 2024/5/29 *
* @author xiaoxin
*/ */
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public class SunoConfig { public class SunoConfig {
/** /**
* suno-api服务的基本路径 * suno-api服务的基本路径
*/ */
private String baseUrl; private String baseUrl;
} }

View File

@ -27,14 +27,15 @@ import java.util.function.Predicate;
public class SunoApi { public class SunoApi {
private final WebClient webClient; private final WebClient webClient;
private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful(); private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
private final Function<ClientResponse, Mono<? extends Throwable>> EXCEPTION_FUNCTION = response -> response.bodyToMono(String.class) private final Function<ClientResponse, Mono<? extends Throwable>> EXCEPTION_FUNCTION = response -> response.bodyToMono(String.class)
.handle((respBody, sink) -> { .handle((respBody, sink) -> {
// TODO @xin最好是 requestresponse 都有哈
log.error("【suno-api】调用失败resp: 【{}】", respBody); log.error("【suno-api】调用失败resp: 【{}】", respBody);
sink.error(new IllegalStateException("【suno-api】调用失败")); sink.error(new IllegalStateException("【suno-api】调用失败"));
}); });
public SunoApi(SunoConfig config) { public SunoApi(SunoConfig config) {
this.webClient = WebClient.builder() this.webClient = WebClient.builder()
.baseUrl(config.getBaseUrl()) .baseUrl(config.getBaseUrl())
@ -42,50 +43,49 @@ public class SunoApi {
.build(); .build();
} }
public List<MusicData> generate(SunoApi.SunoReq sunReq) { public List<MusicData> generate(SunoRequest request) {
return this.webClient.post() return this.webClient.post()
.uri("/api/generate") .uri("/api/generate")
.body(Mono.just(sunReq), SunoApi.SunoReq.class) .body(Mono.just(request), SunoRequest.class)
.retrieve() .retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION) .onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { .bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { })
})
.block(); .block();
} }
public List<MusicData> customGenerate(SunoApi.SunoReq sunReq) { public List<MusicData> customGenerate(SunoRequest request) {
return this.webClient.post() return this.webClient.post()
.uri("/api/custom_generate") .uri("/api/custom_generate")
.body(Mono.just(sunReq), SunoApi.SunoReq.class) .body(Mono.just(request), SunoRequest.class)
.retrieve() .retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION) .onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { .bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { })
})
.block(); .block();
} }
// TODO @xin: 是不是叫 chatCompletion
public List<MusicData> doChatCompletion(String prompt) { public List<MusicData> doChatCompletion(String prompt) {
return this.webClient.post() return this.webClient.post()
.uri("/v1/chat/completions") .uri("/v1/chat/completions")
.body(Mono.just(new SunoReq(prompt)), SunoApi.SunoReq.class) .body(Mono.just(new SunoRequest(prompt)), SunoRequest.class)
.retrieve() .retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION) .onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { .bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { })
})
.block(); .block();
} }
public LyricsData generateLyrics(String prompt) { public LyricsData generateLyrics(String prompt) {
return this.webClient.post() return this.webClient.post()
.uri("/api/generate_lyrics") .uri("/api/generate_lyrics")
.body(Mono.just(new SunoReq(prompt)), SunoApi.SunoReq.class) .body(Mono.just(new SunoRequest(prompt)), SunoRequest.class)
.retrieve() .retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION) .onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.bodyToMono(LyricsData.class) .bodyToMono(LyricsData.class)
.block(); .block();
} }
// TODO @xin:应该传入 List<String> ids
// TODO @xin:方法名建议使用 getMusicList
public List<MusicData> selectById(String ids) { public List<MusicData> selectById(String ids) {
return this.webClient.get() return this.webClient.get()
.uri(uriBuilder -> uriBuilder .uri(uriBuilder -> uriBuilder
@ -94,12 +94,11 @@ public class SunoApi {
.build()) .build())
.retrieve() .retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION) .onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { .bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { })
})
.block(); .block();
} }
// TODO @xin:方法名建议使用 getLimitUsage
public LimitData selectLimit() { public LimitData selectLimit() {
return this.webClient.get() return this.webClient.get()
.uri("/api/get_limit") .uri("/api/get_limit")
@ -109,7 +108,7 @@ public class SunoApi {
.block(); .block();
} }
// TODO @xin可以改成 MusicGenerateRequest
/** /**
* 根据提示生成音频 * 根据提示生成音频
* *
@ -122,7 +121,7 @@ public class SunoApi {
* @param makeInstrumental 指示音乐音频是否为定制如果为 true则从歌词生成否则从提示生成 * @param makeInstrumental 指示音乐音频是否为定制如果为 true则从歌词生成否则从提示生成
*/ */
@JsonInclude(value = JsonInclude.Include.NON_NULL) @JsonInclude(value = JsonInclude.Include.NON_NULL)
public record SunoReq( public record SunoRequest(
String prompt, String prompt,
String tags, String tags,
String title, String title,
@ -130,23 +129,23 @@ public class SunoApi {
@JsonProperty("wait_audio") boolean waitAudio, @JsonProperty("wait_audio") boolean waitAudio,
@JsonProperty("make_instrumental") boolean makeInstrumental @JsonProperty("make_instrumental") boolean makeInstrumental
) { ) {
public SunoReq(String prompt) {
public SunoRequest(String prompt) {
this(prompt, null, null, null, false, false); this(prompt, null, null, null, false, false);
} }
public SunoReq(String prompt, String mv, boolean makeInstrumental) { public SunoRequest(String prompt, String mv, boolean makeInstrumental) {
this(prompt, null, null, mv, false, makeInstrumental); this(prompt, null, null, mv, false, makeInstrumental);
} }
public SunoRequest(String prompt, String mv, String tags, String title) {
public SunoReq(String prompt, String mv, String tags, String title) {
this(prompt, tags, title, mv, false, false); this(prompt, tags, title, mv, false, false);
} }
} }
/** /**
* SunoAPI 响应的音频数据 * Suno API 响应的音频数据
* *
* @param id 音乐数据的 ID * @param id 音乐数据的 ID
* @param title 音乐音频的标题 * @param title 音乐音频的标题
@ -179,7 +178,6 @@ public class SunoApi {
) { ) {
} }
/** /**
* Suno API 响应的歌词数据 * Suno API 响应的歌词数据
* *
@ -194,7 +192,6 @@ public class SunoApi {
) { ) {
} }
/** /**
* Suno API 响应的限额数据目前每日免费50 * Suno API 响应的限额数据目前每日免费50
*/ */
@ -206,5 +203,4 @@ public class SunoApi {
) { ) {
} }
} }

View File

@ -29,7 +29,7 @@ public class SunoTests {
@Test @Test
public void generate() { public void generate() {
List<SunoApi.MusicData> generate = sunoApi.generate(new SunoApi.SunoReq("创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。")); List<SunoApi.MusicData> generate = sunoApi.generate(new SunoApi.SunoRequest("创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。"));
System.out.println(generate); System.out.println(generate);
} }