diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http index 463e61207..924f21866 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http @@ -1,15 +1,13 @@ - -### chat call -POST {{baseUrl}}/admin-api/ai/chat/message/send +### 发送消息(段式) +POST {{baseUrl}}/ai/chat/message/send Content-Type: application/json Authorization: {{token}} { - "conversationId": "1781604279872581649", + "conversationId": "1781604279872581724", "content": "你是 OpenAI 么?" } - ### 发送消息(流式) POST {{baseUrl}}/ai/chat/message/send-stream Content-Type: application/json @@ -20,11 +18,10 @@ Authorization: {{token}} "content": "1+1=?" } -### message list -GET {{baseUrl}}/admin-api/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649 +### 获得指定对话的消息列表 +GET {{baseUrl}}/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649 Authorization: {{token}} - -### message delete -DELETE {{baseUrl}}/admin-api/ai/chat/message/delete?id=50 -Authorization: {{token}} +### 删除消息 +DELETE {{baseUrl}}/ai/chat/message/delete?id=50 +Authorization: {{token}} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java index c0a116216..357dbec5e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java @@ -21,10 +21,10 @@ import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.annotation.Resource; import jakarta.annotation.security.PermitAll; +import jakarta.validation.Valid; import lombok.extern.slf4j.Slf4j; import org.springframework.http.MediaType; import org.springframework.security.access.prepost.PreAuthorize; -import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.*; import reactor.core.publisher.Flux; @@ -51,14 +51,14 @@ public class AiChatMessageController { @Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢") @PostMapping("/send") - public CommonResult sendMessage(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { - return success(chatMessageService.sendMessage(sendReqVO)); + public CommonResult sendMessage(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) { + return success(chatMessageService.sendMessage(sendReqVO, getLoginUserId())); } @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题 - public Flux> sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { + public Flux> sendChatMessageStream(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) { return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId()); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java index fbc31eea5..58ba05659 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java @@ -31,14 +31,6 @@ public class AiChatMessageSendRespVO { @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED) private LocalDateTime createTime; - // ========== 扩展字段 ========== - - @Schema(description = "用户头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://iocoder.cn/1.png") - private String userAvatar; - - @Schema(description = "角色头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://iocoder.cn/2.png") - private String roleAvatar; - } } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java index 064af1608..de12ee1e0 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java @@ -68,7 +68,7 @@ public class AiImageController { @Operation(summary = "生成图片") @PostMapping("/draw") - public CommonResult drawImage(@Validated @RequestBody AiImageDrawReqVO drawReqVO) { + public CommonResult drawImage(@Valid @RequestBody AiImageDrawReqVO drawReqVO) { return success(imageService.drawImage(getLoginUserId(), drawReqVO)); } @@ -84,7 +84,7 @@ public class AiImageController { @Operation(summary = "【Midjourney】生成图片") @PostMapping("/midjourney/imagine") - public CommonResult midjourneyImagine(@Validated @RequestBody AiMidjourneyImagineReqVO reqVO) { + public CommonResult midjourneyImagine(@Valid @RequestBody AiMidjourneyImagineReqVO reqVO) { Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO); return success(imageId); } @@ -92,14 +92,14 @@ public class AiImageController { @Operation(summary = "【Midjourney】通知图片进展", description = "由 Midjourney Proxy 回调") @PostMapping("/midjourney/notify") // 必须是 POST 方法,否则会报错 @PermitAll - public CommonResult midjourneyNotify(@Validated @RequestBody MidjourneyApi.Notify notify) { + public CommonResult midjourneyNotify(@Valid @RequestBody MidjourneyApi.Notify notify) { imageService.midjourneyNotify(notify); return success(true); } @Operation(summary = "【Midjourney】Action 操作(二次生成图片)", description = "例如说:放大、缩小、U1、U2 等") @PostMapping("/midjourney/action") - public CommonResult midjourneyAction(@Validated @RequestBody AiMidjourneyActionReqVO reqVO) { + public CommonResult midjourneyAction(@Valid @RequestBody AiMidjourneyActionReqVO reqVO) { Long imageId = imageService.midjourneyAction(getLoginUserId(), reqVO); return success(imageId); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.http b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.http index 5c8b05bcf..ae68c82ea 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.http +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.http @@ -6,8 +6,8 @@ Authorization: {{token}} { "platform": "Suno", "generateMode": 2, - "prompt": "周末啦!", - "model": "chirp-v3-5", + "prompt": "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。", + "model": "chirp-v3.5", "tags": ["Happy"], "title": "Happy Song" } @@ -20,7 +20,7 @@ Authorization: {{token}} { "platform": "Suno", "generateMode": 1, - "model": "chirp-v3-5", - "prompt": "今天是星球六,结果是个下雨天,希望心情很美丽", + "model": "chirp-v3.5", + "prompt": "happy music", "makeInstrumental": false } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.java index 6c09e4b30..b079f9a3a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.java @@ -38,6 +38,10 @@ public class AiMusicController { @PostMapping("/generate") @Operation(summary = "音乐生成") public CommonResult> generateMusic(@RequestBody @Valid AiSunoGenerateReqVO reqVO) { +// if (true) { +// musicService.syncMusic(); +// return null; +// } return success(musicService.generateMusic(getLoginUserId(), reqVO)); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageService.java index a529a9716..f572bddd9 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageService.java @@ -2,8 +2,9 @@ package cn.iocoder.yudao.module.ai.service.chat; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import reactor.core.publisher.Flux; @@ -22,9 +23,10 @@ public interface AiChatMessageService { * 发送消息 * * @param sendReqVO 发送信息 + * @param userId 用户编号 * @return 发送结果 */ - AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO sendReqVO); + AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId); /** * 发送消息 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 658eaa8fc..48b1e482d 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -10,22 +10,20 @@ import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; -import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; @@ -64,47 +62,37 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { @Resource private AiChatModelService chatModalService; @Resource - private AiChatRoleService chatRoleService; - @Resource private AiApiKeyService apiKeyService; @Transactional(rollbackFor = Exception.class) - public AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO req) { - return null; // TODO 芋艿:一起改 -// Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); -// // 查询对话 -// AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId()); -// // 获取对话模型 -// AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId()); -// // 获取角色信息 -// AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; -// // 获取 client 类型 -// AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform()); -// // 保存 chat message -// createChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), -// chatModel.getModel(), chatModel.getId(), req.getContent()); -// String content = null; -// int tokens = 0; -// try { -// // 创建 chat 需要的 Prompt -// Prompt prompt = new Prompt(req.getContent()); -// // TODO @芋艿 @范 看要不要支持这些 -//// req.setTopK(req.getTopK()); -//// req.setTopP(req.getTopP()); -//// req.setTemperature(req.getTemperature()); -// // 发送 call 调用 -// ChatClient chatClient = chatClientFactory.getChatClient(platformEnum); -// ChatResponse call = chatClient.call(prompt); -// content = call.getResult().getOutput().getContent(); -// // 更新 conversation -// } catch (Exception e) { -// content = ExceptionUtil.getMessage(e); -// } finally { -// // 保存 chat message -// createChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), -// chatModel.getModel(), chatModel.getId(), content); -// } -// return new AiChatMessageRespVO().setContent(content); + public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { + // 1.1 校验对话存在 + AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId()); + if (ObjUtil.notEqual(conversation.getUserId(), userId)) { + throw exception(CHAT_CONVERSATION_NOT_EXISTS); + } + List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); + // 1.2 校验模型 + AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); + ChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); + + // 2. 插入 user 发送消息 + AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, + userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext()); + + // 3.1 插入 assistant 接收消息 + AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, + userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); + + // 3.2 创建 chat 需要的 Prompt + Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); + ChatResponse chatResponse = chatClient.call(prompt); + + // 3.3 段式返回 + String newContent = chatResponse.getResult().getOutput().getContent(); + chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent)); + return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) + .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent)); } @Override @@ -112,14 +100,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { // 1.1 校验对话存在 AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId()); if (ObjUtil.notEqual(conversation.getUserId(), userId)) { - throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿:异常情况的对接; + throw exception(CHAT_CONVERSATION_NOT_EXISTS); } List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); // 1.2 校验模型 AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); - StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId()); - // 1.3 获取用户头像、角色头像 - AiChatRoleDO role = conversation.getRoleId() != null ? chatRoleService.getChatRole(conversation.getRoleId()) : null; + StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); // 2. 插入 user 发送消息 AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, @@ -149,9 +135,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { // TODO @芋艿:失败的情况下,要不要删除消息 log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable); chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())); - }).onErrorResume(error -> { - return Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)); - }); + }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR))); } private Prompt buildPrompt(AiChatConversationDO conversation, List messages, diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java index bf4305b67..a5ba60867 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java @@ -8,7 +8,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageR import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import jakarta.validation.Valid; -import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; import java.util.List; @@ -76,12 +76,12 @@ public interface AiApiKeyService { // ========== 与 spring-ai 集成 ========== /** - * 获得 StreamingChatClient 对象 + * 获得 ChatModel 对象 * * @param id 编号 - * @return StreamingChatClient 对象 + * @return ChatModel 对象 */ - StreamingChatModel getStreamingChatClient(Long id); + ChatModel getChatClient(Long id); /** * 获得 ImageClient 对象 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java index 1ef235068..8db777f3f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java @@ -12,7 +12,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import jakarta.annotation.Resource; -import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; @@ -98,10 +98,10 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { // ========== 与 spring-ai 集成 ========== @Override - public StreamingChatModel getStreamingChatClient(Long id) { + public ChatModel getChatClient(Long id) { AiApiKeyDO apiKey = validateApiKey(id); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); - return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); + return clientFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); } @Override diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java index 591778a55..63470dc50 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java @@ -21,6 +21,7 @@ import cn.iocoder.yudao.module.infra.api.file.FileApi; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import java.util.*; @@ -49,6 +50,7 @@ public class AiMusicServiceImpl implements AiMusicService { private FileApi fileApi; @Override + @Transactional(rollbackFor = Exception.class) public List generateMusic(Long userId, AiSunoGenerateReqVO reqVO) { // 1. 调用 Suno 生成音乐 SunoApi sunoApi = apiKeyService.getSunoApi(); @@ -164,14 +166,9 @@ public class AiMusicServiceImpl implements AiMusicService { */ private List buildMusicDOList(List musicList) { return convertList(musicList, musicData -> { - Integer status; - if (Objects.equals("complete", musicData.status())) { - status = AiMusicStatusEnum.SUCCESS.getStatus(); - } else if (Objects.equals("error", musicData.status())) { - status = AiMusicStatusEnum.FAIL.getStatus(); - } else { - status = AiMusicStatusEnum.IN_PROGRESS.getStatus(); - } + Integer status = Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.SUCCESS.getStatus() + : Objects.equals("error", musicData.status()) ? AiMusicStatusEnum.FAIL.getStatus() + : AiMusicStatusEnum.IN_PROGRESS.getStatus(); return new AiMusicDO() .setTaskId(musicData.id()).setModel(musicData.modelName()) .setDescription(musicData.gptDescriptionPrompt()) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index 2130485a1..690aa444a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -53,7 +53,7 @@ public class AiWriteServiceImpl implements AiWriteService { // 1.1 校验模型 // TODO @xin:可以约定大于配置先,查询某个名字。例如说,写作助手!然后写作助手,上面是有个 model 的,可以使用它。 AiChatModelDO model = chatModalService.validateChatModel(14L); - StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId()); + StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java index 44ad08294..e37afc41d 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java @@ -3,7 +3,7 @@ package cn.iocoder.yudao.framework.ai.core.factory; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; -import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; /** @@ -14,26 +14,26 @@ import org.springframework.ai.image.ImageModel; public interface AiClientFactory { /** - * 基于指定配置,获得 StreamingChatClient 对象 + * 基于指定配置,获得 ChatModel 对象 * * 如果不存在,则进行创建 * * @param platform 平台 * @param apiKey API KEY * @param url API URL - * @return StreamingChatClient 对象 + * @return ChatModel 对象 */ - StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url); + ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url); /** - * 基于默认配置,获得 StreamingChatClient 对象 + * 基于默认配置,获得 ChatModel 对象 * * 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置 * * @param platform 平台 - * @return StreamingChatClient 对象 + * @return ChatModel 对象 */ - StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform); + ChatModel getDefaultChatClient(AiPlatformEnum platform); /** * 基于默认配置,获得 ImageClient 对象 diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java index a22d2fc72..10f4b7503 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java @@ -23,7 +23,7 @@ import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties; import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties; -import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaApi; @@ -50,9 +50,9 @@ import java.util.List; public class AiClientFactoryImpl implements AiClientFactory { @Override - public StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) { - String cacheKey = buildClientCacheKey(StreamingChatModel.class, platform, apiKey, url); - return Singleton.get(cacheKey, (Func0) () -> { + public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) { + String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url); + return Singleton.get(cacheKey, (Func0) () -> { //noinspection EnhancedSwitchMigration switch (platform) { case OPENAI: @@ -74,7 +74,7 @@ public class AiClientFactoryImpl implements AiClientFactory { } @Override - public StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform) { + public ChatModel getDefaultChatClient(AiPlatformEnum platform) { //noinspection EnhancedSwitchMigration switch (platform) { case OPENAI: