【新增】AI 知识库: AiVectorFactory 负责管理不同 EmbeddingModel 对应的 VectorStore

This commit is contained in:
xiaoxin 2024-08-29 14:13:37 +08:00
parent 024109dac9
commit f97fb0a8fe
13 changed files with 239 additions and 52 deletions

View File

@ -45,7 +45,7 @@ public class AiKnowledgeDO extends BaseDO {
@TableField(typeHandler = JacksonTypeHandler.class)
private List<Long> visibilityPermissions;
/**
* 嵌入模型编号高质量模式时维护
* 嵌入模型编号
*/
private Long modelId;
/**

View File

@ -24,10 +24,14 @@ public class AiKnowledgeSegmentDO extends BaseDO {
* 向量库的编号
*/
private String vectorId;
// TODO @新knowledgeId 加个会方便点
/**
* 知识库编号
* 关联 {@link AiKnowledgeDO#getId()}
*/
private Long knowledgeId;
/**
* 文档编号
*
* <p>
* 关联 {@link AiKnowledgeDocumentDO#getId()}
*/
private Long documentId;

View File

@ -6,24 +6,27 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeDocumentCreateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.List;
import java.util.Objects;
/**
* AI 知识库-文档 Service 实现类
@ -42,9 +45,14 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
@Resource
private TokenTextSplitter tokenTextSplitter;
@Resource
private TokenCountEstimator TOKEN_COUNT_ESTIMATOR;
private TokenCountEstimator tokenCountEstimator;
@Resource
private RedisVectorStore vectorStore;
private AiApiKeyService apiKeyService;
@Resource
private AiKnowledgeService knowledgeService;
@Resource
private AiChatModelService chatModelService;
// TODO 芋艿需要 review 代码格式
@ -53,18 +61,18 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) {
// 1.1 下载文档
String url = createReqVO.getUrl();
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url));
// 1.2 加载文档
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url));
List<Document> documents = loader.get();
Document document = CollUtil.getFirst(documents);
// TODO @xin是不是不存在就抛出异常呀厚泽 return
Integer tokens = Objects.nonNull(document) ? TOKEN_COUNT_ESTIMATOR.estimate(document.getContent()) : 0;
Integer wordCount = Objects.nonNull(document) ? document.getContent().length() : 0;
String content = document.getContent();
Integer tokens = tokenCountEstimator.estimate(content);
Integer wordCount = content.length();
// 1.3 文档记录入库
AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class)
.setTokens(tokens).setWordCount(wordCount)
.setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus());
// 1.2 文档记录入库
documentMapper.insert(documentDO);
Long documentId = documentDO.getId();
if (CollUtil.isEmpty(documents)) {
@ -75,11 +83,16 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
List<Document> segments = tokenTextSplitter.apply(documents);
// 2.2 分段内容入库
List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments,
segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId)
.setTokens(TOKEN_COUNT_ESTIMATOR.estimate(segment.getContent())).setWordCount(segment.getContent().length())
segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId).setKnowledgeId(createReqVO.getKnowledgeId())
.setTokens(tokenCountEstimator.estimate(segment.getContent())).setWordCount(segment.getContent().length())
.setStatus(CommonStatusEnum.ENABLE.getStatus()));
segmentMapper.insertBatch(segmentDOList);
// 3 向量化并存储
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 3.1 获取向量存储实例
VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
// 3.2 向量化并存储
vectorStore.add(segments);
return documentId;
}

View File

@ -1,6 +1,8 @@
package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeCreateMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeUpdateMyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
/**
* AI 知识库-基础信息 Service 接口
@ -13,7 +15,7 @@ public interface AiKnowledgeService {
* 创建我的知识库
*
* @param createReqVO 创建信息
* @param userId 用户编号
* @param userId 用户编号
* @return 编号
*/
Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId);
@ -23,8 +25,16 @@ public interface AiKnowledgeService {
* 创建我的知识库
*
* @param updateReqVO 更新信息
* @param userId 用户编号
* @param userId 用户编号
*/
void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId);
/**
* 校验知识库是否存在
*
* @param id 记录编号
*/
AiKnowledgeDO validateKnowledgeExists(Long id);
}

View File

@ -29,7 +29,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
private AiChatModelService chatModalService;
@Resource
private AiKnowledgeMapper knowledgeBaseMapper;
private AiKnowledgeMapper knowledgeMapper;
@Override
public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) {
@ -39,7 +39,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
// 2. 插入知识库
AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
.setModel(model.getModel()).setUserId(userId).setStatus(CommonStatusEnum.ENABLE.getStatus());
knowledgeBaseMapper.insert(knowledgeBase);
knowledgeMapper.insert(knowledgeBase);
return knowledgeBase.getId();
}
@ -56,11 +56,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
// 2. 更新知识库
AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class);
updateDO.setModel(model.getModel());
knowledgeBaseMapper.updateById(updateDO);
knowledgeMapper.updateById(updateDO);
}
@Override
public AiKnowledgeDO validateKnowledgeExists(Long id) {
AiKnowledgeDO knowledgeBase = knowledgeBaseMapper.selectById(id);
AiKnowledgeDO knowledgeBase = knowledgeMapper.selectById(id);
if (knowledgeBase == null) {
throw exception(KNOWLEDGE_NOT_EXISTS);
}

View File

@ -9,7 +9,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import java.util.List;
@ -83,6 +85,14 @@ public interface AiApiKeyService {
*/
ChatModel getChatModel(Long id);
/**
* 获得 EmbeddingModel 对象
*
* @param id 编号
* @return EmbeddingModel 对象
*/
EmbeddingModel getEmbeddingModel(Long id);
/**
* 获得 ImageModel 对象
*
@ -111,4 +121,12 @@ public interface AiApiKeyService {
*/
SunoApi getSunoApi();
/**
* 获得 vector 对象
*
* @param id 编号
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(Long id);
}

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
@ -13,7 +14,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
@ -36,6 +39,8 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource
private AiModelFactory modelFactory;
@Resource
private AiVectorFactory vectorFactory;
@Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
@ -104,6 +109,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public EmbeddingModel getEmbeddingModel(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public ImageModel getImageModel(AiPlatformEnum platform) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
@ -132,4 +144,11 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
}
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public VectorStore getOrCreateVectorStore(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
}
}

View File

@ -2,6 +2,8 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
@ -10,22 +12,15 @@ import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Lazy;
import redis.clients.jedis.JedisPooled;
/**
* 芋道 AI 自动配置
@ -43,6 +38,12 @@ public class YudaoAiAutoConfiguration {
return new AiModelFactoryImpl();
}
@Bean
public AiVectorFactory aiVectorFactory() {
return new AiVectorFactoryImpl();
}
// ========== 各种 AI Client 创建 ==========
@Bean
@ -85,30 +86,31 @@ public class YudaoAiAutoConfiguration {
}
// ========== rag 相关 ==========
@Bean
@Lazy // TODO 芋艿临时注释避免无法启动
public EmbeddingModel transformersEmbeddingClient() {
return new TransformersEmbeddingModel(MetadataMode.EMBED);
}
// TODO @xin 免费版本
// @Bean
// @Lazy // TODO 芋艿临时注释避免无法启动
// public EmbeddingModel transformersEmbeddingClient() {
// return new TransformersEmbeddingModel(MetadataMode.EMBED);
// }
/**
* TODO @xin 抽离出去根据具体模型走
* TODO @xin 默认版本先不弄目前都先取对应的 EmbeddingModel
*/
@Bean
@Lazy // TODO 芋艿临时注释避免无法启动
public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties,
RedisProperties redisProperties) {
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withIndexName(properties.getIndex())
.withPrefix(properties.getPrefix())
.build();
RedisVectorStore redisVectorStore = new RedisVectorStore(config, transformersEmbeddingModel,
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
properties.isInitializeSchema());
redisVectorStore.afterPropertiesSet();
return redisVectorStore;
}
// @Bean
// @Lazy // TODO 芋艿临时注释避免无法启动
// public RedisVectorStore vectorStore(TongYiTextEmbeddingModel tongYiTextEmbeddingModel, RedisVectorStoreProperties properties,
// RedisProperties redisProperties) {
// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
// .withIndexName(properties.getIndex())
// .withPrefix(properties.getPrefix())
// .build();
//
// RedisVectorStore redisVectorStore = new RedisVectorStore(config, tongYiTextEmbeddingModel,
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
// properties.isInitializeSchema());
// redisVectorStore.afterPropertiesSet();
// return redisVectorStore;
// }
@Bean
@Lazy // TODO 芋艿临时注释避免无法启动

View File

@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
/**
@ -25,6 +26,18 @@ public interface AiModelFactory {
*/
ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于指定配置获得 EmbeddingModel 对象
* <p>
* 如果不存在则进行创建
*
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @return ChatModel 对象
*/
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于默认配置获得 ChatModel 对象
*

View File

@ -21,6 +21,7 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.embeddings.TextEmbedding;
import com.azure.ai.openai.OpenAIClient;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
@ -37,6 +38,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.ollama.OllamaChatModel;
@ -97,6 +99,21 @@ public class AiModelFactoryImpl implements AiModelFactory {
});
}
@Override
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
// TODO @xin 先测试一个
switch (platform) {
case TONG_YI:
return buildTongYiEmbeddingModel(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
});
}
@Override
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
@ -239,7 +256,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
/**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
* ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
*ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
*/
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
@ -249,7 +266,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
/**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
* ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
@ -315,4 +332,15 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new StabilityAiImageModel(stabilityAiApi);
}
// ========== 各种创建 EmbeddingModel 的方法 ==========
/**
* 可参考 {@link TongYiAutoConfiguration#tongYiTextEmbeddingClient(TextEmbedding, TongYiConnectionProperties)}
*/
private EmbeddingModel buildTongYiEmbeddingModel(String apiKey) {
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
connectionProperties.setApiKey(apiKey);
return new TongYiAutoConfiguration().tongYiTextEmbeddingClient(SpringUtil.getBean(TextEmbedding.class), connectionProperties);
}
}

View File

@ -0,0 +1,27 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
/**
* AI Vector 模型工厂的接口类
* @author xiaoxin
*/
public interface AiVectorFactory {
/**
* 基于指定配置获得 VectorStore 对象
* <p>
* 如果不存在则进行创建
*
* @param embeddingModel 嵌入模型
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
}

View File

@ -0,0 +1,51 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.hutool.core.lang.Singleton;
import cn.hutool.core.lang.func.Func0;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import redis.clients.jedis.JedisPooled;
/**
* AI Vector 模型工厂的实现类
* 使用 redisVectorStore 实现 VectorStore
*
* @author xiaoxin
*/
public class AiVectorFactoryImpl implements AiVectorFactory {
@Override
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
// TODO 芋艿 @xin 这两个配置取哪好呢
// TODO 不同模型的向量维度可能会不一样目前看貌似是以 index 来做区分的维度不一样存不到一个 index
String index = "default-index";
String prefix = "default:";
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withIndexName(index)
.withPrefix(prefix)
.build();
RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
true);
redisVectorStore.afterPropertiesSet();
return redisVectorStore;
});
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) {
return clazz.getName();
}
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
}
}

View File

@ -19,6 +19,7 @@ import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration;
@ -38,7 +39,7 @@ import redis.clients.jedis.JedisPooled;
*/
@AutoConfiguration(after = RedisAutoConfiguration.class)
@ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class})
//@ConditionalOnBean(JedisConnectionFactory.class)
@ConditionalOnBean(JedisConnectionFactory.class)
@EnableConfigurationProperties(RedisVectorStoreProperties.class)
public class RedisVectorStoreAutoConfiguration {