Merge remote-tracking branch 'origin/master-jdk21-ai' into master-jdk21-ai

This commit is contained in:
cherishsince 2024-07-05 09:26:20 +08:00
commit c98495dfcb
29 changed files with 189 additions and 470 deletions

View File

@ -0,0 +1,16 @@
package cn.iocoder.yudao.module.ai.enums;
/**
* AI 字典类型的枚举类
*
* @author xiaoxin
*/
public interface DictTypeConstants {
// ========== AI Write ==========
String AI_WRITE_FORMAT = "ai_write_format"; // 写作格式
String AI_WRITE_LENGTH = "ai_write_length"; // 写作长度
String AI_WRITE_LANGUAGE = "ai_write_language"; // 写作语言
String AI_WRITE_TONE = "ai_write_tone"; // 写作语气
}

View File

@ -1,73 +0,0 @@
package cn.iocoder.yudao.module.ai.enums.model;
import lombok.AllArgsConstructor;
import lombok.Getter;
// TODO @芋艿可以考虑清理掉
/**
* ai 模型
*
* @author: fansili
* @time: 2024/3/4 12:36
*/
@Getter
@AllArgsConstructor
public enum AiModelEnum {
// open ai
OPEN_AI_GPT_3_5( "GPT3.5", "gpt-3.5-turbo",null),
OPEN_AI_GPT_4("GPT4", "gpt-4-turbo",null),
// 千问付费模型
QWEN_TURBO("通义千问超大规模语言模型", "qwen-turbo", null),
QWEN_PLUS("通义千问超大规模语言模型增强版", "qwen-plus", null),
QWEN_MAX("通义千问千亿级别超大规模语言模型", "qwen-max", null),
QWEN_MAX_0403("通义千问千亿级别超大规模语言模型-0403", "qwen-max-0403", null),
QWEN_MAX_0107("通义千问千亿级别超大规模语言模型-0107", "qwen-max-0107", null),
QWEN_MAX_1201("通义千问千亿级别超大规模语言模型-1201", "qwen-max-1201", null),
QWEN_MAX_LONGCONTEXT("通义千问千亿级别超大规模语言模型-28k tokens", "qwen-max-longcontext", null),
// 千问开源模型
// https://help.aliyun.com/document_detail/2666503.html?spm=a2c4g.2701795.0.0.26eb34dfKzcWN4
QWEN_72B_CHAT("通义千问1.5对外开源的72B规模参数量的经过人类指令对齐的chat模型", "qwen-72b-chat", null),
// 一言模型
ERNIE4_0("ERNIE 4.0", "ERNIE 4.0", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"),
ERNIE4_3_5_8K("ERNIE-3.5-8K", "ERNIE-3.5-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"),
ERNIE4_3_5_8K_0205("ERNIE-3.5-8K-0205", "ERNIE-3.5-8K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205"),
ERNIE4_3_5_8K_1222("ERNIE-3.5-8K-1222", "ERNIE-3.5-8K-1222", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222"),
ERNIE4_BOT_8K("ERNIE-Bot-8K", "ERNIE-Bot-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"),
ERNIE4_3_5_4K_0205("ERNIE-3.5-4K-0205", "ERNIE-3.5-4K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205"),
// 文档地址https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
// general指向V1.5版本;
// generalv2指向V2版本;
// generalv3指向V3版本;
// generalv3.5指向V3.5版本;
XING_HUO_1_5("星火大模型1.5", "general", "/v1.1/chat"),
XING_HUO_2_0("星火大模型2.0", "generalv2", "/v2.1/chat"),
XING_HUO_3_0("星火大模型3.0", "generalv3", "/v3.1/chat"),
XING_HUO_3_5("星火大模型3.5", "generalv3.5", "/v3.5/chat"),
// Suno 模型
SUNO_2( "SUNO-2", "chirp-v2-xxl-alpha",null),
SUNO_3_0( "SUNO-3.0", "chirp-v3-0",null),
SUNO_3_5( "SUNO-3.5", "chirp-v3.5",null),
;
/**
* 模型名字 - 用于展示
*/
private final String name;
/**
* 模型标志 - 用于参数传递
*/
private final String model;
/**
* uri地址
*/
private final String uri;
}

View File

@ -1,45 +0,0 @@
package cn.iocoder.yudao.module.ai.enums.write;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
// TODO @xin写作的几个不用枚举类哈直接搞字段就好了AiWriteTypeEnum 还是需要的哈
@AllArgsConstructor
@Getter
public enum AiLanguageEnum implements IntArrayValuable {
AUTO(1, "自动"),
CHINESE(2, "中文"),
ENGLISH(3, "英文"),
KOREAN(4, "韩语"),
JAPANESE(5, "日语");
/**
* Language code
*/
private final Integer language;
/**
* Language name
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiLanguageEnum::getLanguage).toArray();
@Override
public int[] array() {
return ARRAYS;
}
public static AiLanguageEnum valueOfLanguage(Integer language) {
for (AiLanguageEnum languageEnum : AiLanguageEnum.values()) {
if (languageEnum.getLanguage().equals(language)) {
return languageEnum;
}
}
throw new IllegalArgumentException("未知语言: " + language);
}
}

View File

@ -1,53 +0,0 @@
package cn.iocoder.yudao.module.ai.enums.write;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 写作类型的枚举
*
* @author xiaoxin
*/
@AllArgsConstructor
@Getter
public enum AiWriteFormatEnum implements IntArrayValuable {
AUTO(1, "自动"),
EMAIL(2, "电子邮件"),
MESSAGE(3, "消息"),
COMMENT(4, "评论"),
PARAGRAPH(5, "段落"),
ARTICLE(6, "文章"),
BLOG_POST(7, "博客文章"),
IDEA(8, "想法"),
OUTLINE(9, "大纲");
/**
* 格式
*/
private final Integer format;
/**
* 格式名
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteFormatEnum::getFormat).toArray();
@Override
public int[] array() {
return ARRAYS;
}
public static AiWriteFormatEnum valueOfFormat(Integer format) {
for (AiWriteFormatEnum formatEnum : AiWriteFormatEnum.values()) {
if (formatEnum.getFormat().equals(format)) {
return formatEnum;
}
}
throw new IllegalArgumentException("未知格式: " + format);
}
}

View File

@ -1,47 +0,0 @@
package cn.iocoder.yudao.module.ai.enums.write;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 写作类型的枚举
*
* @author xiaoxin
*/
@AllArgsConstructor
@Getter
public enum AiWriteLengthEnum implements IntArrayValuable {
AUTO(1, "自动"),
SHORT(2, ""),
MEDIUM(3, ""),
LONG(4, "");
/**
* 长度
*/
private final Integer length;
/**
* 长度名
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteLengthEnum::getLength).toArray();
@Override
public int[] array() {
return ARRAYS;
}
public static AiWriteLengthEnum valueOfLength(Integer length) {
for (AiWriteLengthEnum lengthEnum : AiWriteLengthEnum.values()) {
if (lengthEnum.getLength().equals(length)) {
return lengthEnum;
}
}
throw new IllegalArgumentException("未知长度: " + length);
}
}

View File

@ -1,46 +0,0 @@
package cn.iocoder.yudao.module.ai.enums.write;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
@AllArgsConstructor
@Getter
public enum AiWriteToneEnum implements IntArrayValuable {
AUTO(1, "自动"),
FRIENDLY(2, "友善"),
CASUAL(3, "随意"),
KIND(4, "友好"),
PROFESSIONAL(5, "专业"),
HUMOROUS(6, "谈谐"),
INTERESTING(7, "有趣"),
FORMAL(8, "正式");
/**
* 语气
*/
private final Integer tone;
/**
* 语气名
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteToneEnum::getTone).toArray();
@Override
public int[] array() {
return ARRAYS;
}
public static AiWriteToneEnum valueOfTone(Integer tone) {
for (AiWriteToneEnum toneEnum : AiWriteToneEnum.values()) {
if (toneEnum.getTone().equals(tone)) {
return toneEnum;
}
}
throw new IllegalArgumentException("未知语气: " + tone);
}
}

View File

@ -23,7 +23,6 @@ import org.springframework.web.bind.annotation.*;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@ -79,9 +78,8 @@ public class AiChatConversationController {
return success(true); return success(true);
} }
// TODO 芋艿这个 url 可以改下 @DeleteMapping("/delete-by-unpinned")
@DeleteMapping("/delete-my-all-except-pinned") @Operation(summary = "删除未置顶的聊天对话")
@Operation(summary = "删除所有对话(置顶除外)")
public CommonResult<Boolean> deleteChatConversationMyByUnpinned() { public CommonResult<Boolean> deleteChatConversationMyByUnpinned() {
chatConversationService.deleteChatConversationMyByUnpinned(getLoginUserId()); chatConversationService.deleteChatConversationMyByUnpinned(getLoginUserId());
return success(true); return success(true);

View File

@ -1,15 +1,13 @@
### 发送消息(段式)
### chat call POST {{baseUrl}}/ai/chat/message/send
POST {{baseUrl}}/admin-api/ai/chat/message/send
Content-Type: application/json Content-Type: application/json
Authorization: {{token}} Authorization: {{token}}
{ {
"conversationId": "1781604279872581649", "conversationId": "1781604279872581724",
"content": "你是 OpenAI 么?" "content": "你是 OpenAI 么?"
} }
### 发送消息(流式) ### 发送消息(流式)
POST {{baseUrl}}/ai/chat/message/send-stream POST {{baseUrl}}/ai/chat/message/send-stream
Content-Type: application/json Content-Type: application/json
@ -20,11 +18,10 @@ Authorization: {{token}}
"content": "1+1=?" "content": "1+1=?"
} }
### message list ### 获得指定对话的消息列表
GET {{baseUrl}}/admin-api/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649 GET {{baseUrl}}/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649
Authorization: {{token}} Authorization: {{token}}
### 删除消息
### message delete DELETE {{baseUrl}}/ai/chat/message/delete?id=50
DELETE {{baseUrl}}/admin-api/ai/chat/message/delete?id=50
Authorization: {{token}} Authorization: {{token}}

View File

@ -21,10 +21,10 @@ import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import jakarta.annotation.security.PermitAll; import jakarta.annotation.security.PermitAll;
import jakarta.validation.Valid;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -51,14 +51,14 @@ public class AiChatMessageController {
@Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢") @Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢")
@PostMapping("/send") @PostMapping("/send")
public CommonResult<AiChatMessageRespVO> sendMessage(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { public CommonResult<AiChatMessageSendRespVO> sendMessage(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
return success(chatMessageService.sendMessage(sendReqVO)); return success(chatMessageService.sendMessage(sendReqVO, getLoginUserId()));
} }
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题 @PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题
public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId()); return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId());
} }

View File

@ -28,7 +28,7 @@ public class AiChatMessageRespVO {
private Long roleId; private Long roleId;
@Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "gpt-3.5-turbo") @Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "gpt-3.5-turbo")
private String model; // 参见 AiOpenAiModelEnum 枚举类 private String model;
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "123") @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "123")
private Long modelId; private Long modelId;

View File

@ -31,14 +31,6 @@ public class AiChatMessageSendRespVO {
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED) @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime createTime; private LocalDateTime createTime;
// ========== 扩展字段 ==========
@Schema(description = "用户头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://iocoder.cn/1.png")
private String userAvatar;
@Schema(description = "角色头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://iocoder.cn/2.png")
private String roleAvatar;
} }
} }

View File

@ -1,4 +1,3 @@
### 生成图片OpenAIDALL ### 生成图片OpenAIDALL
POST {{baseUrl}}/ai/image/draw POST {{baseUrl}}/ai/image/draw
Content-Type: application/json Content-Type: application/json
@ -29,8 +28,7 @@ Authorization: {{token}}
"style": "vivid" "style": "vivid"
} }
### 生成图片:生成图片 ### 生成图片生成图片Midjourney
POST {{baseUrl}}/ai/image/midjourney/imagine POST {{baseUrl}}/ai/image/midjourney/imagine
Content-Type: application/json Content-Type: application/json
Authorization: {{token}} Authorization: {{token}}
@ -40,6 +38,5 @@ Authorization: {{token}}
"model": "midjourney", "model": "midjourney",
"width": "1", "width": "1",
"height": "1", "height": "1",
"version": "6.0", "version": "6.0"
"base64Array": []
} }

View File

@ -68,7 +68,7 @@ public class AiImageController {
@Operation(summary = "生成图片") @Operation(summary = "生成图片")
@PostMapping("/draw") @PostMapping("/draw")
public CommonResult<Long> drawImage(@Validated @RequestBody AiImageDrawReqVO drawReqVO) { public CommonResult<Long> drawImage(@Valid @RequestBody AiImageDrawReqVO drawReqVO) {
return success(imageService.drawImage(getLoginUserId(), drawReqVO)); return success(imageService.drawImage(getLoginUserId(), drawReqVO));
} }
@ -84,7 +84,7 @@ public class AiImageController {
@Operation(summary = "【Midjourney】生成图片") @Operation(summary = "【Midjourney】生成图片")
@PostMapping("/midjourney/imagine") @PostMapping("/midjourney/imagine")
public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiMidjourneyImagineReqVO reqVO) { public CommonResult<Long> midjourneyImagine(@Valid @RequestBody AiMidjourneyImagineReqVO reqVO) {
Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO); Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO);
return success(imageId); return success(imageId);
} }
@ -92,14 +92,14 @@ public class AiImageController {
@Operation(summary = "【Midjourney】通知图片进展", description = "由 Midjourney Proxy 回调") @Operation(summary = "【Midjourney】通知图片进展", description = "由 Midjourney Proxy 回调")
@PostMapping("/midjourney/notify") // 必须是 POST 方法否则会报错 @PostMapping("/midjourney/notify") // 必须是 POST 方法否则会报错
@PermitAll @PermitAll
public CommonResult<Boolean> midjourneyNotify(@Validated @RequestBody MidjourneyApi.Notify notify) { public CommonResult<Boolean> midjourneyNotify(@Valid @RequestBody MidjourneyApi.Notify notify) {
imageService.midjourneyNotify(notify); imageService.midjourneyNotify(notify);
return success(true); return success(true);
} }
@Operation(summary = "【Midjourney】Action 操作(二次生成图片)", description = "例如说放大、缩小、U1、U2 等") @Operation(summary = "【Midjourney】Action 操作(二次生成图片)", description = "例如说放大、缩小、U1、U2 等")
@PostMapping("/midjourney/action") @PostMapping("/midjourney/action")
public CommonResult<Long> midjourneyAction(@Validated @RequestBody AiMidjourneyActionReqVO reqVO) { public CommonResult<Long> midjourneyAction(@Valid @RequestBody AiMidjourneyActionReqVO reqVO) {
Long imageId = imageService.midjourneyAction(getLoginUserId(), reqVO); Long imageId = imageService.midjourneyAction(getLoginUserId(), reqVO);
return success(imageId); return success(imageId);
} }

View File

@ -6,7 +6,7 @@ Authorization: {{token}}
{ {
"platform": "Suno", "platform": "Suno",
"generateMode": 2, "generateMode": 2,
"prompt": "周末啦!", "prompt": "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。",
"model": "chirp-v3.5", "model": "chirp-v3.5",
"tags": ["Happy"], "tags": ["Happy"],
"title": "Happy Song" "title": "Happy Song"
@ -21,6 +21,6 @@ Authorization: {{token}}
"platform": "Suno", "platform": "Suno",
"generateMode": 1, "generateMode": 1,
"model": "chirp-v3.5", "model": "chirp-v3.5",
"gptDescriptionPrompt": "今天是星球六,结果是个下雨天,希望心情很美丽", "prompt": "happy music",
"makeInstrumental": false "makeInstrumental": false
} }

View File

@ -46,7 +46,7 @@ public class AiSunoGenerateReqVO {
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "chirp-v3.5") @Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "chirp-v3.5")
@NotEmpty(message = "模型不能为空") @NotEmpty(message = "模型不能为空")
private String model; // 参见 AiModelEnum 枚举 private String model;
@Schema(description = "音乐风格", example = "[\"pop\",\"jazz\",\"punk\"]") @Schema(description = "音乐风格", example = "[\"pop\",\"jazz\",\"punk\"]")
private List<String> tags; private List<String> tags;

View File

@ -26,9 +26,10 @@ public class AiWriteController {
private AiWriteService writeService; private AiWriteService writeService;
@PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@PermitAll
@Operation(summary = "写作生成(流式)", description = "流式返回,响应较快") @Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
@PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题
public Flux<CommonResult<String>> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) { public Flux<CommonResult<String>> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
return writeService.generateWriteContent(generateReqVO, getLoginUserId()); return writeService.generateWriteContent(generateReqVO, getLoginUserId());
} }
} }

View File

@ -14,11 +14,10 @@ public class AiWriteGenerateReqVO {
@InEnum(AiWriteTypeEnum.class) @InEnum(AiWriteTypeEnum.class)
private Integer type; private Integer type;
// TODO @xin如果非必填可以不用写 requiredMode @Schema(description = "写作内容提示", example = "1.撰写田忌赛马2.回复:不批")
@Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1.撰写田忌赛马2.回复:不批")
private String prompt; private String prompt;
@Schema(description = "原文", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "领导我要辞职") @Schema(description = "原文", example = "领导我要辞职")
private String originalContent; private String originalContent;
@Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")

View File

@ -3,7 +3,6 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.enums.model.AiModelEnum;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
@ -73,8 +72,6 @@ public class AiChatConversationDO extends BaseDO {
private Long modelId; private Long modelId;
/** /**
* 模型标志 * 模型标志
*
* 枚举 {@link AiModelEnum}
*/ */
private String model; private String model;

View File

@ -5,7 +5,6 @@ import org.springframework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.enums.model.AiModelEnum;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import lombok.*; import lombok.*;
@ -69,8 +68,6 @@ public class AiChatMessageDO extends BaseDO {
/** /**
* 模型标志 * 模型标志
*
* 枚举 {@link AiModelEnum}
*/ */
private String model; private String model;
/** /**

View File

@ -2,8 +2,9 @@ package cn.iocoder.yudao.module.ai.service.chat;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -22,9 +23,10 @@ public interface AiChatMessageService {
* 发送消息 * 发送消息
* *
* @param sendReqVO 发送信息 * @param sendReqVO 发送信息
* @param userId 用户编号
* @return 发送结果 * @return 发送结果
*/ */
AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO sendReqVO); AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId);
/** /**
* 发送消息 * 发送消息

View File

@ -4,35 +4,28 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.qianfan.QianFanChatOptions;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -64,47 +57,37 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Resource @Resource
private AiChatModelService chatModalService; private AiChatModelService chatModalService;
@Resource @Resource
private AiChatRoleService chatRoleService;
@Resource
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO req) { public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
return null; // TODO 芋艿一起改 // 1.1 校验对话存在
// Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
// // 查询对话 if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
// AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId()); throw exception(CHAT_CONVERSATION_NOT_EXISTS);
// // 获取对话模型 }
// AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId()); List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// // 获取角色信息 // 1.2 校验模型
// AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
// // 获取 client 类型 ChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
// AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
// // 保存 chat message // 2. 插入 user 发送消息
// createChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
// chatModel.getModel(), chatModel.getId(), req.getContent()); userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
// String content = null;
// int tokens = 0; // 3.1 插入 assistant 接收消息
// try { AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
// // 创建 chat 需要的 Prompt userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// Prompt prompt = new Prompt(req.getContent());
// // TODO @芋艿 @范 看要不要支持这些 // 3.2 创建 chat 需要的 Prompt
//// req.setTopK(req.getTopK()); Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
//// req.setTopP(req.getTopP()); ChatResponse chatResponse = chatClient.call(prompt);
//// req.setTemperature(req.getTemperature());
// // 发送 call 调用 // 3.3 段式返回
// ChatClient chatClient = chatClientFactory.getChatClient(platformEnum); String newContent = chatResponse.getResult().getOutput().getContent();
// ChatResponse call = chatClient.call(prompt); chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
// content = call.getResult().getOutput().getContent(); return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
// // 更新 conversation .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
// } catch (Exception e) {
// content = ExceptionUtil.getMessage(e);
// } finally {
// // 保存 chat message
// createChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
// chatModel.getModel(), chatModel.getId(), content);
// }
// return new AiChatMessageRespVO().setContent(content);
} }
@Override @Override
@ -112,14 +95,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 1.1 校验对话存在 // 1.1 校验对话存在
AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId()); AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
if (ObjUtil.notEqual(conversation.getUserId(), userId)) { if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿异常情况的对接 throw exception(CHAT_CONVERSATION_NOT_EXISTS);
} }
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型 // 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId()); StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
// 1.3 获取用户头像角色头像
AiChatRoleDO role = conversation.getRoleId() != null ? chatRoleService.getChatRole(conversation.getRoleId()) : null;
// 2. 插入 user 发送消息 // 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -149,9 +130,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// TODO @芋艿失败的情况下要不要删除消息 // TODO @芋艿失败的情况下要不要删除消息
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable); log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())); chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage()));
}).onErrorResume(error -> { }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
return Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR));
});
} }
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
@ -164,46 +143,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
// 1.2 history message 历史消息 // 1.2 history message 历史消息
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO); List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> { contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
// TODO @芋艿看看有没优化空间
if (MessageType.USER.getValue().equals(message.getType())) {
chatMessages.add(new UserMessage(message.getContent()));
} else {
chatMessages.add(new AssistantMessage(message.getContent()));
}
});
// 1.3 user message 新发送消息 // 1.3 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent())); chatMessages.add(new UserMessage(sendReqVO.getContent()));
// 2. 构建 ChatOptions 对象 // 2. 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
conversation.getTemperature(), conversation.getMaxTokens()); conversation.getTemperature(), conversation.getMaxTokens());
return new Prompt(chatMessages, chatOptions); return new Prompt(chatMessages, chatOptions);
} }
private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
Float temperatureF = temperature != null ? temperature.floatValue() : null;
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN:
// TODO 芋艿貌似 model 只要一设置就报错
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens);
case QIAN_WEN:
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
/** /**
* 从历史消息中获得倒序的 n 组消息作为消息上下文 * 从历史消息中获得倒序的 n 组消息作为消息上下文
* *

View File

@ -8,7 +8,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageR
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import java.util.List; import java.util.List;
@ -76,12 +76,12 @@ public interface AiApiKeyService {
// ========== spring-ai 集成 ========== // ========== spring-ai 集成 ==========
/** /**
* 获得 StreamingChatClient 对象 * 获得 ChatModel 对象
* *
* @param id 编号 * @param id 编号
* @return StreamingChatClient 对象 * @return ChatModel 对象
*/ */
StreamingChatModel getStreamingChatClient(Long id); ChatModel getChatClient(Long id);
/** /**
* 获得 ImageClient 对象 * 获得 ImageClient 对象

View File

@ -12,7 +12,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
@ -98,10 +98,10 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
// ========== spring-ai 集成 ========== // ========== spring-ai 集成 ==========
@Override @Override
public StreamingChatModel getStreamingChatClient(Long id) { public ChatModel getChatClient(Long id) {
AiApiKeyDO apiKey = validateApiKey(id); AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); return clientFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
} }
@Override @Override

View File

@ -21,6 +21,7 @@ import cn.iocoder.yudao.module.infra.api.file.FileApi;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.*; import java.util.*;
@ -49,10 +50,10 @@ public class AiMusicServiceImpl implements AiMusicService {
private FileApi fileApi; private FileApi fileApi;
@Override @Override
@Transactional(rollbackFor = Exception.class)
public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) { public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
// 1. 调用 Suno 生成音乐 // 1. 调用 Suno 生成音乐
SunoApi sunoApi = apiKeyService.getSunoApi(); SunoApi sunoApi = apiKeyService.getSunoApi();
// TODO 芋艿这两个貌似一直没跑成功你那可以么用的请求是 AiMusicController.http --xin大部分ok的补充了error_message
List<SunoApi.MusicData> musicDataList; List<SunoApi.MusicData> musicDataList;
if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) { if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
// 1.1 描述模式 // 1.1 描述模式
@ -164,14 +165,9 @@ public class AiMusicServiceImpl implements AiMusicService {
*/ */
private List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicList) { private List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicList) {
return convertList(musicList, musicData -> { return convertList(musicList, musicData -> {
Integer status; Integer status = Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.SUCCESS.getStatus()
if (Objects.equals("complete", musicData.status())) { : Objects.equals("error", musicData.status()) ? AiMusicStatusEnum.FAIL.getStatus()
status = AiMusicStatusEnum.SUCCESS.getStatus(); : AiMusicStatusEnum.IN_PROGRESS.getStatus();
} else if (Objects.equals("error", musicData.status())) {
status = AiMusicStatusEnum.FAIL.getStatus();
} else {
status = AiMusicStatusEnum.IN_PROGRESS.getStatus();
}
return new AiMusicDO() return new AiMusicDO()
.setTaskId(musicData.id()).setModel(musicData.modelName()) .setTaskId(musicData.id()).setModel(musicData.modelName())
.setDescription(musicData.gptDescriptionPrompt()) .setDescription(musicData.gptDescriptionPrompt())

View File

@ -2,28 +2,25 @@ package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.enums.write.*; import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.qianfan.QianFanChatOptions;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -45,25 +42,30 @@ public class AiWriteServiceImpl implements AiWriteService {
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Resource @Resource
private AiChatModelService chatModalService; private AiChatModelService chatModalService;
@Resource @Resource
private AiWriteMapper writeMapper; // TODO @xin上面空一行因为同类之间不要空行非同类空行 private DictDataApi dictDataApi;
@Resource
private AiWriteMapper writeMapper;
@Override @Override
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
// 1.1 校验模型 // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok那可以有限拿 chatRole 的角色如果没有则获取默认的
// TODO @xin可以约定大于配置先查询某个名字例如说写作助手然后写作助手上面是有个 model 可以使用它 AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
AiChatModelDO model = chatModalService.validateChatModel(14L); StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
// 1.2 插入写作信息 // 1.2 插入写作信息
// TODO @xin建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())写在 toBean consumer 原因是让这个 set 保持完整性
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class); AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
// 2.1 构建提示词 // 2.1 构建提示词
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
Flux<ChatResponse> streamResponse = chatClient.stream(prompt); Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
// 2.2 流式返回 // 2.2 流式返回
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> { return streamResponse.map(chunk -> {
@ -83,11 +85,13 @@ public class AiWriteServiceImpl implements AiWriteService {
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
String template; String template;
Integer writeType = generateReqVO.getType(); Integer writeType = generateReqVO.getType();
String format = AiWriteFormatEnum.valueOfFormat(generateReqVO.getFormat()).getName(); String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = AiWriteToneEnum.valueOfTone(generateReqVO.getTone()).getName(); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat());
String language = AiLanguageEnum.valueOfLanguage(generateReqVO.getLanguage()).getName(); String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat());
String length = AiWriteLengthEnum.valueOfLength(generateReqVO.getLength()).getName(); String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat());
// TODO @xin建议改成 if return 更简洁
if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) { if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
// TODO @xin写成静态枚举哈
template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length); return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
} else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) { } else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
@ -98,27 +102,4 @@ public class AiWriteServiceImpl implements AiWriteService {
} }
} }
// TODO 芋艿复用
private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
Float temperatureF = temperature != null ? temperature.floatValue() : null;
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN:
// TODO 芋艿貌似 model 只要一设置就报错
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens);
case QIAN_WEN:
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
} }

View File

@ -3,7 +3,7 @@ package cn.iocoder.yudao.framework.ai.core.factory;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
/** /**
@ -14,26 +14,26 @@ import org.springframework.ai.image.ImageModel;
public interface AiClientFactory { public interface AiClientFactory {
/** /**
* 基于指定配置获得 StreamingChatClient 对象 * 基于指定配置获得 ChatModel 对象
* *
* 如果不存在则进行创建 * 如果不存在则进行创建
* *
* @param platform 平台 * @param platform 平台
* @param apiKey API KEY * @param apiKey API KEY
* @param url API URL * @param url API URL
* @return StreamingChatClient 对象 * @return ChatModel 对象
*/ */
StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url); ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url);
/** /**
* 基于默认配置获得 StreamingChatClient 对象 * 基于默认配置获得 ChatModel 对象
* *
* 默认配置指的是在 application.yaml 配置文件中的 spring.ai 相关的配置 * 默认配置指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
* *
* @param platform 平台 * @param platform 平台
* @return StreamingChatClient 对象 * @return ChatModel 对象
*/ */
StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform); ChatModel getDefaultChatClient(AiPlatformEnum platform);
/** /**
* 基于默认配置获得 ImageClient 对象 * 基于默认配置获得 ImageClient 对象

View File

@ -23,7 +23,7 @@ import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties; import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties; import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaApi;
@ -50,9 +50,9 @@ import java.util.List;
public class AiClientFactoryImpl implements AiClientFactory { public class AiClientFactoryImpl implements AiClientFactory {
@Override @Override
public StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) { public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(StreamingChatModel.class, platform, apiKey, url); String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<StreamingChatModel>) () -> { return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OPENAI: case OPENAI:
@ -74,7 +74,7 @@ public class AiClientFactoryImpl implements AiClientFactory {
} }
@Override @Override
public StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform) { public ChatModel getDefaultChatClient(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case OPENAI: case OPENAI:

View File

@ -0,0 +1,59 @@
package cn.iocoder.yudao.framework.ai.core.util;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.qianfan.QianFanChatOptions;
/**
* Spring AI 工具类
*
* @author 芋道源码
*/
public class AiUtils {
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
Float temperatureF = temperature != null ? temperature.floatValue() : null;
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN:
// TODO @xin貌似 model 只要一设置就报错可以排查下
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens);
case QIAN_WEN:
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
public static Message buildMessage(String type, String content) {
if (MessageType.USER.getValue().equals(type)) {
return new UserMessage(content);
}
if (MessageType.ASSISTANT.getValue().equals(type)) {
return new AssistantMessage(content);
}
if (MessageType.SYSTEM.getValue().equals(type)) {
return new SystemMessage(content);
}
if (MessageType.FUNCTION.getValue().equals(type)) {
return new FunctionMessage(content);
}
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
}
}

View File

@ -236,6 +236,7 @@ public class TongYiChatModel extends
.model(Generation.Models.QWEN_TURBO) .model(Generation.Models.QWEN_TURBO)
// {@link GenerationOutput} // {@link GenerationOutput}
.resultFormat(ConversationParam.ResultFormat.MESSAGE) .resultFormat(ConversationParam.ResultFormat.MESSAGE)
.incrementalOutput(true)
.build(); .build();