diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/knowledge/AiKnowledgeCreateMyReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/knowledge/AiKnowledgeCreateMyReqVO.java
index 44a5e87ee..58a89caee 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/knowledge/AiKnowledgeCreateMyReqVO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/knowledge/AiKnowledgeCreateMyReqVO.java
@@ -25,4 +25,11 @@ public class AiKnowledgeCreateMyReqVO {
@NotNull(message = "嵌入模型不能为空")
private Long modelId;
+ @Schema(description = "相似性阈值", requiredMode = Schema.RequiredMode.REQUIRED, example = "0.5")
+ @NotNull(message = "相似性阈值不能为空")
+ private Double similarityThreshold;
+
+ @Schema(description = "topK", requiredMode = Schema.RequiredMode.REQUIRED, example = "3")
+ @NotNull(message = "topK 不能为空")
+ private Integer topK;
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/knowledge/AiKnowledgeDocumentCreateReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/knowledge/AiKnowledgeDocumentCreateReqVO.java
index 9cc5290ab..651bdc0f7 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/knowledge/AiKnowledgeDocumentCreateReqVO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/knowledge/AiKnowledgeDocumentCreateReqVO.java
@@ -23,4 +23,23 @@ public class AiKnowledgeDocumentCreateReqVO {
@URL(message = "文档 URL 格式不正确")
private String url;
+ @Schema(description = "每个文本块的目标 token 数", requiredMode = Schema.RequiredMode.REQUIRED, example = "800")
+ @NotNull(message = "每个文本块的目标 token 数不能为空")
+ private Integer defaultChunkSize;
+
+ @Schema(description = "每个文本块的最小字符数", requiredMode = Schema.RequiredMode.REQUIRED, example = "350")
+ @NotNull(message = "每个文本块的最小字符数不能为空")
+ private Integer minChunkSizeChars;
+
+ @Schema(description = "丢弃阈值", requiredMode = Schema.RequiredMode.REQUIRED, example = "5")
+ @NotNull(message = "丢弃阈值不能为空")
+ private Integer minChunkLengthToEmbed;
+
+ @Schema(description = "最大块数", requiredMode = Schema.RequiredMode.REQUIRED, example = "10000")
+ @NotNull(message = "最大块数不能为空")
+ private Integer maxNumChunks;
+
+ @Schema(description = "分块是否保留分隔符", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
+ @NotNull(message = "分块是否保留分隔符不能为空")
+ private Boolean keepSeparator;
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/segment/AiKnowledgeSegmentSearchReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/segment/AiKnowledgeSegmentSearchReqVO.java
new file mode 100644
index 000000000..75349df62
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/vo/segment/AiKnowledgeSegmentSearchReqVO.java
@@ -0,0 +1,17 @@
+package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import lombok.Data;
+
+
+@Schema(description = "管理后台 - AI 知识库段落召回 Request VO")
+@Data
+public class AiKnowledgeSegmentSearchReqVO {
+
+ @Schema(description = "知识库编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "24790")
+ private Long knowledgeId;
+
+ @Schema(description = "内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 学习路线")
+ private String content;
+
+}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java
index 756d8cdb3..5db631dd4 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java
@@ -52,6 +52,18 @@ public class AiKnowledgeDO extends BaseDO {
* 模型标识
*/
private String model;
+
+ /**
+ * topK
+ */
+ private Integer topK;
+
+ /**
+ * 相似度阈值
+ */
+ private Double similarityThreshold;
+
+
/**
* 状态
*
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDocumentDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDocumentDO.java
index c5e526cce..18fa46c3a 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDocumentDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDocumentDO.java
@@ -23,7 +23,7 @@ public class AiKnowledgeDocumentDO extends BaseDO {
private Long id;
/**
* 知识库编号
- *
+ *
* 关联 {@link AiKnowledgeDO#getId()}
*/
private Long knowledgeId;
@@ -47,6 +47,26 @@ public class AiKnowledgeDocumentDO extends BaseDO {
* 字符数
*/
private Integer wordCount;
+ /**
+ * 每个文本块的目标 token 数
+ */
+ private Integer defaultChunkSize;
+ /**
+ * 每个文本块的最小字符数
+ */
+ private Integer minChunkSizeChars;
+ /**
+ * 低于此值的块会被丢弃
+ */
+ private Integer minChunkLengthToEmbed;
+ /**
+ * 最大块数
+ */
+ private Integer maxNumChunks;
+ /**
+ * 分块是否保留分隔符
+ */
+ private Boolean keepSeparator;
/**
* 切片状态
*
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java
index 84f7de654..be57265e1 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java
@@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import com.baomidou.mybatisplus.annotation.FieldStrategy;
+import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
@@ -25,16 +27,17 @@ public class AiKnowledgeSegmentDO extends BaseDO {
/**
* 向量库的编号
*/
+ @TableField(updateStrategy = FieldStrategy.ALWAYS)
private String vectorId;
/**
* 知识库编号
- *
+ *
* 关联 {@link AiKnowledgeDO#getId()}
*/
private Long knowledgeId;
/**
* 文档编号
- *
+ *
* 关联 {@link AiKnowledgeDocumentDO#getId()}
*/
private Long documentId;
@@ -52,7 +55,7 @@ public class AiKnowledgeSegmentDO extends BaseDO {
private Integer tokens;
/**
* 状态
- *
+ *
* 枚举 {@link CommonStatusEnum}
*/
private Integer status;
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeSegmentMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeSegmentMapper.java
index 912d18cbc..ea238b826 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeSegmentMapper.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeSegmentMapper.java
@@ -7,6 +7,8 @@ import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowle
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import org.apache.ibatis.annotations.Mapper;
+import java.util.List;
+
/**
* AI 知识库-分片 Mapper
*
@@ -22,4 +24,10 @@ public interface AiKnowledgeSegmentMapper extends BaseMapperX selectList(List vectorIdList) {
+ return selectList(new LambdaQueryWrapperX()
+ .in(AiKnowledgeSegmentDO::getVectorId, vectorIdList)
+ .orderByDesc(AiKnowledgeSegmentDO::getId));
+ }
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java
index 99f0621c8..05a9dce22 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java
@@ -9,15 +9,11 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeDocumentCreateReqVO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
@@ -48,24 +44,16 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
@Resource
private AiKnowledgeSegmentMapper segmentMapper;
- @Resource
- private TokenTextSplitter tokenTextSplitter;
@Resource
private TokenCountEstimator tokenCountEstimator;
-
- @Resource
- private AiApiKeyService apiKeyService;
@Resource
private AiKnowledgeService knowledgeService;
- @Resource
- private AiChatModelService chatModelService;
@Override
@Transactional(rollbackFor = Exception.class)
public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) {
- // 0. 校验
- AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
- AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
+ // 0. 校验并获取向量存储实例
+ VectorStore vectorStore = knowledgeService.getVectorStoreById(createReqVO.getKnowledgeId());
// 1.1 下载文档
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(createReqVO.getUrl()));
@@ -82,6 +70,9 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
return documentId;
}
+ // 2 构造文本分段器
+ TokenTextSplitter tokenTextSplitter = new TokenTextSplitter(createReqVO.getDefaultChunkSize(), createReqVO.getMinChunkSizeChars(), createReqVO.getMinChunkLengthToEmbed(),
+ createReqVO.getMaxNumChunks(), createReqVO.getKeepSeparator());
// 2.1 文档分段
List segments = tokenTextSplitter.apply(documents);
// 2.2 分段内容入库
@@ -92,8 +83,6 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
.setStatus(CommonStatusEnum.ENABLE.getStatus()));
segmentMapper.insertBatch(segmentDOList);
- // 3.1 获取向量存储实例
- VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
// 3.2 向量化并存储
segments.forEach(segment -> segment.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, createReqVO.getKnowledgeId()));
vectorStore.add(segments);
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java
index 8ecb2d24a..49ed67135 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java
@@ -2,10 +2,13 @@ package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
+import java.util.List;
+
/**
* AI 知识库段落 Service 接口
*
@@ -35,4 +38,13 @@ public interface AiKnowledgeSegmentService {
*/
void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO);
+
+ /**
+ * 段落召回
+ *
+ * @param reqVO 召回请求信息
+ * @return 召回的段落
+ */
+ List similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO);
+
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java
index 7f751b176..1813d0b48 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java
@@ -1,16 +1,34 @@
package cn.iocoder.yudao.module.ai.service.knowledge;
+import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.collection.ListUtil;
+import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
+import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
+import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.vectorstore.SearchRequest;
+import org.springframework.ai.vectorstore.VectorStore;
+import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.stereotype.Service;
+import java.util.List;
+import java.util.Objects;
+
+import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
+
/**
* AI 知识库分片 Service 实现类
*
@@ -23,6 +41,13 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
@Resource
private AiKnowledgeSegmentMapper segmentMapper;
+ @Resource
+ private AiKnowledgeService knowledgeService;
+ @Resource
+ private AiChatModelService chatModelService;
+ @Resource
+ private AiApiKeyService apiKeyService;
+
@Override
public PageResult getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) {
return segmentMapper.selectPage(pageReqVO);
@@ -30,13 +55,80 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
@Override
public void updateKnowledgeSegment(AiKnowledgeSegmentUpdateReqVO reqVO) {
- segmentMapper.updateById(BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class));
- // TODO @xin 重新向量化
+ // 0 校验
+ AiKnowledgeSegmentDO oldKnowledgeSegment = validateKnowledgeSegmentExists(reqVO.getId());
+ // 2.1 获取知识库向量实例
+ VectorStore vectorStore = knowledgeService.getVectorStoreById(oldKnowledgeSegment.getKnowledgeId());
+ // 2.2 删除原向量
+ vectorStore.delete(List.of(oldKnowledgeSegment.getVectorId()));
+
+ // 2.3 重新向量化
+ Document document = new Document(reqVO.getContent());
+ document.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, oldKnowledgeSegment.getKnowledgeId());
+ vectorStore.add(List.of(document));
+
+ // 2.1 更新段落内容
+ AiKnowledgeSegmentDO knowledgeSegment = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
+ knowledgeSegment.setVectorId(document.getId());
+ segmentMapper.updateById(knowledgeSegment);
}
@Override
public void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO) {
- segmentMapper.updateById(BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class));
- // TODO @xin 1.禁用删除向量 2.启用重新向量化
+ // 0 校验
+ AiKnowledgeSegmentDO oldKnowledgeSegment = validateKnowledgeSegmentExists(reqVO.getId());
+ // 1 获取知识库向量实例
+ VectorStore vectorStore = knowledgeService.getVectorStoreById(oldKnowledgeSegment.getKnowledgeId());
+ AiKnowledgeSegmentDO knowledgeSegment = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
+
+ if (Objects.equals(reqVO.getStatus(), CommonStatusEnum.ENABLE.getStatus())) {
+ // 2.1 启用重新向量化
+ Document document = new Document(oldKnowledgeSegment.getContent());
+ document.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, oldKnowledgeSegment.getKnowledgeId());
+ vectorStore.add(List.of(document));
+ knowledgeSegment.setVectorId(document.getId());
+ } else {
+ // 2.2 禁用删除向量
+ vectorStore.delete(List.of(oldKnowledgeSegment.getVectorId()));
+ knowledgeSegment.setVectorId(null);
+ }
+ // 3 更新段落状态
+ segmentMapper.updateById(knowledgeSegment);
+ }
+
+ @Override
+ public List similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO) {
+ // 0. 校验
+ AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqVO.getKnowledgeId());
+ AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
+
+ // 1.1 获取向量存储实例
+ VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
+
+ // 1.2 向量检索
+ List documentList = vectorStore.similaritySearch(SearchRequest.query(reqVO.getContent())
+ .withTopK(knowledge.getTopK())
+ .withSimilarityThreshold(knowledge.getSimilarityThreshold())
+ .withFilterExpression(new FilterExpressionBuilder().eq(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, reqVO.getKnowledgeId()).build()));
+ if (CollUtil.isEmpty(documentList)) {
+ return ListUtil.empty();
+ }
+ // 2.1 段落召回
+ return segmentMapper.selectList(CollUtil.getFieldValues(documentList, "id", String.class));
+ }
+
+
+ /**
+ * 校验段落是否存在
+ *
+ * @param id 文档编号
+ * @return 段落信息
+ */
+ private AiKnowledgeSegmentDO validateKnowledgeSegmentExists(Long id) {
+ AiKnowledgeSegmentDO knowledgeSegment = segmentMapper.selectById(id);
+ if (knowledgeSegment == null) {
+ throw exception(KNOWLEDGE_SEGMENT_NOT_EXISTS);
+ }
+ return knowledgeSegment;
}
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java
index 9f43c5328..d9770f452 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java
@@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateMyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
+import org.springframework.ai.vectorstore.VectorStore;
/**
* AI 知识库-基础信息 Service 接口
@@ -47,4 +48,12 @@ public interface AiKnowledgeService {
* @return 知识库分页
*/
PageResult getKnowledgePageMy(Long userId, PageParam pageReqVO);
+
+ /**
+ * 根据知识库编号获取向量存储实例
+ *
+ * @param knowledgeId 知识库编号
+ * @return 向量存储实例
+ */
+ VectorStore getVectorStoreById(Long knowledgeId);
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java
index 1948bb00e..7a145d734 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java
@@ -10,9 +10,11 @@ import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnow
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
+import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@@ -32,6 +34,10 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
@Resource
private AiKnowledgeMapper knowledgeMapper;
+ @Resource
+ private AiChatModelService chatModelService;
+ @Resource
+ private AiApiKeyService apiKeyService;
@Override
public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) {
@@ -75,4 +81,11 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
return knowledgeMapper.selectPageByMy(userId, pageReqVO);
}
+ @Override
+ public VectorStore getVectorStoreById(Long knowledgeId) {
+ AiKnowledgeDO knowledge = validateKnowledgeExists(knowledgeId);
+ AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
+ return apiKeyService.getOrCreateVectorStore(model.getKeyId());
+ }
+
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java
index bf11ec218..b2807905a 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java
@@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
-import cn.iocoder.yudao.framework.ai.core.factory.AiVectorStoreFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
@@ -39,8 +38,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource
private AiModelFactory modelFactory;
- @Resource
- private AiVectorStoreFactory vectorFactory;
@Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
@@ -149,7 +146,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
public VectorStore getOrCreateVectorStore(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
- return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
+ return modelFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
}
}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml b/yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml
index 85996cb82..27c762a72 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml
@@ -46,11 +46,13 @@
-
- ${spring-ai.groupId}
- spring-ai-transformers-spring-boot-starter
- ${spring-ai.version}
-
+
+
+
+
+
+
+
${spring-ai.groupId}
spring-ai-tika-document-reader
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java
index 79a1f345b..0d2620b0c 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java
@@ -2,8 +2,6 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
-import cn.iocoder.yudao.framework.ai.core.factory.AiVectorStoreFactory;
-import cn.iocoder.yudao.framework.ai.core.factory.AiVectorStoreFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
@@ -38,11 +36,6 @@ public class YudaoAiAutoConfiguration {
return new AiModelFactoryImpl();
}
- @Bean
- public AiVectorStoreFactory aiVectorFactory() {
- return new AiVectorStoreFactoryImpl();
- }
-
// ========== 各种 AI Client 创建 ==========
@@ -89,7 +82,7 @@ public class YudaoAiAutoConfiguration {
// TODO @xin 免费版本
// @Bean
// @Lazy // TODO 芋艿:临时注释,避免无法启动」
-// public EmbeddingModel transformersEmbeddingClient() {
+// public TransformersEmbeddingModel transformersEmbeddingClient() {
// return new TransformersEmbeddingModel(MetadataMode.EMBED);
// }
@@ -98,23 +91,24 @@ public class YudaoAiAutoConfiguration {
*/
// @Bean
// @Lazy // TODO 芋艿:临时注释,避免无法启动
-// public RedisVectorStore vectorStore(TongYiTextEmbeddingModel tongYiTextEmbeddingModel, RedisVectorStoreProperties properties,
+// public RedisVectorStore vectorStore(TransformersEmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
// RedisProperties redisProperties) {
// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
// .withIndexName(properties.getIndex())
// .withPrefix(properties.getPrefix())
+// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
// .build();
//
-// RedisVectorStore redisVectorStore = new RedisVectorStore(config, tongYiTextEmbeddingModel,
+// RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
// properties.isInitializeSchema());
// redisVectorStore.afterPropertiesSet();
// return redisVectorStore;
// }
-
@Bean
@Lazy // TODO 芋艿:临时注释,避免无法启动
public TokenTextSplitter tokenTextSplitter() {
+ //TODO @xin 配置提取
return new TokenTextSplitter(500, 100, 5, 10000, true);
}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java
index 7e8465375..243c4ae4b 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java
@@ -6,6 +6,7 @@ import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
+import org.springframework.ai.vectorstore.VectorStore;
/**
* AI Model 模型工厂的接口类
@@ -92,4 +93,17 @@ public interface AiModelFactory {
*/
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
+ /**
+ * 基于指定配置,获得 VectorStore 对象
+ *
+ * 如果不存在,则进行创建
+ *
+ * @param embeddingModel 嵌入模型
+ * @param platform 平台
+ * @param apiKey API KEY
+ * @param url API URL
+ * @return VectorStore 对象
+ */
+ VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
+
}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java
index aa46c45f2..7acd24769 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java
@@ -13,6 +13,7 @@ import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
+import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
@@ -54,13 +55,18 @@ import org.springframework.ai.qianfan.api.QianFanApi;
import org.springframework.ai.qianfan.api.QianFanImageApi;
import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
+import org.springframework.ai.vectorstore.RedisVectorStore;
+import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
+import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
+import redis.clients.jedis.JedisPooled;
+import redis.clients.jedis.search.Schema;
import java.util.List;
@@ -191,6 +197,25 @@ public class AiModelFactoryImpl implements AiModelFactory {
});
}
+ @Override
+ public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
+ String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
+ return Singleton.get(cacheKey, (Func0) () -> {
+ String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
+ var config = RedisVectorStore.RedisVectorStoreConfig.builder()
+ .withIndexName(cacheKey)
+ .withPrefix(prefix)
+ .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
+ .build();
+ RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
+ RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
+ new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
+ true);
+ redisVectorStore.afterPropertiesSet();
+ return redisVectorStore;
+ });
+ }
+
private static String buildClientCacheKey(Class> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) {
return clazz.getName();
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorStoreFactory.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorStoreFactory.java
deleted file mode 100644
index dad58a2c0..000000000
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorStoreFactory.java
+++ /dev/null
@@ -1,28 +0,0 @@
-package cn.iocoder.yudao.framework.ai.core.factory;
-
-import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.vectorstore.VectorStore;
-
-// TODO @xin:也放到 AiModelFactory 里面好了,后续改成 AiFactory
-/**
- * AI Vector 模型工厂的接口类
- *
- * @author xiaoxin
- */
-public interface AiVectorStoreFactory {
-
- /**
- * 基于指定配置,获得 VectorStore 对象
- *
- * 如果不存在,则进行创建
- *
- * @param embeddingModel 嵌入模型
- * @param platform 平台
- * @param apiKey API KEY
- * @param url API URL
- * @return VectorStore 对象
- */
- VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
-
-}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorStoreFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorStoreFactoryImpl.java
deleted file mode 100644
index ec04c5e88..000000000
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorStoreFactoryImpl.java
+++ /dev/null
@@ -1,52 +0,0 @@
-package cn.iocoder.yudao.framework.ai.core.factory;
-
-import cn.hutool.core.lang.Singleton;
-import cn.hutool.core.lang.func.Func0;
-import cn.hutool.core.util.ArrayUtil;
-import cn.hutool.core.util.StrUtil;
-import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.vectorstore.RedisVectorStore;
-import org.springframework.ai.vectorstore.VectorStore;
-import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
-import redis.clients.jedis.JedisPooled;
-
-/**
- * AI Vector 模型工厂的实现类
- * 使用 redisVectorStore 实现 VectorStore
- *
- * @author xiaoxin
- */
-public class AiVectorStoreFactoryImpl implements AiVectorStoreFactory {
-
- @Override
- public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
- String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
- return Singleton.get(cacheKey, (Func0) () -> {
- // TODO 芋艿 @xin 这两个配置取哪好呢
- // TODO 不同模型的向量维度可能会不一样,目前看貌似是以 index 来做区分的,维度不一样存不到一个 index 上
- // TODO 回复:好的哈
- String index = "default-index";
- String prefix = "default:";
- var config = RedisVectorStore.RedisVectorStoreConfig.builder()
- .withIndexName(index)
- .withPrefix(prefix)
- .build();
- RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
- RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
- new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
- true);
- redisVectorStore.afterPropertiesSet();
- return redisVectorStore;
- });
- }
-
- private static String buildClientCacheKey(Class> clazz, Object... params) {
- if (ArrayUtil.isEmpty(params)) {
- return clazz.getName();
- }
- return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
- }
-
-}