【新增】AI 知识库: AiVectorFactory 负责管理不同 EmbeddingModel 对应的 VectorStore

This commit is contained in:
xiaoxin 2024-08-29 14:13:37 +08:00
parent 024109dac9
commit f97fb0a8fe
13 changed files with 239 additions and 52 deletions

View File

@ -45,7 +45,7 @@ public class AiKnowledgeDO extends BaseDO {
@TableField(typeHandler = JacksonTypeHandler.class) @TableField(typeHandler = JacksonTypeHandler.class)
private List<Long> visibilityPermissions; private List<Long> visibilityPermissions;
/** /**
* 嵌入模型编号高质量模式时维护 * 嵌入模型编号
*/ */
private Long modelId; private Long modelId;
/** /**

View File

@ -24,10 +24,14 @@ public class AiKnowledgeSegmentDO extends BaseDO {
* 向量库的编号 * 向量库的编号
*/ */
private String vectorId; private String vectorId;
// TODO @新knowledgeId 加个会方便点 /**
* 知识库编号
* 关联 {@link AiKnowledgeDO#getId()}
*/
private Long knowledgeId;
/** /**
* 文档编号 * 文档编号
* * <p>
* 关联 {@link AiKnowledgeDocumentDO#getId()} * 关联 {@link AiKnowledgeDocumentDO#getId()}
*/ */
private Long documentId; private Long documentId;

View File

@ -6,24 +6,27 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils; import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeDocumentCreateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeDocumentCreateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum; import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
import org.springframework.ai.reader.tika.TikaDocumentReader; import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.tokenizer.TokenCountEstimator; import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.ByteArrayResource; import org.springframework.core.io.ByteArrayResource;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.util.List; import java.util.List;
import java.util.Objects;
/** /**
* AI 知识库-文档 Service 实现类 * AI 知识库-文档 Service 实现类
@ -42,9 +45,14 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
@Resource @Resource
private TokenTextSplitter tokenTextSplitter; private TokenTextSplitter tokenTextSplitter;
@Resource @Resource
private TokenCountEstimator TOKEN_COUNT_ESTIMATOR; private TokenCountEstimator tokenCountEstimator;
@Resource @Resource
private RedisVectorStore vectorStore; private AiApiKeyService apiKeyService;
@Resource
private AiKnowledgeService knowledgeService;
@Resource
private AiChatModelService chatModelService;
// TODO 芋艿需要 review 代码格式 // TODO 芋艿需要 review 代码格式
@ -53,18 +61,18 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) { public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) {
// 1.1 下载文档 // 1.1 下载文档
String url = createReqVO.getUrl(); String url = createReqVO.getUrl();
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url));
// 1.2 加载文档 // 1.2 加载文档
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url));
List<Document> documents = loader.get(); List<Document> documents = loader.get();
Document document = CollUtil.getFirst(documents); Document document = CollUtil.getFirst(documents);
// TODO @xin是不是不存在就抛出异常呀厚泽 return String content = document.getContent();
Integer tokens = Objects.nonNull(document) ? TOKEN_COUNT_ESTIMATOR.estimate(document.getContent()) : 0; Integer tokens = tokenCountEstimator.estimate(content);
Integer wordCount = Objects.nonNull(document) ? document.getContent().length() : 0; Integer wordCount = content.length();
// 1.3 文档记录入库
AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class) AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class)
.setTokens(tokens).setWordCount(wordCount) .setTokens(tokens).setWordCount(wordCount)
.setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus()); .setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus());
// 1.2 文档记录入库
documentMapper.insert(documentDO); documentMapper.insert(documentDO);
Long documentId = documentDO.getId(); Long documentId = documentDO.getId();
if (CollUtil.isEmpty(documents)) { if (CollUtil.isEmpty(documents)) {
@ -75,11 +83,16 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
List<Document> segments = tokenTextSplitter.apply(documents); List<Document> segments = tokenTextSplitter.apply(documents);
// 2.2 分段内容入库 // 2.2 分段内容入库
List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments, List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments,
segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId) segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId).setKnowledgeId(createReqVO.getKnowledgeId())
.setTokens(TOKEN_COUNT_ESTIMATOR.estimate(segment.getContent())).setWordCount(segment.getContent().length()) .setTokens(tokenCountEstimator.estimate(segment.getContent())).setWordCount(segment.getContent().length())
.setStatus(CommonStatusEnum.ENABLE.getStatus())); .setStatus(CommonStatusEnum.ENABLE.getStatus()));
segmentMapper.insertBatch(segmentDOList); segmentMapper.insertBatch(segmentDOList);
// 3 向量化并存储
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 3.1 获取向量存储实例
VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
// 3.2 向量化并存储
vectorStore.add(segments); vectorStore.add(segments);
return documentId; return documentId;
} }

View File

@ -1,6 +1,8 @@
package cn.iocoder.yudao.module.ai.service.knowledge; package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeCreateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeCreateMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeUpdateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeUpdateMyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
/** /**
* AI 知识库-基础信息 Service 接口 * AI 知识库-基础信息 Service 接口
@ -27,4 +29,12 @@ public interface AiKnowledgeService {
*/ */
void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId); void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId);
/**
* 校验知识库是否存在
*
* @param id 记录编号
*/
AiKnowledgeDO validateKnowledgeExists(Long id);
} }

View File

@ -29,7 +29,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
private AiChatModelService chatModalService; private AiChatModelService chatModalService;
@Resource @Resource
private AiKnowledgeMapper knowledgeBaseMapper; private AiKnowledgeMapper knowledgeMapper;
@Override @Override
public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) { public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) {
@ -39,7 +39,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
// 2. 插入知识库 // 2. 插入知识库
AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class) AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
.setModel(model.getModel()).setUserId(userId).setStatus(CommonStatusEnum.ENABLE.getStatus()); .setModel(model.getModel()).setUserId(userId).setStatus(CommonStatusEnum.ENABLE.getStatus());
knowledgeBaseMapper.insert(knowledgeBase); knowledgeMapper.insert(knowledgeBase);
return knowledgeBase.getId(); return knowledgeBase.getId();
} }
@ -56,11 +56,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
// 2. 更新知识库 // 2. 更新知识库
AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class); AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class);
updateDO.setModel(model.getModel()); updateDO.setModel(model.getModel());
knowledgeBaseMapper.updateById(updateDO); knowledgeMapper.updateById(updateDO);
} }
@Override
public AiKnowledgeDO validateKnowledgeExists(Long id) { public AiKnowledgeDO validateKnowledgeExists(Long id) {
AiKnowledgeDO knowledgeBase = knowledgeBaseMapper.selectById(id); AiKnowledgeDO knowledgeBase = knowledgeMapper.selectById(id);
if (knowledgeBase == null) { if (knowledgeBase == null) {
throw exception(KNOWLEDGE_NOT_EXISTS); throw exception(KNOWLEDGE_NOT_EXISTS);
} }

View File

@ -9,7 +9,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import java.util.List; import java.util.List;
@ -83,6 +85,14 @@ public interface AiApiKeyService {
*/ */
ChatModel getChatModel(Long id); ChatModel getChatModel(Long id);
/**
* 获得 EmbeddingModel 对象
*
* @param id 编号
* @return EmbeddingModel 对象
*/
EmbeddingModel getEmbeddingModel(Long id);
/** /**
* 获得 ImageModel 对象 * 获得 ImageModel 对象
* *
@ -111,4 +121,12 @@ public interface AiApiKeyService {
*/ */
SunoApi getSunoApi(); SunoApi getSunoApi();
/**
* 获得 vector 对象
*
* @param id 编号
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(Long id);
} }

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
@ -13,7 +14,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
@ -36,6 +39,8 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource @Resource
private AiModelFactory modelFactory; private AiModelFactory modelFactory;
@Resource
private AiVectorFactory vectorFactory;
@Override @Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) { public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
@ -104,6 +109,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl()); return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
} }
@Override
public EmbeddingModel getEmbeddingModel(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override @Override
public ImageModel getImageModel(AiPlatformEnum platform) { public ImageModel getImageModel(AiPlatformEnum platform) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
@ -132,4 +144,11 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
} }
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
} }
@Override
public VectorStore getOrCreateVectorStore(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
}
} }

View File

@ -2,6 +2,8 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
@ -10,22 +12,15 @@ import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration; import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator; import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.tokenizer.TokenCountEstimator; import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.Lazy;
import redis.clients.jedis.JedisPooled;
/** /**
* 芋道 AI 自动配置 * 芋道 AI 自动配置
@ -43,6 +38,12 @@ public class YudaoAiAutoConfiguration {
return new AiModelFactoryImpl(); return new AiModelFactoryImpl();
} }
@Bean
public AiVectorFactory aiVectorFactory() {
return new AiVectorFactoryImpl();
}
// ========== 各种 AI Client 创建 ========== // ========== 各种 AI Client 创建 ==========
@Bean @Bean
@ -85,30 +86,31 @@ public class YudaoAiAutoConfiguration {
} }
// ========== rag 相关 ========== // ========== rag 相关 ==========
@Bean // TODO @xin 免费版本
@Lazy // TODO 芋艿临时注释避免无法启动 // @Bean
public EmbeddingModel transformersEmbeddingClient() { // @Lazy // TODO 芋艿临时注释避免无法启动
return new TransformersEmbeddingModel(MetadataMode.EMBED); // public EmbeddingModel transformersEmbeddingClient() {
} // return new TransformersEmbeddingModel(MetadataMode.EMBED);
// }
/** /**
* TODO @xin 抽离出去根据具体模型走 * TODO @xin 默认版本先不弄目前都先取对应的 EmbeddingModel
*/ */
@Bean // @Bean
@Lazy // TODO 芋艿临时注释避免无法启动 // @Lazy // TODO 芋艿临时注释避免无法启动
public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties, // public RedisVectorStore vectorStore(TongYiTextEmbeddingModel tongYiTextEmbeddingModel, RedisVectorStoreProperties properties,
RedisProperties redisProperties) { // RedisProperties redisProperties) {
var config = RedisVectorStore.RedisVectorStoreConfig.builder() // var config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withIndexName(properties.getIndex()) // .withIndexName(properties.getIndex())
.withPrefix(properties.getPrefix()) // .withPrefix(properties.getPrefix())
.build(); // .build();
//
RedisVectorStore redisVectorStore = new RedisVectorStore(config, transformersEmbeddingModel, // RedisVectorStore redisVectorStore = new RedisVectorStore(config, tongYiTextEmbeddingModel,
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), // new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
properties.isInitializeSchema()); // properties.isInitializeSchema());
redisVectorStore.afterPropertiesSet(); // redisVectorStore.afterPropertiesSet();
return redisVectorStore; // return redisVectorStore;
} // }
@Bean @Bean
@Lazy // TODO 芋艿临时注释避免无法启动 @Lazy // TODO 芋艿临时注释避免无法启动

View File

@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
/** /**
@ -25,6 +26,18 @@ public interface AiModelFactory {
*/ */
ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url); ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于指定配置获得 EmbeddingModel 对象
* <p>
* 如果不存在则进行创建
*
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @return ChatModel 对象
*/
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
/** /**
* 基于默认配置获得 ChatModel 对象 * 基于默认配置获得 ChatModel 对象
* *

View File

@ -21,6 +21,7 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties; import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation; import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.embeddings.TextEmbedding;
import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClient;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
@ -37,6 +38,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.OllamaChatModel;
@ -97,6 +99,21 @@ public class AiModelFactoryImpl implements AiModelFactory {
}); });
} }
@Override
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
// TODO @xin 先测试一个
switch (platform) {
case TONG_YI:
return buildTongYiEmbeddingModel(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
});
}
@Override @Override
public ChatModel getDefaultChatModel(AiPlatformEnum platform) { public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
@ -239,7 +256,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
/** /**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel( * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
* ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)} *ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
*/ */
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) { private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
@ -249,7 +266,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
/** /**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel( * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
* ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} *ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/ */
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) { private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
@ -315,4 +332,15 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new StabilityAiImageModel(stabilityAiApi); return new StabilityAiImageModel(stabilityAiApi);
} }
// ========== 各种创建 EmbeddingModel 的方法 ==========
/**
* 可参考 {@link TongYiAutoConfiguration#tongYiTextEmbeddingClient(TextEmbedding, TongYiConnectionProperties)}
*/
private EmbeddingModel buildTongYiEmbeddingModel(String apiKey) {
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
connectionProperties.setApiKey(apiKey);
return new TongYiAutoConfiguration().tongYiTextEmbeddingClient(SpringUtil.getBean(TextEmbedding.class), connectionProperties);
}
} }

View File

@ -0,0 +1,27 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
/**
* AI Vector 模型工厂的接口类
* @author xiaoxin
*/
public interface AiVectorFactory {
/**
* 基于指定配置获得 VectorStore 对象
* <p>
* 如果不存在则进行创建
*
* @param embeddingModel 嵌入模型
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
}

View File

@ -0,0 +1,51 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.hutool.core.lang.Singleton;
import cn.hutool.core.lang.func.Func0;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import redis.clients.jedis.JedisPooled;
/**
* AI Vector 模型工厂的实现类
* 使用 redisVectorStore 实现 VectorStore
*
* @author xiaoxin
*/
public class AiVectorFactoryImpl implements AiVectorFactory {
@Override
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
// TODO 芋艿 @xin 这两个配置取哪好呢
// TODO 不同模型的向量维度可能会不一样目前看貌似是以 index 来做区分的维度不一样存不到一个 index
String index = "default-index";
String prefix = "default:";
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withIndexName(index)
.withPrefix(prefix)
.build();
RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
true);
redisVectorStore.afterPropertiesSet();
return redisVectorStore;
});
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) {
return clazz.getName();
}
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
}
}

View File

@ -19,6 +19,7 @@ import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig; import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig;
import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration; import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration;
@ -38,7 +39,7 @@ import redis.clients.jedis.JedisPooled;
*/ */
@AutoConfiguration(after = RedisAutoConfiguration.class) @AutoConfiguration(after = RedisAutoConfiguration.class)
@ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class}) @ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class})
//@ConditionalOnBean(JedisConnectionFactory.class) @ConditionalOnBean(JedisConnectionFactory.class)
@EnableConfigurationProperties(RedisVectorStoreProperties.class) @EnableConfigurationProperties(RedisVectorStoreProperties.class)
public class RedisVectorStoreAutoConfiguration { public class RedisVectorStoreAutoConfiguration {