From f97fb0a8fe9e7987233e0f81940b3521544f2031 Mon Sep 17 00:00:00 2001 From: xiaoxin <718949661@qq.com> Date: Thu, 29 Aug 2024 14:13:37 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E6=96=B0=E5=A2=9E=E3=80=91AI=20?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93:=20AiVectorFactory=20=E8=B4=9F?= =?UTF-8?q?=E8=B4=A3=E7=AE=A1=E7=90=86=E4=B8=8D=E5=90=8C=20EmbeddingModel?= =?UTF-8?q?=20=E5=AF=B9=E5=BA=94=E7=9A=84=20VectorStore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dataobject/knowledge/AiKnowledgeDO.java | 2 +- .../knowledge/AiKnowledgeSegmentDO.java | 8 ++- .../AiKnowledgeDocumentServiceImpl.java | 37 ++++++++---- .../service/knowledge/AiKnowledgeService.java | 14 ++++- .../knowledge/AiKnowledgeServiceImpl.java | 9 +-- .../ai/service/model/AiApiKeyService.java | 18 ++++++ .../ai/service/model/AiApiKeyServiceImpl.java | 19 ++++++ .../ai/config/YudaoAiAutoConfiguration.java | 58 ++++++++++--------- .../ai/core/factory/AiModelFactory.java | 13 +++++ .../ai/core/factory/AiModelFactoryImpl.java | 32 +++++++++- .../ai/core/factory/AiVectorFactory.java | 27 +++++++++ .../ai/core/factory/AiVectorFactoryImpl.java | 51 ++++++++++++++++ .../RedisVectorStoreAutoConfiguration.java | 3 +- 13 files changed, 239 insertions(+), 52 deletions(-) create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorFactory.java create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorFactoryImpl.java diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java index 89e7486dc..756d8cdb3 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java @@ -45,7 +45,7 @@ public class AiKnowledgeDO extends BaseDO { @TableField(typeHandler = JacksonTypeHandler.class) private List visibilityPermissions; /** - * 嵌入模型编号,高质量模式时维护 + * 嵌入模型编号 */ private Long modelId; /** diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java index 2032bfd5e..6d5da06ea 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java @@ -24,10 +24,14 @@ public class AiKnowledgeSegmentDO extends BaseDO { * 向量库的编号 */ private String vectorId; - // TODO @新:knowledgeId 加个,会方便点 + /** + * 知识库编号 + * 关联 {@link AiKnowledgeDO#getId()} + */ + private Long knowledgeId; /** * 文档编号 - * + *

* 关联 {@link AiKnowledgeDocumentDO#getId()} */ private Long documentId; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java index 2af8b9d90..9be3dd1af 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java @@ -6,24 +6,27 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeDocumentCreateReqVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum; +import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; +import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.document.Document; import org.springframework.ai.reader.tika.TikaDocumentReader; import org.springframework.ai.tokenizer.TokenCountEstimator; import org.springframework.ai.transformer.splitter.TokenTextSplitter; -import org.springframework.ai.vectorstore.RedisVectorStore; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.core.io.ByteArrayResource; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import java.util.List; -import java.util.Objects; /** * AI 知识库-文档 Service 实现类 @@ -42,9 +45,14 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic @Resource private TokenTextSplitter tokenTextSplitter; @Resource - private TokenCountEstimator TOKEN_COUNT_ESTIMATOR; + private TokenCountEstimator tokenCountEstimator; + @Resource - private RedisVectorStore vectorStore; + private AiApiKeyService apiKeyService; + @Resource + private AiKnowledgeService knowledgeService; + @Resource + private AiChatModelService chatModelService; // TODO 芋艿:需要 review 下,代码格式; @@ -53,18 +61,18 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) { // 1.1 下载文档 String url = createReqVO.getUrl(); - TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url)); // 1.2 加载文档 + TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url)); List documents = loader.get(); Document document = CollUtil.getFirst(documents); - // TODO @xin:是不是不存在,就抛出异常呀;厚泽 return 呀; - Integer tokens = Objects.nonNull(document) ? TOKEN_COUNT_ESTIMATOR.estimate(document.getContent()) : 0; - Integer wordCount = Objects.nonNull(document) ? document.getContent().length() : 0; + String content = document.getContent(); + Integer tokens = tokenCountEstimator.estimate(content); + Integer wordCount = content.length(); + // 1.3 文档记录入库 AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class) .setTokens(tokens).setWordCount(wordCount) .setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus()); - // 1.2 文档记录入库 documentMapper.insert(documentDO); Long documentId = documentDO.getId(); if (CollUtil.isEmpty(documents)) { @@ -75,11 +83,16 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic List segments = tokenTextSplitter.apply(documents); // 2.2 分段内容入库 List segmentDOList = CollectionUtils.convertList(segments, - segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId) - .setTokens(TOKEN_COUNT_ESTIMATOR.estimate(segment.getContent())).setWordCount(segment.getContent().length()) + segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId).setKnowledgeId(createReqVO.getKnowledgeId()) + .setTokens(tokenCountEstimator.estimate(segment.getContent())).setWordCount(segment.getContent().length()) .setStatus(CommonStatusEnum.ENABLE.getStatus())); segmentMapper.insertBatch(segmentDOList); - // 3 向量化并存储 + + AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId()); + AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId()); + // 3.1 获取向量存储实例 + VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId()); + // 3.2 向量化并存储 vectorStore.add(segments); return documentId; } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java index 91b0c9b3e..bf7e8886a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java @@ -1,6 +1,8 @@ package cn.iocoder.yudao.module.ai.service.knowledge; + import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeCreateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeUpdateMyReqVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; /** * AI 知识库-基础信息 Service 接口 @@ -13,7 +15,7 @@ public interface AiKnowledgeService { * 创建【我的】知识库 * * @param createReqVO 创建信息 - * @param userId 用户编号 + * @param userId 用户编号 * @return 编号 */ Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId); @@ -23,8 +25,16 @@ public interface AiKnowledgeService { * 创建【我的】知识库 * * @param updateReqVO 更新信息 - * @param userId 用户编号 + * @param userId 用户编号 */ void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId); + + /** + * 校验知识库是否存在 + * + * @param id 记录编号 + */ + AiKnowledgeDO validateKnowledgeExists(Long id); + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java index 5889bcef7..70442936e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java @@ -29,7 +29,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService { private AiChatModelService chatModalService; @Resource - private AiKnowledgeMapper knowledgeBaseMapper; + private AiKnowledgeMapper knowledgeMapper; @Override public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) { @@ -39,7 +39,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService { // 2. 插入知识库 AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class) .setModel(model.getModel()).setUserId(userId).setStatus(CommonStatusEnum.ENABLE.getStatus()); - knowledgeBaseMapper.insert(knowledgeBase); + knowledgeMapper.insert(knowledgeBase); return knowledgeBase.getId(); } @@ -56,11 +56,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService { // 2. 更新知识库 AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class); updateDO.setModel(model.getModel()); - knowledgeBaseMapper.updateById(updateDO); + knowledgeMapper.updateById(updateDO); } + @Override public AiKnowledgeDO validateKnowledgeExists(Long id) { - AiKnowledgeDO knowledgeBase = knowledgeBaseMapper.selectById(id); + AiKnowledgeDO knowledgeBase = knowledgeMapper.selectById(id); if (knowledgeBase == null) { throw exception(KNOWLEDGE_NOT_EXISTS); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java index fe8fdd194..603325da6 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java @@ -9,7 +9,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import jakarta.validation.Valid; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.image.ImageModel; +import org.springframework.ai.vectorstore.VectorStore; import java.util.List; @@ -83,6 +85,14 @@ public interface AiApiKeyService { */ ChatModel getChatModel(Long id); + /** + * 获得 EmbeddingModel 对象 + * + * @param id 编号 + * @return EmbeddingModel 对象 + */ + EmbeddingModel getEmbeddingModel(Long id); + /** * 获得 ImageModel 对象 * @@ -111,4 +121,12 @@ public interface AiApiKeyService { */ SunoApi getSunoApi(); + /** + * 获得 vector 对象 + * + * @param id 编号 + * @return VectorStore 对象 + */ + VectorStore getOrCreateVectorStore(Long id); + } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java index 590b10a4c..d3e9b7cfb 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java @@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.model; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; +import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; @@ -13,7 +14,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import jakarta.annotation.Resource; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.image.ImageModel; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; @@ -36,6 +39,8 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { @Resource private AiModelFactory modelFactory; + @Resource + private AiVectorFactory vectorFactory; @Override public Long createApiKey(AiApiKeySaveReqVO createReqVO) { @@ -104,6 +109,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl()); } + @Override + public EmbeddingModel getEmbeddingModel(Long id) { + AiApiKeyDO apiKey = validateApiKey(id); + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); + return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl()); + } + @Override public ImageModel getImageModel(AiPlatformEnum platform) { AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); @@ -132,4 +144,11 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { } return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); } + + @Override + public VectorStore getOrCreateVectorStore(Long id) { + AiApiKeyDO apiKey = validateApiKey(id); + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); + return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl()); + } } \ No newline at end of file diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java index 8566a0941..50eacd00e 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java @@ -2,6 +2,8 @@ package cn.iocoder.yudao.framework.ai.config; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl; +import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory; +import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactoryImpl; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; @@ -10,22 +12,15 @@ import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions; import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties; -import org.springframework.ai.document.MetadataMode; -import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator; import org.springframework.ai.tokenizer.TokenCountEstimator; import org.springframework.ai.transformer.splitter.TokenTextSplitter; -import org.springframework.ai.transformers.TransformersEmbeddingModel; -import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.boot.autoconfigure.data.redis.RedisProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Lazy; -import redis.clients.jedis.JedisPooled; /** * 芋道 AI 自动配置 @@ -43,6 +38,12 @@ public class YudaoAiAutoConfiguration { return new AiModelFactoryImpl(); } + @Bean + public AiVectorFactory aiVectorFactory() { + return new AiVectorFactoryImpl(); + } + + // ========== 各种 AI Client 创建 ========== @Bean @@ -85,30 +86,31 @@ public class YudaoAiAutoConfiguration { } // ========== rag 相关 ========== - @Bean - @Lazy // TODO 芋艿:临时注释,避免无法启动 - public EmbeddingModel transformersEmbeddingClient() { - return new TransformersEmbeddingModel(MetadataMode.EMBED); - } + // TODO @xin 免费版本 +// @Bean +// @Lazy // TODO 芋艿:临时注释,避免无法启动」 +// public EmbeddingModel transformersEmbeddingClient() { +// return new TransformersEmbeddingModel(MetadataMode.EMBED); +// } /** - * TODO @xin 抽离出去,根据具体模型走 + * TODO @xin 默认版本先不弄,目前都先取对应的 EmbeddingModel */ - @Bean - @Lazy // TODO 芋艿:临时注释,避免无法启动 - public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties, - RedisProperties redisProperties) { - var config = RedisVectorStore.RedisVectorStoreConfig.builder() - .withIndexName(properties.getIndex()) - .withPrefix(properties.getPrefix()) - .build(); - - RedisVectorStore redisVectorStore = new RedisVectorStore(config, transformersEmbeddingModel, - new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), - properties.isInitializeSchema()); - redisVectorStore.afterPropertiesSet(); - return redisVectorStore; - } +// @Bean +// @Lazy // TODO 芋艿:临时注释,避免无法启动 +// public RedisVectorStore vectorStore(TongYiTextEmbeddingModel tongYiTextEmbeddingModel, RedisVectorStoreProperties properties, +// RedisProperties redisProperties) { +// var config = RedisVectorStore.RedisVectorStoreConfig.builder() +// .withIndexName(properties.getIndex()) +// .withPrefix(properties.getPrefix()) +// .build(); +// +// RedisVectorStore redisVectorStore = new RedisVectorStore(config, tongYiTextEmbeddingModel, +// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), +// properties.isInitializeSchema()); +// redisVectorStore.afterPropertiesSet(); +// return redisVectorStore; +// } @Bean @Lazy // TODO 芋艿:临时注释,避免无法启动 diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java index b6d7b3dd0..6f628ea4d 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java @@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.image.ImageModel; /** @@ -25,6 +26,18 @@ public interface AiModelFactory { */ ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url); + /** + * 基于指定配置,获得 EmbeddingModel 对象 + *

+ * 如果不存在,则进行创建 + * + * @param platform 平台 + * @param apiKey API KEY + * @param url API URL + * @return ChatModel 对象 + */ + EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url); + /** * 基于默认配置,获得 ChatModel 对象 * diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index c9b04dc1e..5c3524e66 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -21,6 +21,7 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel; import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties; import com.alibaba.dashscope.aigc.generation.Generation; import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; +import com.alibaba.dashscope.embeddings.TextEmbedding; import com.azure.ai.openai.OpenAIClient; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties; @@ -37,6 +38,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.image.ImageModel; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.ollama.OllamaChatModel; @@ -97,6 +99,21 @@ public class AiModelFactoryImpl implements AiModelFactory { }); } + @Override + public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) { + String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url); + return Singleton.get(cacheKey, (Func0) () -> { + // TODO @xin 先测试一个 + switch (platform) { + case TONG_YI: + return buildTongYiEmbeddingModel(apiKey); + default: + throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); + } + }); + } + + @Override public ChatModel getDefaultChatModel(AiPlatformEnum platform) { //noinspection EnhancedSwitchMigration @@ -239,7 +256,7 @@ public class AiModelFactoryImpl implements AiModelFactory { /** * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel( - * ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)} + *ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)} */ private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) { url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); @@ -249,7 +266,7 @@ public class AiModelFactoryImpl implements AiModelFactory { /** * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel( - * ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} + *ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} */ private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) { url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); @@ -315,4 +332,15 @@ public class AiModelFactoryImpl implements AiModelFactory { return new StabilityAiImageModel(stabilityAiApi); } + // ========== 各种创建 EmbeddingModel 的方法 ========== + + /** + * 可参考 {@link TongYiAutoConfiguration#tongYiTextEmbeddingClient(TextEmbedding, TongYiConnectionProperties)} + */ + private EmbeddingModel buildTongYiEmbeddingModel(String apiKey) { + TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties(); + connectionProperties.setApiKey(apiKey); + return new TongYiAutoConfiguration().tongYiTextEmbeddingClient(SpringUtil.getBean(TextEmbedding.class), connectionProperties); + } + } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorFactory.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorFactory.java new file mode 100644 index 000000000..5e43c9bab --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorFactory.java @@ -0,0 +1,27 @@ +package cn.iocoder.yudao.framework.ai.core.factory; + +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.VectorStore; + +/** + * AI Vector 模型工厂的接口类 + * @author xiaoxin + */ +public interface AiVectorFactory { + + + /** + * 基于指定配置,获得 VectorStore 对象 + *

+ * 如果不存在,则进行创建 + * + * @param embeddingModel 嵌入模型 + * @param platform 平台 + * @param apiKey API KEY + * @param url API URL + * @return VectorStore 对象 + */ + VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url); + +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorFactoryImpl.java new file mode 100644 index 000000000..d16b595f3 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorFactoryImpl.java @@ -0,0 +1,51 @@ +package cn.iocoder.yudao.framework.ai.core.factory; + +import cn.hutool.core.lang.Singleton; +import cn.hutool.core.lang.func.Func0; +import cn.hutool.core.util.ArrayUtil; +import cn.hutool.core.util.StrUtil; +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.common.util.spring.SpringUtils; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.RedisVectorStore; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.boot.autoconfigure.data.redis.RedisProperties; +import redis.clients.jedis.JedisPooled; + +/** + * AI Vector 模型工厂的实现类 + * 使用 redisVectorStore 实现 VectorStore + * + * @author xiaoxin + */ +public class AiVectorFactoryImpl implements AiVectorFactory { + + @Override + public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) { + String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url); + return Singleton.get(cacheKey, (Func0) () -> { + // TODO 芋艿 @xin 这两个配置取哪好呢 + // TODO 不同模型的向量维度可能会不一样,目前看貌似是以 index 来做区分的,维度不一样存不到一个 index 上 + String index = "default-index"; + String prefix = "default:"; + var config = RedisVectorStore.RedisVectorStoreConfig.builder() + .withIndexName(index) + .withPrefix(prefix) + .build(); + RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class); + RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel, + new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), + true); + redisVectorStore.afterPropertiesSet(); + return redisVectorStore; + }); + } + + + private static String buildClientCacheKey(Class clazz, Object... params) { + if (ArrayUtil.isEmpty(params)) { + return clazz.getName(); + } + return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_")); + } +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java index 615b05f78..a72d50c4a 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java @@ -19,6 +19,7 @@ import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration; @@ -38,7 +39,7 @@ import redis.clients.jedis.JedisPooled; */ @AutoConfiguration(after = RedisAutoConfiguration.class) @ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class}) -//@ConditionalOnBean(JedisConnectionFactory.class) +@ConditionalOnBean(JedisConnectionFactory.class) @EnableConfigurationProperties(RedisVectorStoreProperties.class) public class RedisVectorStoreAutoConfiguration {