diff --git a/yudao-module-ai/pom.xml b/yudao-module-ai/pom.xml index 7135100d7..69a5e987f 100644 --- a/yudao-module-ai/pom.xml +++ b/yudao-module-ai/pom.xml @@ -18,7 +18,7 @@ ${project.artifactId} - ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维脑图等功能。 + ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维导图等功能。 目前已接入各种模型,不限于: 国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek 国外:OpenAI、Ollama、Midjourney、StableDiffusion、Suno diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java index 19cbc8f8f..811189efe 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java @@ -22,7 +22,7 @@ public enum AiChatRoleEnum implements IntArrayValuable { 除此之外不需要除了正文内容外的其他回复,如标题、开头、任何解释性语句或道歉。 """), - AI_MIND_MAP_ROLE(2, "脑图助手", """ + AI_MIND_MAP_ROLE(2, "导图助手", """ 你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子: # Geek-AI 助手 ## 完整的开源系统 diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java index ddfb489f3..50934e7a0 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java @@ -45,9 +45,11 @@ public interface ErrorCodeConstants { // ========== API 音乐 1-040-006-000 ========== ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!"); - // ========== API 写作 1-022-007-000 ========== ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!"); ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "写作生成异常!"); + // ========== API 思维导图 1-040-008-000 ========== + ErrorCode MIND_MAP_NOT_EXISTS = new ErrorCode(1_040_008_000, "思维导图不存在!"); + } diff --git a/yudao-module-ai/yudao-module-ai-biz/pom.xml b/yudao-module-ai/yudao-module-ai-biz/pom.xml index 7c529f118..ec6f8c762 100644 --- a/yudao-module-ai/yudao-module-ai-biz/pom.xml +++ b/yudao-module-ai/yudao-module-ai-biz/pom.xml @@ -12,7 +12,7 @@ ${project.artifactId} - ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维脑图等功能。 + ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维导图等功能。 目前已接入各种模型,不限于: 国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek 国外:OpenAI、Ollama、Midjourney、StableDiffusion、Suno diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java index fd3cdf786..4cd8e55c2 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java @@ -5,10 +5,7 @@ import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; @@ -45,6 +42,13 @@ public class AiImageController { return success(BeanUtils.toBean(pageResult, AiImageRespVO.class)); } + @GetMapping("/public-page") + @Operation(summary = "获取公开的绘图分页") + public CommonResult> getImagePagePublic(AiImagePublicPageReqVO pageReqVO) { + PageResult pageResult = imageService.getImagePagePublic(pageReqVO); + return success(BeanUtils.toBean(pageResult, AiImageRespVO.class)); + } + @GetMapping("/get-my") @Operation(summary = "获取【我的】绘图记录") @Parameter(name = "id", required = true, description = "绘画编号", example = "1024") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImagePublicPageReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImagePublicPageReqVO.java new file mode 100644 index 000000000..e7ff80a98 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImagePublicPageReqVO.java @@ -0,0 +1,14 @@ +package cn.iocoder.yudao.module.ai.controller.admin.image.vo; + +import cn.iocoder.yudao.framework.common.pojo.PageParam; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +@Schema(description = "管理后台 - AI 绘画公开的分页 Request VO") +@Data +public class AiImagePublicPageReqVO extends PageParam { + + @Schema(description = "提示词") + private String prompt; + +} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/AiMindMapController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/AiMindMapController.java index 015180265..f1c59b964 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/AiMindMapController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/AiMindMapController.java @@ -1,20 +1,25 @@ package cn.iocoder.yudao.module.ai.controller.admin.mindmap; import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.framework.common.pojo.PageResult; +import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapRespVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; import cn.iocoder.yudao.module.ai.service.mindmap.AiMindMapService; import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.annotation.Resource; import jakarta.annotation.security.PermitAll; import jakarta.validation.Valid; import org.springframework.http.MediaType; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.web.bind.annotation.*; import reactor.core.publisher.Flux; +import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; @Tag(name = "管理后台 - AI 思维导图") @@ -26,10 +31,29 @@ public class AiMindMapController { private AiMindMapService mindMapService; @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - @Operation(summary = "脑图生成(流式)", description = "流式返回,响应较快") + @Operation(summary = "导图生成(流式)", description = "流式返回,响应较快") @PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题 public Flux> generateMindMap(@RequestBody @Valid AiMindMapGenerateReqVO generateReqVO) { return mindMapService.generateMindMap(generateReqVO, getLoginUserId()); } + // ================ 导图管理 ================ + + @DeleteMapping("/delete") + @Operation(summary = "删除思维导图") + @Parameter(name = "id", description = "编号", required = true) + @PreAuthorize("@ss.hasPermission('ai:mind-map:delete')") + public CommonResult deleteMindMap(@RequestParam("id") Long id) { + mindMapService.deleteMindMap(id); + return success(true); + } + + @GetMapping("/page") + @Operation(summary = "获得思维导图分页") + @PreAuthorize("@ss.hasPermission('ai:mind-map:query')") + public CommonResult> getMindMapPage(@Valid AiMindMapPageReqVO pageReqVO) { + PageResult pageResult = mindMapService.getMindMapPage(pageReqVO); + return success(BeanUtils.toBean(pageResult, AiMindMapRespVO.class)); + } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/vo/AiMindMapPageReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/vo/AiMindMapPageReqVO.java new file mode 100644 index 000000000..c123ab70e --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/vo/AiMindMapPageReqVO.java @@ -0,0 +1,30 @@ +package cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo; + +import cn.iocoder.yudao.framework.common.pojo.PageParam; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.ToString; +import org.springframework.format.annotation.DateTimeFormat; + +import java.time.LocalDateTime; + +import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND; + +@Schema(description = "管理后台 - AI 思维导图分页 Request VO") +@Data +@EqualsAndHashCode(callSuper = true) +@ToString(callSuper = true) +public class AiMindMapPageReqVO extends PageParam { + + @Schema(description = "用户编号", example = "4325") + private Long userId; + + @Schema(description = "生成内容提示", example = "Java 学习路线") + private String prompt; + + @Schema(description = "创建时间") + @DateTimeFormat(pattern = FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND) + private LocalDateTime[] createTime; + +} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/vo/AiMindMapRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/vo/AiMindMapRespVO.java new file mode 100644 index 000000000..f65e809e9 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/vo/AiMindMapRespVO.java @@ -0,0 +1,36 @@ +package cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +import java.time.LocalDateTime; + +@Schema(description = "管理后台 - AI 思维导图 Response VO") +@Data +public class AiMindMapRespVO { + + @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "3373") + private Long id; + + @Schema(description = "用户编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "4325") + private Long userId; + + @Schema(description = "生成内容提示", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 学习路线") + private String prompt; + + @Schema(description = "生成的思维导图内容") + private String generatedContent; + + @Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI") + private String platform; + + @Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "gpt-3.5-turbo-0125") + private String model; + + @Schema(description = "错误信息") + private String errorMessage; + + @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED) + private LocalDateTime createTime; + +} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/image/AiImageMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/image/AiImageMapper.java index 847a5c204..f87c9472b 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/image/AiImageMapper.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/image/AiImageMapper.java @@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import org.apache.ibatis.annotations.Mapper; @@ -41,6 +42,13 @@ public interface AiImageMapper extends BaseMapperX { .orderByDesc(AiImageDO::getId)); } + default PageResult selectPage(AiImagePublicPageReqVO pageReqVO) { + return selectPage(pageReqVO, new LambdaQueryWrapperX() + .eqIfPresent(AiImageDO::getPublicStatus, Boolean.TRUE) + .likeIfPresent(AiImageDO::getPrompt, pageReqVO.getPrompt()) + .orderByDesc(AiImageDO::getId)); + } + default List selectListByStatusAndPlatform(Integer status, String platform) { return selectList(AiImageDO::getStatus, status, AiImageDO::getPlatform, platform); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java index ff25e89ff..0292ef473 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java @@ -1,6 +1,9 @@ package cn.iocoder.yudao.module.ai.dal.mysql.mindmap; +import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; +import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; +import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; import org.apache.ibatis.annotations.Mapper; @@ -11,4 +14,13 @@ import org.apache.ibatis.annotations.Mapper; */ @Mapper public interface AiMindMapMapper extends BaseMapperX { + + default PageResult selectPage(AiMindMapPageReqVO reqVO) { + return selectPage(reqVO, new LambdaQueryWrapperX() + .eqIfPresent(AiMindMapDO::getUserId, reqVO.getUserId()) + .eqIfPresent(AiMindMapDO::getPrompt, reqVO.getPrompt()) + .betweenIfPresent(AiMindMapDO::getCreateTime, reqVO.getCreateTime()) + .orderByDesc(AiMindMapDO::getId)); + } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java index b4d76b315..fcc40657b 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java @@ -2,9 +2,7 @@ package cn.iocoder.yudao.module.ai.service.image; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; @@ -28,6 +26,14 @@ public interface AiImageService { */ PageResult getImagePageMy(Long userId, AiImagePageReqVO pageReqVO); + /** + * 获取公开的绘图分页 + * + * @param pageReqVO 分页条件 + * @return 绘图分页 + */ + PageResult getImagePagePublic(AiImagePublicPageReqVO pageReqVO); + /** * 获得绘图记录 * diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 020546ae9..e8532a576 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -12,9 +12,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; @@ -70,6 +68,11 @@ public class AiImageServiceImpl implements AiImageService { return imageMapper.selectPageMy(userId, pageReqVO); } + @Override + public PageResult getImagePagePublic(AiImagePublicPageReqVO pageReqVO) { + return imageMapper.selectPage(pageReqVO); + } + @Override public AiImageDO getImage(Long id) { return imageMapper.selectById(id); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/DocService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/DocService.java new file mode 100644 index 000000000..47905d4b1 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/DocService.java @@ -0,0 +1,15 @@ +package cn.iocoder.yudao.module.ai.service.knowledge; + +/** + * AI 知识库 Service 接口 + * + * @author xiaoxin + */ +public interface DocService { + + /** + * 向量化文档 + */ + void embeddingDoc(); + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/DocServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/DocServiceImpl.java new file mode 100644 index 000000000..7ba5018bd --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/DocServiceImpl.java @@ -0,0 +1,42 @@ +package cn.iocoder.yudao.module.ai.service.knowledge; + +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.document.Document; +import org.springframework.ai.reader.tika.TikaDocumentReader; +import org.springframework.ai.transformer.splitter.TokenTextSplitter; +import org.springframework.ai.vectorstore.RedisVectorStore; +import org.springframework.beans.factory.annotation.Value; + +import java.util.List; + +/** + * AI 知识库 Service 实现类 + * + * @author xiaoxin + */ +//@Service // TODO 芋艿:临时注释,避免无法启动 +@Slf4j +public class DocServiceImpl implements DocService { + + @Resource + private RedisVectorStore vectorStore; + @Resource + private TokenTextSplitter tokenTextSplitter; + + // TODO @xin 临时测试用,后续删 + @Value("classpath:/webapp/test/Fel.pdf") + private org.springframework.core.io.Resource data; + + @Override + public void embeddingDoc() { + // 读取文件 + TikaDocumentReader loader = new TikaDocumentReader(data); + List documents = loader.get(); + // 文档分段 + List segments = tokenTextSplitter.apply(documents); + // 向量化并存储 + vectorStore.add(segments); + } + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapService.java index 2eb1f1b1a..65a5aaf3a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapService.java @@ -1,7 +1,10 @@ package cn.iocoder.yudao.module.ai.service.mindmap; import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; import reactor.core.publisher.Flux; /** @@ -20,4 +23,19 @@ public interface AiMindMapService { */ Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId); + /** + * 删除思维导图 + * + * @param id 编号 + */ + void deleteMindMap(Long id); + + /** + * 获得思维导图分页 + * + * @param pageReqVO 分页查询 + * @return 思维导图分页 + */ + PageResult getMindMapPage(AiMindMapPageReqVO pageReqVO); + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java index df5296df2..b34bd6348 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java @@ -6,9 +6,11 @@ import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; @@ -33,8 +35,10 @@ import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; +import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; +import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_EXISTS; /** * AI 思维导图 Service 实现类 @@ -57,10 +61,10 @@ public class AiMindMapServiceImpl implements AiMindMapService { @Override public Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { - // 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型 + // 1. 获取导图模型。尝试获取思维导图助手角色,如果没有则使用默认模型 AiChatRoleDO role = CollUtil.getFirst( chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); - // 1.1 获取脑图执行模型 + // 1.1 获取导图执行模型 AiChatModelDO model = getModel(role); // 1.2 获取角色设定消息 String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage()) @@ -131,4 +135,23 @@ public class AiMindMapServiceImpl implements AiMindMapService { return model; } + @Override + public void deleteMindMap(Long id) { + // 校验存在 + validateMindMapExists(id); + // 删除 + mindMapMapper.deleteById(id); + } + + private void validateMindMapExists(Long id) { + if (mindMapMapper.selectById(id) == null) { + throw exception(MIND_MAP_NOT_EXISTS); + } + } + + @Override + public PageResult getMindMapPage(AiMindMapPageReqVO pageReqVO) { + return mindMapMapper.selectPage(pageReqVO); + } + } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml b/yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml index 7354f5008..5bc2e383a 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml @@ -23,12 +23,16 @@ spring-ai-zhipuai-spring-boot-starter ${spring-ai.version} - org.springframework.ai spring-ai-openai-spring-boot-starter ${spring-ai.version} + + org.springframework.ai + spring-ai-azure-openai-spring-boot-starter + ${spring-ai.version} + org.springframework.ai spring-ai-ollama-spring-boot-starter @@ -40,6 +44,30 @@ ${spring-ai.version} + + + org.springframework.ai + spring-ai-transformers-spring-boot-starter + ${spring-ai.version} + + + org.springframework.ai + spring-ai-tika-document-reader + ${spring-ai.version} + + + org.springframework.ai + spring-ai-redis-store + ${spring-ai.version} + + + + + org.springframework.data + spring-data-redis + true + + cn.iocoder.boot yudao-common diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java index 05a317294..543444fdd 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java @@ -10,11 +10,20 @@ import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions; import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration; import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.transformer.splitter.TokenTextSplitter; +import org.springframework.ai.transformers.TransformersEmbeddingModel; +import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.data.redis.RedisProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.context.annotation.Lazy; +import redis.clients.jedis.JedisPooled; /** * 芋道 AI 自动配置 @@ -73,4 +82,36 @@ public class YudaoAiAutoConfiguration { return new SunoApi(yudaoAiProperties.getSuno().getBaseUrl()); } + // ========== rag 相关 ========== + @Bean + @Lazy // TODO 芋艿:临时注释,避免无法启动 + public EmbeddingModel transformersEmbeddingClient() { + return new TransformersEmbeddingModel(MetadataMode.EMBED); + } + + /** + * 我们启动有加载很多 Embedding 模型,不晓得取哪个好,先 new 个 TransformersEmbeddingModel 跑 + */ + @Bean + @Lazy // TODO 芋艿:临时注释,避免无法启动 + public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties, + RedisProperties redisProperties) { + var config = RedisVectorStore.RedisVectorStoreConfig.builder() + .withIndexName(properties.getIndex()) + .withPrefix(properties.getPrefix()) + .build(); + + RedisVectorStore redisVectorStore = new RedisVectorStore(config, transformersEmbeddingModel, + new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), + properties.isInitializeSchema()); + redisVectorStore.afterPropertiesSet(); + return redisVectorStore; + } + + @Bean + @Lazy // TODO 芋艿:临时注释,避免无法启动 + public TokenTextSplitter tokenTextSplitter() { + return new TokenTextSplitter(500, 100, 5, 10000, true); + } + } \ No newline at end of file diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java index 596118168..1922e9a2c 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java @@ -22,7 +22,8 @@ public enum AiPlatformEnum { // ========== 国外平台 ========== - OPENAI("OpenAI", "OpenAI"), + OPENAI("OpenAI", "OpenAI"), // OpenAI 官方 + AZURE_OPENAI("AzureOpenAI", "AzureOpenAI"), // OpenAI 微软 OLLAMA("Ollama", "Ollama"), STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index a5df28246..c9b04dc1e 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -21,6 +21,10 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel; import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties; import com.alibaba.dashscope.aigc.generation.Generation; import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; +import com.azure.ai.openai.OpenAIClient; +import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; +import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties; +import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; @@ -31,6 +35,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties; +import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -82,6 +87,8 @@ public class AiModelFactoryImpl implements AiModelFactory { return buildXingHuoChatModel(apiKey); case OPENAI: return buildOpenAiChatModel(apiKey, url); + case AZURE_OPENAI: + return buildAzureOpenAiChatModel(apiKey, url); case OLLAMA: return buildOllamaChatModel(url); default: @@ -106,6 +113,8 @@ public class AiModelFactoryImpl implements AiModelFactory { return SpringUtil.getBean(XingHuoChatModel.class); case OPENAI: return SpringUtil.getBean(OpenAiChatModel.class); + case AZURE_OPENAI: + return SpringUtil.getBean(AzureOpenAiChatModel.class); case OLLAMA: return SpringUtil.getBean(OllamaChatModel.class); default: @@ -268,6 +277,21 @@ public class AiModelFactoryImpl implements AiModelFactory { return new OpenAiChatModel(openAiApi); } + /** + * 可参考 {@link AzureOpenAiAutoConfiguration} + */ + private static AzureOpenAiChatModel buildAzureOpenAiChatModel(String apiKey, String url) { + AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration(); + // 创建 OpenAIClient 对象 + AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties(); + connectionProperties.setApiKey(apiKey); + connectionProperties.setEndpoint(url); + OpenAIClient openAIClient = azureOpenAiAutoConfiguration.openAIClient(connectionProperties); + // 获取 AzureOpenAiChatProperties 对象 + AzureOpenAiChatProperties chatProperties = SpringUtil.getBean(AzureOpenAiChatProperties.class); + return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties, null, null); + } + /** * 可参考 {@link OpenAiAutoConfiguration} */ diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java index b25658c67..e18f10015 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java @@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; +import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.ollama.api.OllamaOptions; @@ -35,6 +36,9 @@ public class AiUtils { return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build(); case OPENAI: return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + case AZURE_OPENAI: + // TODO 芋艿:貌似没 model 字段???! + return AzureOpenAiChatOptions.builder().withDeploymentName(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); case OLLAMA: return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); default: diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java new file mode 100644 index 000000000..61c38dd1d --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java @@ -0,0 +1,59 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vectorstore.redis; + +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.RedisVectorStore; +import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; +import redis.clients.jedis.JedisPooled; + +/** + * TODO @xin 先拿 spring-ai 最新代码覆盖,1.0.0-M1 跟 redis 自动配置会冲突 + * + * TODO 这个官方,有说啥时候 fix 哇? + * + * @author Christian Tzolov + * @author Eddú Meléndez + */ +@AutoConfiguration(after = RedisAutoConfiguration.class) +@ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class}) +//@ConditionalOnBean(JedisConnectionFactory.class) +@EnableConfigurationProperties(RedisVectorStoreProperties.class) +public class RedisVectorStoreAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorStoreProperties properties, + JedisConnectionFactory jedisConnectionFactory) { + + var config = RedisVectorStoreConfig.builder() + .withIndexName(properties.getIndex()) + .withPrefix(properties.getPrefix()) + .build(); + + return new RedisVectorStore(config, embeddingModel, + new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()), + properties.isInitializeSchema()); + } + +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java new file mode 100644 index 000000000..de80401ed --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java @@ -0,0 +1,456 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.Pipeline; +import redis.clients.jedis.json.Path2; +import redis.clients.jedis.search.*; +import redis.clients.jedis.search.Schema.FieldType; +import redis.clients.jedis.search.schemafields.*; +import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; + +import java.text.MessageFormat; +import java.util.*; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +/** + * The RedisVectorStore is for managing and querying vector data in a Redis database. It + * offers functionalities like adding, deleting, and performing similarity searches on + * documents. + * + * The store utilizes RedisJSON and RedisSearch to handle JSON documents and to index and + * search vector data. It supports various vector algorithms (e.g., FLAT, HSNW) for + * efficient similarity searches. Additionally, it allows for custom metadata fields in + * the documents to be stored alongside the vector and content data. + * + * This class requires a RedisVectorStoreConfig configuration object for initialization, + * which includes settings like Redis URI, index name, field names, and vector algorithms. + * It also requires an EmbeddingModel to convert documents into embeddings before storing + * them. + * + * @author Julien Ruaux + * @author Christian Tzolov + * @author Eddú Meléndez + * @see VectorStore + * @see RedisVectorStoreConfig + * @see EmbeddingModel + */ +public class RedisVectorStore implements VectorStore, InitializingBean { + + public enum Algorithm { + + FLAT, HSNW + + } + + public record MetadataField(String name, FieldType fieldType) { + + public static MetadataField text(String name) { + return new MetadataField(name, FieldType.TEXT); + } + + public static MetadataField numeric(String name) { + return new MetadataField(name, FieldType.NUMERIC); + } + + public static MetadataField tag(String name) { + return new MetadataField(name, FieldType.TAG); + } + + } + + /** + * Configuration for the Redis vector store. + */ + public static final class RedisVectorStoreConfig { + + private final String indexName; + + private final String prefix; + + private final String contentFieldName; + + private final String embeddingFieldName; + + private final Algorithm vectorAlgorithm; + + private final List metadataFields; + + private RedisVectorStoreConfig() { + this(builder()); + } + + private RedisVectorStoreConfig(Builder builder) { + this.indexName = builder.indexName; + this.prefix = builder.prefix; + this.contentFieldName = builder.contentFieldName; + this.embeddingFieldName = builder.embeddingFieldName; + this.vectorAlgorithm = builder.vectorAlgorithm; + this.metadataFields = builder.metadataFields; + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + + return new Builder(); + } + + /** + * {@return the default config} + */ + public static RedisVectorStoreConfig defaultConfig() { + + return builder().build(); + } + + public static class Builder { + + private String indexName = DEFAULT_INDEX_NAME; + + private String prefix = DEFAULT_PREFIX; + + private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME; + + private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME; + + private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM; + + private List metadataFields = new ArrayList<>(); + + private Builder() { + } + + /** + * Configures the Redis index name to use. + * @param name the index name to use + * @return this builder + */ + public Builder withIndexName(String name) { + this.indexName = name; + return this; + } + + /** + * Configures the Redis key prefix to use (default: "embedding:"). + * @param prefix the prefix to use + * @return this builder + */ + public Builder withPrefix(String prefix) { + this.prefix = prefix; + return this; + } + + /** + * Configures the Redis content field name to use. + * @param name the content field name to use + * @return this builder + */ + public Builder withContentFieldName(String name) { + this.contentFieldName = name; + return this; + } + + /** + * Configures the Redis embedding field name to use. + * @param name the embedding field name to use + * @return this builder + */ + public Builder withEmbeddingFieldName(String name) { + this.embeddingFieldName = name; + return this; + } + + /** + * Configures the Redis vector algorithmto use. + * @param algorithm the vector algorithm to use + * @return this builder + */ + public Builder withVectorAlgorithm(Algorithm algorithm) { + this.vectorAlgorithm = algorithm; + return this; + } + + public Builder withMetadataFields(MetadataField... fields) { + return withMetadataFields(Arrays.asList(fields)); + } + + public Builder withMetadataFields(List fields) { + this.metadataFields = fields; + return this; + } + + /** + * {@return the immutable configuration} + */ + public RedisVectorStoreConfig build() { + + return new RedisVectorStoreConfig(this); + } + + } + + } + + private final boolean initializeSchema; + + public static final String DEFAULT_INDEX_NAME = "spring-ai-index"; + + public static final String DEFAULT_CONTENT_FIELD_NAME = "content"; + + public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding"; + + public static final String DEFAULT_PREFIX = "embedding:"; + + public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW; + + private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]"; + + private static final Path2 JSON_SET_PATH = Path2.of("$"); + + private static final String JSON_PATH_PREFIX = "$."; + + private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class); + + private static final Predicate RESPONSE_OK = Predicate.isEqual("OK"); + + private static final Predicate RESPONSE_DEL_OK = Predicate.isEqual(1l); + + private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32"; + + private static final String EMBEDDING_PARAM_NAME = "BLOB"; + + public static final String DISTANCE_FIELD_NAME = "vector_score"; + + private static final String DEFAULT_DISTANCE_METRIC = "COSINE"; + + private final JedisPooled jedis; + + private final EmbeddingModel embeddingModel; + + private final RedisVectorStoreConfig config; + + private FilterExpressionConverter filterExpressionConverter; + + public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, + boolean initializeSchema) { + + Assert.notNull(config, "Config must not be null"); + Assert.notNull(embeddingModel, "Embedding model must not be null"); + this.initializeSchema = initializeSchema; + + this.jedis = jedis; + this.embeddingModel = embeddingModel; + this.config = config; + this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields); + } + + public JedisPooled getJedis() { + return this.jedis; + } + + @Override + public void add(List documents) { + try (Pipeline pipeline = this.jedis.pipelined()) { + for (Document document : documents) { + var embedding = this.embeddingModel.embed(document); + document.setEmbedding(embedding); + + var fields = new HashMap(); + fields.put(this.config.embeddingFieldName, embedding); + fields.put(this.config.contentFieldName, document.getContent()); + fields.putAll(document.getMetadata()); + pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields); + } + List responses = pipeline.syncAndReturnAll(); + Optional errResponse = responses.stream().filter(Predicate.not(RESPONSE_OK)).findAny(); + if (errResponse.isPresent()) { + String message = MessageFormat.format("Could not add document: {0}", errResponse.get()); + if (logger.isErrorEnabled()) { + logger.error(message); + } + throw new RuntimeException(message); + } + } + } + + private String key(String id) { + return this.config.prefix + id; + } + + @Override + public Optional delete(List idList) { + try (Pipeline pipeline = this.jedis.pipelined()) { + for (String id : idList) { + pipeline.jsonDel(key(id)); + } + List responses = pipeline.syncAndReturnAll(); + Optional errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny(); + if (errResponse.isPresent()) { + if (logger.isErrorEnabled()) { + logger.error("Could not delete document: {}", errResponse.get()); + } + return Optional.of(false); + } + return Optional.of(true); + } + } + + @Override + public List similaritySearch(SearchRequest request) { + + Assert.isTrue(request.getTopK() > 0, "The number of documents to returned must be greater than zero"); + Assert.isTrue(request.getSimilarityThreshold() >= 0 && request.getSimilarityThreshold() <= 1, + "The similarity score is bounded between 0 and 1; least to most similar respectively."); + + String filter = nativeExpressionFilter(request); + + String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.config.embeddingFieldName, + EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME); + + List returnFields = new ArrayList<>(); + this.config.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add); + returnFields.add(this.config.embeddingFieldName); + returnFields.add(this.config.contentFieldName); + returnFields.add(DISTANCE_FIELD_NAME); + var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery())); + Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, RediSearchUtil.toByteArray(embedding)) + .returnFields(returnFields.toArray(new String[0])) + .setSortBy(DISTANCE_FIELD_NAME, true) + .dialect(2); + + SearchResult result = this.jedis.ftSearch(this.config.indexName, query); + return result.getDocuments() + .stream() + .filter(d -> similarityScore(d) >= request.getSimilarityThreshold()) + .map(this::toDocument) + .toList(); + } + + private Document toDocument(redis.clients.jedis.search.Document doc) { + var id = doc.getId().substring(this.config.prefix.length()); + var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) + : null; + Map metadata = this.config.metadataFields.stream() + .map(MetadataField::name) + .filter(doc::hasProperty) + .collect(Collectors.toMap(Function.identity(), doc::getString)); + metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc)); + return new Document(id, content, metadata); + } + + private float similarityScore(redis.clients.jedis.search.Document doc) { + return (2 - Float.parseFloat(doc.getString(DISTANCE_FIELD_NAME))) / 2; + } + + private String nativeExpressionFilter(SearchRequest request) { + if (request.getFilterExpression() == null) { + return "*"; + } + return "(" + this.filterExpressionConverter.convertExpression(request.getFilterExpression()) + ")"; + } + + @Override + public void afterPropertiesSet() { + + if (!this.initializeSchema) { + return; + } + + // If index already exists don't do anything + if (this.jedis.ftList().contains(this.config.indexName)) { + return; + } + + String response = this.jedis.ftCreate(this.config.indexName, + FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.config.prefix), schemaFields()); + if (!RESPONSE_OK.test(response)) { + String message = MessageFormat.format("Could not create index: {0}", response); + throw new RuntimeException(message); + } + } + + private Iterable schemaFields() { + Map vectorAttrs = new HashMap<>(); + vectorAttrs.put("DIM", this.embeddingModel.dimensions()); + vectorAttrs.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC); + vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32); + List fields = new ArrayList<>(); + fields.add(TextField.of(jsonPath(this.config.contentFieldName)).as(this.config.contentFieldName).weight(1.0)); + fields.add(VectorField.builder() + .fieldName(jsonPath(this.config.embeddingFieldName)) + .algorithm(vectorAlgorithm()) + .attributes(vectorAttrs) + .as(this.config.embeddingFieldName) + .build()); + + if (!CollectionUtils.isEmpty(this.config.metadataFields)) { + for (MetadataField field : this.config.metadataFields) { + fields.add(schemaField(field)); + } + } + return fields; + } + + private SchemaField schemaField(MetadataField field) { + String fieldName = jsonPath(field.name); + switch (field.fieldType) { + case NUMERIC: + return NumericField.of(fieldName).as(field.name); + case TAG: + return TagField.of(fieldName).as(field.name); + case TEXT: + return TextField.of(fieldName).as(field.name); + default: + throw new IllegalArgumentException( + MessageFormat.format("Field {0} has unsupported type {1}", field.name, field.fieldType)); + } + } + + private VectorAlgorithm vectorAlgorithm() { + if (config.vectorAlgorithm == Algorithm.HSNW) { + return VectorAlgorithm.HNSW; + } + return VectorAlgorithm.FLAT; + } + + private String jsonPath(String field) { + return JSON_PATH_PREFIX + field; + } + + private static float[] toFloatArray(List embeddingDouble) { + float[] embeddingFloat = new float[embeddingDouble.size()]; + int i = 0; + for (Double d : embeddingDouble) { + embeddingFloat[i++] = d.floatValue(); + } + return embeddingFloat; + } + +} \ No newline at end of file diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/resources/webapp/test/Fel.pdf b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/resources/webapp/test/Fel.pdf new file mode 100755 index 000000000..405b67fed Binary files /dev/null and b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/resources/webapp/test/Fel.pdf differ diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/AzureOpenAIChatModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/AzureOpenAIChatModelTests.java new file mode 100644 index 000000000..c85958779 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/AzureOpenAIChatModelTests.java @@ -0,0 +1,70 @@ +package cn.iocoder.yudao.framework.ai.chat; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.util.ClientOptions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.azure.openai.AzureOpenAiChatModel; +import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; + +import static org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties.DEFAULT_DEPLOYMENT_NAME; + +/** + * {@link AzureOpenAiChatModel} 集成测试 + * + * @author 芋道源码 + */ +public class AzureOpenAIChatModelTests { + + private final OpenAIClient openAiApi = (new OpenAIClientBuilder()) + .endpoint("https://eastusprejade.openai.azure.com") + .credential(new AzureKeyCredential("xxx")) + .clientOptions((new ClientOptions()).setApplicationId("spring-ai")) + .buildClient(); + private final AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAiApi, + AzureOpenAiChatOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).build()); + + @Test + @Disabled + public void testCall() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + ChatResponse response = chatModel.call(new Prompt(messages)); + // 打印结果 + System.out.println(response); + System.out.println(response.getResult().getOutput()); + } + + @Test + @Disabled + public void testStream() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + Flux flux = chatModel.stream(new Prompt(messages)); + // 打印结果 + flux.doOnNext(response -> { +// System.out.println(response); + System.out.println(response.getResult().getOutput()); + }).then().block(); + } + +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/OpenAIChatModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/OpenAIChatModelTests.java index 0d956e5b3..676832546 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/OpenAIChatModelTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/OpenAIChatModelTests.java @@ -1,6 +1,5 @@ package cn.iocoder.yudao.framework.ai.chat; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springframework.ai.chat.messages.Message; @@ -17,7 +16,7 @@ import java.util.ArrayList; import java.util.List; /** - * {@link XingHuoChatModel} 集成测试 + * {@link OpenAiChatModel} 集成测试 * * @author 芋道源码 */ diff --git a/yudao-server/src/main/resources/application.yaml b/yudao-server/src/main/resources/application.yaml index 80568517f..609f85b82 100644 --- a/yudao-server/src/main/resources/application.yaml +++ b/yudao-server/src/main/resources/application.yaml @@ -147,14 +147,22 @@ spring: spring: ai: + vectorstore: # 向量存储 + redis: + index: default-index + prefix: "default:" qianfan: # 文心一言 api-key: x0cuLZ7XsaTCU08vuJWO87Lg secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK zhipuai: # 智谱 AI api-key: 32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs - openai: + openai: # OpenAI 官方 api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z base-url: https://api.gptsapi.net + azure: # OpenAI 微软 + openai: + endpoint: https://eastusprejade.openai.azure.com + api-key: xxx ollama: base-url: http://127.0.0.1:11434 chat: