diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/DictTypeConstants.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/DictTypeConstants.java new file mode 100644 index 000000000..73782a2cb --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/DictTypeConstants.java @@ -0,0 +1,16 @@ +package cn.iocoder.yudao.module.ai.enums; + +/** + * AI 字典类型的枚举类 + * + * @author xiaoxin + */ +public interface DictTypeConstants { + + // ========== AI Write ========== + String AI_WRITE_FORMAT = "ai_write_format"; // 写作格式 + String AI_WRITE_LENGTH = "ai_write_length"; // 写作长度 + String AI_WRITE_LANGUAGE = "ai_write_language"; // 写作语言 + String AI_WRITE_TONE = "ai_write_tone"; // 写作语气 + +} diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/model/AiModelEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/model/AiModelEnum.java deleted file mode 100644 index 9e584e18f..000000000 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/model/AiModelEnum.java +++ /dev/null @@ -1,73 +0,0 @@ -package cn.iocoder.yudao.module.ai.enums.model; - -import lombok.AllArgsConstructor; -import lombok.Getter; - -// TODO @芋艿:可以考虑清理掉 -/** - * ai 模型 - * - * @author: fansili - * @time: 2024/3/4 12:36 - */ -@Getter -@AllArgsConstructor -public enum AiModelEnum { - - // open ai - OPEN_AI_GPT_3_5( "GPT3.5", "gpt-3.5-turbo",null), - OPEN_AI_GPT_4("GPT4", "gpt-4-turbo",null), - - // 千问付费模型 - QWEN_TURBO("通义千问超大规模语言模型", "qwen-turbo", null), - QWEN_PLUS("通义千问超大规模语言模型增强版", "qwen-plus", null), - QWEN_MAX("通义千问千亿级别超大规模语言模型", "qwen-max", null), - QWEN_MAX_0403("通义千问千亿级别超大规模语言模型-0403", "qwen-max-0403", null), - QWEN_MAX_0107("通义千问千亿级别超大规模语言模型-0107", "qwen-max-0107", null), - QWEN_MAX_1201("通义千问千亿级别超大规模语言模型-1201", "qwen-max-1201", null), - QWEN_MAX_LONGCONTEXT("通义千问千亿级别超大规模语言模型-28k tokens", "qwen-max-longcontext", null), - - // 千问开源模型 - // https://help.aliyun.com/document_detail/2666503.html?spm=a2c4g.2701795.0.0.26eb34dfKzcWN4 - QWEN_72B_CHAT("通义千问1.5对外开源的72B规模参数量的经过人类指令对齐的chat模型", "qwen-72b-chat", null), - - // 一言模型 - ERNIE4_0("ERNIE 4.0", "ERNIE 4.0", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"), - ERNIE4_3_5_8K("ERNIE-3.5-8K", "ERNIE-3.5-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"), - ERNIE4_3_5_8K_0205("ERNIE-3.5-8K-0205", "ERNIE-3.5-8K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205"), - - ERNIE4_3_5_8K_1222("ERNIE-3.5-8K-1222", "ERNIE-3.5-8K-1222", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222"), - ERNIE4_BOT_8K("ERNIE-Bot-8K", "ERNIE-Bot-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"), - ERNIE4_3_5_4K_0205("ERNIE-3.5-4K-0205", "ERNIE-3.5-4K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205"), - -// 文档地址:https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E -// general指向V1.5版本; -// generalv2指向V2版本; -// generalv3指向V3版本; -// generalv3.5指向V3.5版本; - - XING_HUO_1_5("星火大模型1.5", "general", "/v1.1/chat"), - XING_HUO_2_0("星火大模型2.0", "generalv2", "/v2.1/chat"), - XING_HUO_3_0("星火大模型3.0", "generalv3", "/v3.1/chat"), - XING_HUO_3_5("星火大模型3.5", "generalv3.5", "/v3.5/chat"), - - // Suno 模型 - SUNO_2( "SUNO-2", "chirp-v2-xxl-alpha",null), - SUNO_3_0( "SUNO-3.0", "chirp-v3-0",null), - SUNO_3_5( "SUNO-3.5", "chirp-v3.5",null), - ; - - /** - * 模型名字 - 用于展示 - */ - private final String name; - /** - * 模型标志 - 用于参数传递 - */ - private final String model; - /** - * uri地址 - */ - private final String uri; - -} diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiLanguageEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiLanguageEnum.java deleted file mode 100644 index 00d88359b..000000000 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiLanguageEnum.java +++ /dev/null @@ -1,45 +0,0 @@ -package cn.iocoder.yudao.module.ai.enums.write; - -import cn.iocoder.yudao.framework.common.core.IntArrayValuable; -import lombok.AllArgsConstructor; -import lombok.Getter; - -import java.util.Arrays; - -// TODO @xin:写作的几个,不用枚举类哈。直接搞字段就好了。AiWriteTypeEnum 还是需要的哈 -@AllArgsConstructor -@Getter -public enum AiLanguageEnum implements IntArrayValuable { - - AUTO(1, "自动"), - CHINESE(2, "中文"), - ENGLISH(3, "英文"), - KOREAN(4, "韩语"), - JAPANESE(5, "日语"); - - /** - * Language code - */ - private final Integer language; - /** - * Language name - */ - private final String name; - - public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiLanguageEnum::getLanguage).toArray(); - - @Override - public int[] array() { - return ARRAYS; - } - - public static AiLanguageEnum valueOfLanguage(Integer language) { - for (AiLanguageEnum languageEnum : AiLanguageEnum.values()) { - if (languageEnum.getLanguage().equals(language)) { - return languageEnum; - } - } - throw new IllegalArgumentException("未知语言: " + language); - } - -} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteFormatEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteFormatEnum.java deleted file mode 100644 index d77e08fcc..000000000 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteFormatEnum.java +++ /dev/null @@ -1,53 +0,0 @@ -package cn.iocoder.yudao.module.ai.enums.write; - -import cn.iocoder.yudao.framework.common.core.IntArrayValuable; -import lombok.AllArgsConstructor; -import lombok.Getter; - -import java.util.Arrays; - -/** - * AI 写作类型的枚举 - * - * @author xiaoxin - */ -@AllArgsConstructor -@Getter -public enum AiWriteFormatEnum implements IntArrayValuable { - - AUTO(1, "自动"), - EMAIL(2, "电子邮件"), - MESSAGE(3, "消息"), - COMMENT(4, "评论"), - PARAGRAPH(5, "段落"), - ARTICLE(6, "文章"), - BLOG_POST(7, "博客文章"), - IDEA(8, "想法"), - OUTLINE(9, "大纲"); - - /** - * 格式 - */ - private final Integer format; - /** - * 格式名 - */ - private final String name; - - public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteFormatEnum::getFormat).toArray(); - - @Override - public int[] array() { - return ARRAYS; - } - - public static AiWriteFormatEnum valueOfFormat(Integer format) { - for (AiWriteFormatEnum formatEnum : AiWriteFormatEnum.values()) { - if (formatEnum.getFormat().equals(format)) { - return formatEnum; - } - } - throw new IllegalArgumentException("未知格式: " + format); - } - -} diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteLengthEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteLengthEnum.java deleted file mode 100644 index 2c6a9c5c1..000000000 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteLengthEnum.java +++ /dev/null @@ -1,47 +0,0 @@ -package cn.iocoder.yudao.module.ai.enums.write; - -import cn.iocoder.yudao.framework.common.core.IntArrayValuable; -import lombok.AllArgsConstructor; -import lombok.Getter; - -import java.util.Arrays; - -/** - * AI 写作类型的枚举 - * - * @author xiaoxin - */ -@AllArgsConstructor -@Getter -public enum AiWriteLengthEnum implements IntArrayValuable { - - AUTO(1, "自动"), - SHORT(2, "短"), - MEDIUM(3, "中"), - LONG(4, "长"); - - /** - * 长度 - */ - private final Integer length; - /** - * 长度名 - */ - private final String name; - - public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteLengthEnum::getLength).toArray(); - - @Override - public int[] array() { - return ARRAYS; - } - - public static AiWriteLengthEnum valueOfLength(Integer length) { - for (AiWriteLengthEnum lengthEnum : AiWriteLengthEnum.values()) { - if (lengthEnum.getLength().equals(length)) { - return lengthEnum; - } - } - throw new IllegalArgumentException("未知长度: " + length); - } -} diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteToneEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteToneEnum.java deleted file mode 100644 index 181682fd9..000000000 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteToneEnum.java +++ /dev/null @@ -1,46 +0,0 @@ -package cn.iocoder.yudao.module.ai.enums.write; - -import cn.iocoder.yudao.framework.common.core.IntArrayValuable; -import lombok.AllArgsConstructor; -import lombok.Getter; - -import java.util.Arrays; - -@AllArgsConstructor -@Getter -public enum AiWriteToneEnum implements IntArrayValuable { - - AUTO(1, "自动"), - FRIENDLY(2, "友善"), - CASUAL(3, "随意"), - KIND(4, "友好"), - PROFESSIONAL(5, "专业"), - HUMOROUS(6, "谈谐"), - INTERESTING(7, "有趣"), - FORMAL(8, "正式"); - - /** - * 语气 - */ - private final Integer tone; - /** - * 语气名 - */ - private final String name; - - public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteToneEnum::getTone).toArray(); - - @Override - public int[] array() { - return ARRAYS; - } - - public static AiWriteToneEnum valueOfTone(Integer tone) { - for (AiWriteToneEnum toneEnum : AiWriteToneEnum.values()) { - if (toneEnum.getTone().equals(tone)) { - return toneEnum; - } - } - throw new IllegalArgumentException("未知语气: " + tone); - } -} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatConversationController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatConversationController.java index ed22acf69..5142cde44 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatConversationController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatConversationController.java @@ -23,7 +23,6 @@ import org.springframework.web.bind.annotation.*; import java.util.List; import java.util.Map; -import java.util.function.Consumer; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; @@ -79,9 +78,8 @@ public class AiChatConversationController { return success(true); } - // TODO 芋艿:这个 url 可以改下 - @DeleteMapping("/delete-my-all-except-pinned") - @Operation(summary = "删除所有对话(置顶除外)") + @DeleteMapping("/delete-by-unpinned") + @Operation(summary = "删除未置顶的聊天对话") public CommonResult deleteChatConversationMyByUnpinned() { chatConversationService.deleteChatConversationMyByUnpinned(getLoginUserId()); return success(true); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http index 463e61207..924f21866 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http @@ -1,15 +1,13 @@ - -### chat call -POST {{baseUrl}}/admin-api/ai/chat/message/send +### 发送消息(段式) +POST {{baseUrl}}/ai/chat/message/send Content-Type: application/json Authorization: {{token}} { - "conversationId": "1781604279872581649", + "conversationId": "1781604279872581724", "content": "你是 OpenAI 么?" } - ### 发送消息(流式) POST {{baseUrl}}/ai/chat/message/send-stream Content-Type: application/json @@ -20,11 +18,10 @@ Authorization: {{token}} "content": "1+1=?" } -### message list -GET {{baseUrl}}/admin-api/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649 +### 获得指定对话的消息列表 +GET {{baseUrl}}/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649 Authorization: {{token}} - -### message delete -DELETE {{baseUrl}}/admin-api/ai/chat/message/delete?id=50 -Authorization: {{token}} +### 删除消息 +DELETE {{baseUrl}}/ai/chat/message/delete?id=50 +Authorization: {{token}} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java index c0a116216..357dbec5e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java @@ -21,10 +21,10 @@ import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.annotation.Resource; import jakarta.annotation.security.PermitAll; +import jakarta.validation.Valid; import lombok.extern.slf4j.Slf4j; import org.springframework.http.MediaType; import org.springframework.security.access.prepost.PreAuthorize; -import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.*; import reactor.core.publisher.Flux; @@ -51,14 +51,14 @@ public class AiChatMessageController { @Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢") @PostMapping("/send") - public CommonResult sendMessage(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { - return success(chatMessageService.sendMessage(sendReqVO)); + public CommonResult sendMessage(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) { + return success(chatMessageService.sendMessage(sendReqVO, getLoginUserId())); } @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题 - public Flux> sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { + public Flux> sendChatMessageStream(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) { return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId()); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java index 1f7e33fc3..9b358df6f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java @@ -28,7 +28,7 @@ public class AiChatMessageRespVO { private Long roleId; @Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "gpt-3.5-turbo") - private String model; // 参见 AiOpenAiModelEnum 枚举类 + private String model; @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "123") private Long modelId; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java index fbc31eea5..58ba05659 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java @@ -31,14 +31,6 @@ public class AiChatMessageSendRespVO { @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED) private LocalDateTime createTime; - // ========== 扩展字段 ========== - - @Schema(description = "用户头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://iocoder.cn/1.png") - private String userAvatar; - - @Schema(description = "角色头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://iocoder.cn/2.png") - private String roleAvatar; - } } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.http b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.http index 10fa24b5f..9047610c0 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.http +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.http @@ -1,4 +1,3 @@ - ### 生成图片:OpenAI(DALL) POST {{baseUrl}}/ai/image/draw Content-Type: application/json @@ -29,8 +28,7 @@ Authorization: {{token}} "style": "vivid" } -### 生成图片:生成图片 - +### 生成图片:生成图片(Midjourney) POST {{baseUrl}}/ai/image/midjourney/imagine Content-Type: application/json Authorization: {{token}} @@ -40,6 +38,5 @@ Authorization: {{token}} "model": "midjourney", "width": "1", "height": "1", - "version": "6.0", - "base64Array": [] + "version": "6.0" } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java index 064af1608..de12ee1e0 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java @@ -68,7 +68,7 @@ public class AiImageController { @Operation(summary = "生成图片") @PostMapping("/draw") - public CommonResult drawImage(@Validated @RequestBody AiImageDrawReqVO drawReqVO) { + public CommonResult drawImage(@Valid @RequestBody AiImageDrawReqVO drawReqVO) { return success(imageService.drawImage(getLoginUserId(), drawReqVO)); } @@ -84,7 +84,7 @@ public class AiImageController { @Operation(summary = "【Midjourney】生成图片") @PostMapping("/midjourney/imagine") - public CommonResult midjourneyImagine(@Validated @RequestBody AiMidjourneyImagineReqVO reqVO) { + public CommonResult midjourneyImagine(@Valid @RequestBody AiMidjourneyImagineReqVO reqVO) { Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO); return success(imageId); } @@ -92,14 +92,14 @@ public class AiImageController { @Operation(summary = "【Midjourney】通知图片进展", description = "由 Midjourney Proxy 回调") @PostMapping("/midjourney/notify") // 必须是 POST 方法,否则会报错 @PermitAll - public CommonResult midjourneyNotify(@Validated @RequestBody MidjourneyApi.Notify notify) { + public CommonResult midjourneyNotify(@Valid @RequestBody MidjourneyApi.Notify notify) { imageService.midjourneyNotify(notify); return success(true); } @Operation(summary = "【Midjourney】Action 操作(二次生成图片)", description = "例如说:放大、缩小、U1、U2 等") @PostMapping("/midjourney/action") - public CommonResult midjourneyAction(@Validated @RequestBody AiMidjourneyActionReqVO reqVO) { + public CommonResult midjourneyAction(@Valid @RequestBody AiMidjourneyActionReqVO reqVO) { Long imageId = imageService.midjourneyAction(getLoginUserId(), reqVO); return success(imageId); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.http b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.http index 20df8d132..ae68c82ea 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.http +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.http @@ -6,7 +6,7 @@ Authorization: {{token}} { "platform": "Suno", "generateMode": 2, - "prompt": "周末啦!", + "prompt": "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。", "model": "chirp-v3.5", "tags": ["Happy"], "title": "Happy Song" @@ -21,6 +21,6 @@ Authorization: {{token}} "platform": "Suno", "generateMode": 1, "model": "chirp-v3.5", - "gptDescriptionPrompt": "今天是星球六,结果是个下雨天,希望心情很美丽", + "prompt": "happy music", "makeInstrumental": false } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/vo/AiSunoGenerateReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/vo/AiSunoGenerateReqVO.java index 0222abbbc..f72d2b54a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/vo/AiSunoGenerateReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/vo/AiSunoGenerateReqVO.java @@ -46,7 +46,7 @@ public class AiSunoGenerateReqVO { @Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "chirp-v3.5") @NotEmpty(message = "模型不能为空") - private String model; // 参见 AiModelEnum 枚举 + private String model; @Schema(description = "音乐风格", example = "[\"pop\",\"jazz\",\"punk\"]") private List tags; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java index a032998ed..c9023552e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java @@ -26,9 +26,10 @@ public class AiWriteController { private AiWriteService writeService; @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - @PermitAll @Operation(summary = "写作生成(流式)", description = "流式返回,响应较快") + @PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题 public Flux> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) { return writeService.generateWriteContent(generateReqVO, getLoginUserId()); } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java index 12172acd3..065165150 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java @@ -14,11 +14,10 @@ public class AiWriteGenerateReqVO { @InEnum(AiWriteTypeEnum.class) private Integer type; - // TODO @xin:如果非必填,可以不用写 requiredMode - @Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1.撰写:田忌赛马;2.回复:不批") + @Schema(description = "写作内容提示", example = "1.撰写:田忌赛马;2.回复:不批") private String prompt; - @Schema(description = "原文", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "领导我要辞职") + @Schema(description = "原文", example = "领导我要辞职") private String originalContent; @Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java index 67cd490e5..0b7eb0233 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java @@ -3,7 +3,6 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; -import cn.iocoder.yudao.module.ai.enums.model.AiModelEnum; import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; @@ -73,8 +72,6 @@ public class AiChatConversationDO extends BaseDO { private Long modelId; /** * 模型标志 - * - * 枚举 {@link AiModelEnum} */ private String model; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java index 61608b8cc..973c593ce 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java @@ -5,7 +5,6 @@ import org.springframework.ai.chat.messages.MessageType; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; -import cn.iocoder.yudao.module.ai.enums.model.AiModelEnum; import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.TableName; import lombok.*; @@ -69,8 +68,6 @@ public class AiChatMessageDO extends BaseDO { /** * 模型标志 - * - * 枚举 {@link AiModelEnum} */ private String model; /** diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageService.java index a529a9716..f572bddd9 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageService.java @@ -2,8 +2,9 @@ package cn.iocoder.yudao.module.ai.service.chat; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import reactor.core.publisher.Flux; @@ -22,9 +23,10 @@ public interface AiChatMessageService { * 发送消息 * * @param sendReqVO 发送信息 + * @param userId 用户编号 * @return 发送结果 */ - AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO sendReqVO); + AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId); /** * 发送消息 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 658eaa8fc..3a9b34fc8 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -4,35 +4,28 @@ import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; +import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; -import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; -import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.ollama.api.OllamaOptions; -import org.springframework.ai.openai.OpenAiChatOptions; -import org.springframework.ai.qianfan.QianFanChatOptions; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; @@ -64,47 +57,37 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { @Resource private AiChatModelService chatModalService; @Resource - private AiChatRoleService chatRoleService; - @Resource private AiApiKeyService apiKeyService; @Transactional(rollbackFor = Exception.class) - public AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO req) { - return null; // TODO 芋艿:一起改 -// Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); -// // 查询对话 -// AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId()); -// // 获取对话模型 -// AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId()); -// // 获取角色信息 -// AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; -// // 获取 client 类型 -// AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform()); -// // 保存 chat message -// createChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), -// chatModel.getModel(), chatModel.getId(), req.getContent()); -// String content = null; -// int tokens = 0; -// try { -// // 创建 chat 需要的 Prompt -// Prompt prompt = new Prompt(req.getContent()); -// // TODO @芋艿 @范 看要不要支持这些 -//// req.setTopK(req.getTopK()); -//// req.setTopP(req.getTopP()); -//// req.setTemperature(req.getTemperature()); -// // 发送 call 调用 -// ChatClient chatClient = chatClientFactory.getChatClient(platformEnum); -// ChatResponse call = chatClient.call(prompt); -// content = call.getResult().getOutput().getContent(); -// // 更新 conversation -// } catch (Exception e) { -// content = ExceptionUtil.getMessage(e); -// } finally { -// // 保存 chat message -// createChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), -// chatModel.getModel(), chatModel.getId(), content); -// } -// return new AiChatMessageRespVO().setContent(content); + public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { + // 1.1 校验对话存在 + AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId()); + if (ObjUtil.notEqual(conversation.getUserId(), userId)) { + throw exception(CHAT_CONVERSATION_NOT_EXISTS); + } + List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); + // 1.2 校验模型 + AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); + ChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); + + // 2. 插入 user 发送消息 + AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, + userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext()); + + // 3.1 插入 assistant 接收消息 + AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, + userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); + + // 3.2 创建 chat 需要的 Prompt + Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); + ChatResponse chatResponse = chatClient.call(prompt); + + // 3.3 段式返回 + String newContent = chatResponse.getResult().getOutput().getContent(); + chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent)); + return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) + .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent)); } @Override @@ -112,14 +95,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { // 1.1 校验对话存在 AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId()); if (ObjUtil.notEqual(conversation.getUserId(), userId)) { - throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿:异常情况的对接; + throw exception(CHAT_CONVERSATION_NOT_EXISTS); } List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); // 1.2 校验模型 AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); - StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId()); - // 1.3 获取用户头像、角色头像 - AiChatRoleDO role = conversation.getRoleId() != null ? chatRoleService.getChatRole(conversation.getRoleId()) : null; + StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); // 2. 插入 user 发送消息 AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, @@ -149,9 +130,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { // TODO @芋艿:失败的情况下,要不要删除消息 log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable); chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())); - }).onErrorResume(error -> { - return Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)); - }); + }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR))); } private Prompt buildPrompt(AiChatConversationDO conversation, List messages, @@ -164,46 +143,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } // 1.2 history message 历史消息 List contextMessages = filterContextMessages(messages, conversation, sendReqVO); - contextMessages.forEach(message -> { - // TODO @芋艿:看看有没优化空间 - if (MessageType.USER.getValue().equals(message.getType())) { - chatMessages.add(new UserMessage(message.getContent())); - } else { - chatMessages.add(new AssistantMessage(message.getContent())); - } - }); + contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); // 1.3 user message 新发送消息 chatMessages.add(new UserMessage(sendReqVO.getContent())); // 2. 构建 ChatOptions 对象 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); - ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), + ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), conversation.getTemperature(), conversation.getMaxTokens()); return new Prompt(chatMessages, chatOptions); } - private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { - Float temperatureF = temperature != null ? temperature.floatValue() : null; - //noinspection EnhancedSwitchMigration - switch (platform) { - case OPENAI: - return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - case OLLAMA: - return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); - case YI_YAN: - // TODO 芋艿:貌似 model 只要一设置,就报错 -// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - case XING_HUO: - return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) - .setMaxTokens(maxTokens); - case QIAN_WEN: - return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build(); - default: - throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); - } - } - /** * 从历史消息中,获得倒序的 n 组消息作为消息上下文 * diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java index bf4305b67..a5ba60867 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java @@ -8,7 +8,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageR import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import jakarta.validation.Valid; -import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; import java.util.List; @@ -76,12 +76,12 @@ public interface AiApiKeyService { // ========== 与 spring-ai 集成 ========== /** - * 获得 StreamingChatClient 对象 + * 获得 ChatModel 对象 * * @param id 编号 - * @return StreamingChatClient 对象 + * @return ChatModel 对象 */ - StreamingChatModel getStreamingChatClient(Long id); + ChatModel getChatClient(Long id); /** * 获得 ImageClient 对象 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java index 1ef235068..8db777f3f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java @@ -12,7 +12,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import jakarta.annotation.Resource; -import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; @@ -98,10 +98,10 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { // ========== 与 spring-ai 集成 ========== @Override - public StreamingChatModel getStreamingChatClient(Long id) { + public ChatModel getChatClient(Long id) { AiApiKeyDO apiKey = validateApiKey(id); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); - return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); + return clientFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); } @Override diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java index 591778a55..3f10ec840 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java @@ -21,6 +21,7 @@ import cn.iocoder.yudao.module.infra.api.file.FileApi; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import java.util.*; @@ -49,10 +50,10 @@ public class AiMusicServiceImpl implements AiMusicService { private FileApi fileApi; @Override + @Transactional(rollbackFor = Exception.class) public List generateMusic(Long userId, AiSunoGenerateReqVO reqVO) { // 1. 调用 Suno 生成音乐 SunoApi sunoApi = apiKeyService.getSunoApi(); - // TODO 芋艿:这两个貌似一直没跑成功,你那可以么?用的请求是 AiMusicController.http 的 --xin:大部分ok的,补充了error_message List musicDataList; if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) { // 1.1 描述模式 @@ -164,14 +165,9 @@ public class AiMusicServiceImpl implements AiMusicService { */ private List buildMusicDOList(List musicList) { return convertList(musicList, musicData -> { - Integer status; - if (Objects.equals("complete", musicData.status())) { - status = AiMusicStatusEnum.SUCCESS.getStatus(); - } else if (Objects.equals("error", musicData.status())) { - status = AiMusicStatusEnum.FAIL.getStatus(); - } else { - status = AiMusicStatusEnum.IN_PROGRESS.getStatus(); - } + Integer status = Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.SUCCESS.getStatus() + : Objects.equals("error", musicData.status()) ? AiMusicStatusEnum.FAIL.getStatus() + : AiMusicStatusEnum.IN_PROGRESS.getStatus(); return new AiMusicDO() .setTaskId(musicData.id()).setModel(musicData.modelName()) .setDescription(musicData.gptDescriptionPrompt()) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index 2130485a1..9e5c0a7ff 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -2,28 +2,25 @@ package cn.iocoder.yudao.module.ai.service.write; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; +import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; +import cn.iocoder.yudao.module.ai.enums.DictTypeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; -import cn.iocoder.yudao.module.ai.enums.write.*; +import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; -import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; +import cn.iocoder.yudao.module.system.api.dict.DictDataApi; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.ollama.api.OllamaOptions; -import org.springframework.ai.openai.OpenAiChatOptions; -import org.springframework.ai.qianfan.QianFanChatOptions; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; @@ -45,25 +42,30 @@ public class AiWriteServiceImpl implements AiWriteService { private AiApiKeyService apiKeyService; @Resource private AiChatModelService chatModalService; + @Resource - private AiWriteMapper writeMapper; // TODO @xin:上面空一行;因为同类之间不要空行,非同类空行; + private DictDataApi dictDataApi; + + @Resource + private AiWriteMapper writeMapper; @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { - // 1.1 校验模型 - // TODO @xin:可以约定大于配置先,查询某个名字。例如说,写作助手!然后写作助手,上面是有个 model 的,可以使用它。 - AiChatModelDO model = chatModalService.validateChatModel(14L); - StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId()); + // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的; + AiChatModelDO model = chatModalService.getRequiredDefaultChatModel(); + StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); - ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); // 1.2 插入写作信息 + // TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性 AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class); writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); // 2.1 构建提示词 + ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); Flux streamResponse = chatClient.stream(prompt); + // 2.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { @@ -83,11 +85,13 @@ public class AiWriteServiceImpl implements AiWriteService { private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { String template; Integer writeType = generateReqVO.getType(); - String format = AiWriteFormatEnum.valueOfFormat(generateReqVO.getFormat()).getName(); - String tone = AiWriteToneEnum.valueOfTone(generateReqVO.getTone()).getName(); - String language = AiLanguageEnum.valueOfLanguage(generateReqVO.getLanguage()).getName(); - String length = AiWriteLengthEnum.valueOfLength(generateReqVO.getLength()).getName(); + String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); + String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat()); + String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat()); + String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat()); + // TODO @xin:建议改成 if return 哈;更简洁; if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) { + // TODO @xin:写成静态枚举哈 template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length); } else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) { @@ -98,27 +102,4 @@ public class AiWriteServiceImpl implements AiWriteService { } } - // TODO 芋艿:复用 - private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { - Float temperatureF = temperature != null ? temperature.floatValue() : null; - //noinspection EnhancedSwitchMigration - switch (platform) { - case OPENAI: - return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - case OLLAMA: - return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); - case YI_YAN: - // TODO 芋艿:貌似 model 只要一设置,就报错 -// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - case XING_HUO: - return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) - .setMaxTokens(maxTokens); - case QIAN_WEN: - return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build(); - default: - throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); - } - } - } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java index 44ad08294..e37afc41d 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java @@ -3,7 +3,7 @@ package cn.iocoder.yudao.framework.ai.core.factory; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; -import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; /** @@ -14,26 +14,26 @@ import org.springframework.ai.image.ImageModel; public interface AiClientFactory { /** - * 基于指定配置,获得 StreamingChatClient 对象 + * 基于指定配置,获得 ChatModel 对象 * * 如果不存在,则进行创建 * * @param platform 平台 * @param apiKey API KEY * @param url API URL - * @return StreamingChatClient 对象 + * @return ChatModel 对象 */ - StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url); + ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url); /** - * 基于默认配置,获得 StreamingChatClient 对象 + * 基于默认配置,获得 ChatModel 对象 * * 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置 * * @param platform 平台 - * @return StreamingChatClient 对象 + * @return ChatModel 对象 */ - StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform); + ChatModel getDefaultChatClient(AiPlatformEnum platform); /** * 基于默认配置,获得 ImageClient 对象 diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java index a22d2fc72..10f4b7503 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java @@ -23,7 +23,7 @@ import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties; import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties; -import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaApi; @@ -50,9 +50,9 @@ import java.util.List; public class AiClientFactoryImpl implements AiClientFactory { @Override - public StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) { - String cacheKey = buildClientCacheKey(StreamingChatModel.class, platform, apiKey, url); - return Singleton.get(cacheKey, (Func0) () -> { + public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) { + String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url); + return Singleton.get(cacheKey, (Func0) () -> { //noinspection EnhancedSwitchMigration switch (platform) { case OPENAI: @@ -74,7 +74,7 @@ public class AiClientFactoryImpl implements AiClientFactory { } @Override - public StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform) { + public ChatModel getDefaultChatClient(AiPlatformEnum platform) { //noinspection EnhancedSwitchMigration switch (platform) { case OPENAI: diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java new file mode 100644 index 000000000..306b216a3 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java @@ -0,0 +1,59 @@ +package cn.iocoder.yudao.framework.ai.core.util; + +import cn.hutool.core.util.StrUtil; +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; +import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.qianfan.QianFanChatOptions; + +/** + * Spring AI 工具类 + * + * @author 芋道源码 + */ +public class AiUtils { + + public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { + Float temperatureF = temperature != null ? temperature.floatValue() : null; + //noinspection EnhancedSwitchMigration + switch (platform) { + case OPENAI: + return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + case OLLAMA: + return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); + case YI_YAN: + // TODO @xin:貌似 model 只要一设置,就报错;可以排查下 +// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + case XING_HUO: + return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) + .setMaxTokens(maxTokens); + case QIAN_WEN: + return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build(); + default: + throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); + } + } + + public static Message buildMessage(String type, String content) { + if (MessageType.USER.getValue().equals(type)) { + return new UserMessage(content); + } + if (MessageType.ASSISTANT.getValue().equals(type)) { + return new AssistantMessage(content); + } + if (MessageType.SYSTEM.getValue().equals(type)) { + return new SystemMessage(content); + } + if (MessageType.FUNCTION.getValue().equals(type)) { + return new FunctionMessage(content); + } + throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type)); + } + +} \ No newline at end of file diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/com/alibaba/cloud/ai/tongyi/chat/TongYiChatModel.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/com/alibaba/cloud/ai/tongyi/chat/TongYiChatModel.java index 2b3e290e8..c29ffbdfb 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/com/alibaba/cloud/ai/tongyi/chat/TongYiChatModel.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/com/alibaba/cloud/ai/tongyi/chat/TongYiChatModel.java @@ -236,6 +236,7 @@ public class TongYiChatModel extends .model(Generation.Models.QWEN_TURBO) // {@link GenerationOutput} .resultFormat(ConversationParam.ResultFormat.MESSAGE) + .incrementalOutput(true) .build();