【新增】AI 知识库: 段落召回

This commit is contained in:
xiaoxin 2024-09-05 13:32:16 +08:00
parent c26cecaed6
commit 92d32b652e
6 changed files with 88 additions and 10 deletions

View File

@ -0,0 +1,17 @@
package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Schema(description = "管理后台 - AI 知识库段落召回 Request VO")
@Data
public class AiKnowledgeSegmentSearchReqVO {
@Schema(description = "知识库编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "24790")
private Long knowledgeId;
@Schema(description = "内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 学习路线")
private String content;
}

View File

@ -7,6 +7,8 @@ import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowle
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.List;
/** /**
* AI 知识库-分片 Mapper * AI 知识库-分片 Mapper
* *
@ -22,4 +24,10 @@ public interface AiKnowledgeSegmentMapper extends BaseMapperX<AiKnowledgeSegment
.likeIfPresent(AiKnowledgeSegmentDO::getContent, reqVO.getKeyword()) .likeIfPresent(AiKnowledgeSegmentDO::getContent, reqVO.getKeyword())
.orderByDesc(AiKnowledgeSegmentDO::getId)); .orderByDesc(AiKnowledgeSegmentDO::getId));
} }
default List<AiKnowledgeSegmentDO> selectList(List<String> vectorIdList) {
return selectList(new LambdaQueryWrapperX<AiKnowledgeSegmentDO>()
.in(AiKnowledgeSegmentDO::getVectorId, vectorIdList)
.orderByDesc(AiKnowledgeSegmentDO::getId));
}
} }

View File

@ -2,10 +2,13 @@ package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import java.util.List;
/** /**
* AI 知识库段落 Service 接口 * AI 知识库段落 Service 接口
* *
@ -35,4 +38,13 @@ public interface AiKnowledgeSegmentService {
*/ */
void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO); void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO);
/**
* 段落召回
*
* @param reqVO 召回请求信息
* @return 召回的段落
*/
List<AiKnowledgeSegmentDO> similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO);
} }

View File

@ -1,16 +1,29 @@
package cn.iocoder.yudao.module.ai.service.knowledge; package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ListUtil;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List;
/** /**
* AI 知识库分片 Service 实现类 * AI 知识库分片 Service 实现类
* *
@ -23,6 +36,13 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
@Resource @Resource
private AiKnowledgeSegmentMapper segmentMapper; private AiKnowledgeSegmentMapper segmentMapper;
@Resource
private AiKnowledgeService knowledgeService;
@Resource
private AiChatModelService chatModelService;
@Resource
private AiApiKeyService apiKeyService;
@Override @Override
public PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) { public PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) {
return segmentMapper.selectPage(pageReqVO); return segmentMapper.selectPage(pageReqVO);
@ -39,4 +59,26 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
segmentMapper.updateById(BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class)); segmentMapper.updateById(BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class));
// TODO @xin 1.禁用删除向量 2.启用重新向量化 // TODO @xin 1.禁用删除向量 2.启用重新向量化
} }
@Override
public List<AiKnowledgeSegmentDO> similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO) {
// 0. 校验
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqVO.getKnowledgeId());
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 1.1 获取向量存储实例
VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
// 1.2 向量检索
List<Document> documentList = vectorStore.similaritySearch(SearchRequest.query(reqVO.getContent())
//TODO @xin 配置提取
.withTopK(5)
.withSimilarityThreshold(0.5d)
.withFilterExpression(new FilterExpressionBuilder().eq(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, reqVO.getKnowledgeId()).build()));
if (CollUtil.isEmpty(documentList)) {
return ListUtil.empty();
}
// 2.1 段落召回
return segmentMapper.selectList(CollUtil.getFieldValues(documentList, "id", String.class));
}
} }

View File

@ -82,7 +82,7 @@ public class YudaoAiAutoConfiguration {
// TODO @xin 免费版本 // TODO @xin 免费版本
// @Bean // @Bean
// @Lazy // TODO 芋艿临时注释避免无法启动 // @Lazy // TODO 芋艿临时注释避免无法启动
// public EmbeddingModel transformersEmbeddingClient() { // public TransformersEmbeddingModel transformersEmbeddingClient() {
// return new TransformersEmbeddingModel(MetadataMode.EMBED); // return new TransformersEmbeddingModel(MetadataMode.EMBED);
// } // }
@ -91,23 +91,24 @@ public class YudaoAiAutoConfiguration {
*/ */
// @Bean // @Bean
// @Lazy // TODO 芋艿临时注释避免无法启动 // @Lazy // TODO 芋艿临时注释避免无法启动
// public RedisVectorStore vectorStore(TongYiTextEmbeddingModel tongYiTextEmbeddingModel, RedisVectorStoreProperties properties, // public RedisVectorStore vectorStore(TransformersEmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
// RedisProperties redisProperties) { // RedisProperties redisProperties) {
// var config = RedisVectorStore.RedisVectorStoreConfig.builder() // var config = RedisVectorStore.RedisVectorStoreConfig.builder()
// .withIndexName(properties.getIndex()) // .withIndexName(properties.getIndex())
// .withPrefix(properties.getPrefix()) // .withPrefix(properties.getPrefix())
// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
// .build(); // .build();
// //
// RedisVectorStore redisVectorStore = new RedisVectorStore(config, tongYiTextEmbeddingModel, // RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), // new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
// properties.isInitializeSchema()); // properties.isInitializeSchema());
// redisVectorStore.afterPropertiesSet(); // redisVectorStore.afterPropertiesSet();
// return redisVectorStore; // return redisVectorStore;
// } // }
@Bean @Bean
@Lazy // TODO 芋艿临时注释避免无法启动 @Lazy // TODO 芋艿临时注释避免无法启动
public TokenTextSplitter tokenTextSplitter() { public TokenTextSplitter tokenTextSplitter() {
//TODO @xin 配置提取
return new TokenTextSplitter(500, 100, 5, 10000, true); return new TokenTextSplitter(500, 100, 5, 10000, true);
} }

View File

@ -66,6 +66,7 @@ import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient; import org.springframework.web.client.RestClient;
import redis.clients.jedis.JedisPooled; import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.search.Schema;
import java.util.List; import java.util.List;
@ -200,14 +201,11 @@ public class AiModelFactoryImpl implements AiModelFactory {
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) { public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url); String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> { return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
// TODO 芋艿 @xin 这两个配置取哪好呢 String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
// TODO 不同模型的向量维度可能会不一样目前看貌似是以 index 来做区分的维度不一样存不到一个 index
// TODO 回复好的哈
String index = "default-index";
String prefix = "default:";
var config = RedisVectorStore.RedisVectorStoreConfig.builder() var config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withIndexName(index) .withIndexName(cacheKey)
.withPrefix(prefix) .withPrefix(prefix)
.withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
.build(); .build();
RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class); RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel, RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,