mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2024-11-22 23:31:52 +08:00
【新增】AI 知识库: AiVectorFactory 负责管理不同 EmbeddingModel 对应的 VectorStore
This commit is contained in:
parent
024109dac9
commit
f97fb0a8fe
@ -45,7 +45,7 @@ public class AiKnowledgeDO extends BaseDO {
|
|||||||
@TableField(typeHandler = JacksonTypeHandler.class)
|
@TableField(typeHandler = JacksonTypeHandler.class)
|
||||||
private List<Long> visibilityPermissions;
|
private List<Long> visibilityPermissions;
|
||||||
/**
|
/**
|
||||||
* 嵌入模型编号,高质量模式时维护
|
* 嵌入模型编号
|
||||||
*/
|
*/
|
||||||
private Long modelId;
|
private Long modelId;
|
||||||
/**
|
/**
|
||||||
|
@ -24,10 +24,14 @@ public class AiKnowledgeSegmentDO extends BaseDO {
|
|||||||
* 向量库的编号
|
* 向量库的编号
|
||||||
*/
|
*/
|
||||||
private String vectorId;
|
private String vectorId;
|
||||||
// TODO @新:knowledgeId 加个,会方便点
|
/**
|
||||||
|
* 知识库编号
|
||||||
|
* 关联 {@link AiKnowledgeDO#getId()}
|
||||||
|
*/
|
||||||
|
private Long knowledgeId;
|
||||||
/**
|
/**
|
||||||
* 文档编号
|
* 文档编号
|
||||||
*
|
* <p>
|
||||||
* 关联 {@link AiKnowledgeDocumentDO#getId()}
|
* 关联 {@link AiKnowledgeDocumentDO#getId()}
|
||||||
*/
|
*/
|
||||||
private Long documentId;
|
private Long documentId;
|
||||||
|
@ -6,24 +6,27 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
|||||||
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
|
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
|
||||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeDocumentCreateReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeDocumentCreateReqVO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper;
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
|
||||||
import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
|
import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
|
||||||
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||||
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.document.Document;
|
import org.springframework.ai.document.Document;
|
||||||
import org.springframework.ai.reader.tika.TikaDocumentReader;
|
import org.springframework.ai.reader.tika.TikaDocumentReader;
|
||||||
import org.springframework.ai.tokenizer.TokenCountEstimator;
|
import org.springframework.ai.tokenizer.TokenCountEstimator;
|
||||||
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
||||||
import org.springframework.ai.vectorstore.RedisVectorStore;
|
import org.springframework.ai.vectorstore.VectorStore;
|
||||||
import org.springframework.core.io.ByteArrayResource;
|
import org.springframework.core.io.ByteArrayResource;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AI 知识库-文档 Service 实现类
|
* AI 知识库-文档 Service 实现类
|
||||||
@ -42,9 +45,14 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
|
|||||||
@Resource
|
@Resource
|
||||||
private TokenTextSplitter tokenTextSplitter;
|
private TokenTextSplitter tokenTextSplitter;
|
||||||
@Resource
|
@Resource
|
||||||
private TokenCountEstimator TOKEN_COUNT_ESTIMATOR;
|
private TokenCountEstimator tokenCountEstimator;
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private RedisVectorStore vectorStore;
|
private AiApiKeyService apiKeyService;
|
||||||
|
@Resource
|
||||||
|
private AiKnowledgeService knowledgeService;
|
||||||
|
@Resource
|
||||||
|
private AiChatModelService chatModelService;
|
||||||
|
|
||||||
|
|
||||||
// TODO 芋艿:需要 review 下,代码格式;
|
// TODO 芋艿:需要 review 下,代码格式;
|
||||||
@ -53,18 +61,18 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
|
|||||||
public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) {
|
public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) {
|
||||||
// 1.1 下载文档
|
// 1.1 下载文档
|
||||||
String url = createReqVO.getUrl();
|
String url = createReqVO.getUrl();
|
||||||
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url));
|
|
||||||
// 1.2 加载文档
|
// 1.2 加载文档
|
||||||
|
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url));
|
||||||
List<Document> documents = loader.get();
|
List<Document> documents = loader.get();
|
||||||
Document document = CollUtil.getFirst(documents);
|
Document document = CollUtil.getFirst(documents);
|
||||||
// TODO @xin:是不是不存在,就抛出异常呀;厚泽 return 呀;
|
String content = document.getContent();
|
||||||
Integer tokens = Objects.nonNull(document) ? TOKEN_COUNT_ESTIMATOR.estimate(document.getContent()) : 0;
|
Integer tokens = tokenCountEstimator.estimate(content);
|
||||||
Integer wordCount = Objects.nonNull(document) ? document.getContent().length() : 0;
|
Integer wordCount = content.length();
|
||||||
|
|
||||||
|
// 1.3 文档记录入库
|
||||||
AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class)
|
AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class)
|
||||||
.setTokens(tokens).setWordCount(wordCount)
|
.setTokens(tokens).setWordCount(wordCount)
|
||||||
.setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus());
|
.setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus());
|
||||||
// 1.2 文档记录入库
|
|
||||||
documentMapper.insert(documentDO);
|
documentMapper.insert(documentDO);
|
||||||
Long documentId = documentDO.getId();
|
Long documentId = documentDO.getId();
|
||||||
if (CollUtil.isEmpty(documents)) {
|
if (CollUtil.isEmpty(documents)) {
|
||||||
@ -75,11 +83,16 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
|
|||||||
List<Document> segments = tokenTextSplitter.apply(documents);
|
List<Document> segments = tokenTextSplitter.apply(documents);
|
||||||
// 2.2 分段内容入库
|
// 2.2 分段内容入库
|
||||||
List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments,
|
List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments,
|
||||||
segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId)
|
segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId).setKnowledgeId(createReqVO.getKnowledgeId())
|
||||||
.setTokens(TOKEN_COUNT_ESTIMATOR.estimate(segment.getContent())).setWordCount(segment.getContent().length())
|
.setTokens(tokenCountEstimator.estimate(segment.getContent())).setWordCount(segment.getContent().length())
|
||||||
.setStatus(CommonStatusEnum.ENABLE.getStatus()));
|
.setStatus(CommonStatusEnum.ENABLE.getStatus()));
|
||||||
segmentMapper.insertBatch(segmentDOList);
|
segmentMapper.insertBatch(segmentDOList);
|
||||||
// 3 向量化并存储
|
|
||||||
|
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
|
||||||
|
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
|
||||||
|
// 3.1 获取向量存储实例
|
||||||
|
VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
|
||||||
|
// 3.2 向量化并存储
|
||||||
vectorStore.add(segments);
|
vectorStore.add(segments);
|
||||||
return documentId;
|
return documentId;
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package cn.iocoder.yudao.module.ai.service.knowledge;
|
package cn.iocoder.yudao.module.ai.service.knowledge;
|
||||||
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeCreateMyReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeCreateMyReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeUpdateMyReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeUpdateMyReqVO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AI 知识库-基础信息 Service 接口
|
* AI 知识库-基础信息 Service 接口
|
||||||
@ -13,7 +15,7 @@ public interface AiKnowledgeService {
|
|||||||
* 创建【我的】知识库
|
* 创建【我的】知识库
|
||||||
*
|
*
|
||||||
* @param createReqVO 创建信息
|
* @param createReqVO 创建信息
|
||||||
* @param userId 用户编号
|
* @param userId 用户编号
|
||||||
* @return 编号
|
* @return 编号
|
||||||
*/
|
*/
|
||||||
Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId);
|
Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId);
|
||||||
@ -23,8 +25,16 @@ public interface AiKnowledgeService {
|
|||||||
* 创建【我的】知识库
|
* 创建【我的】知识库
|
||||||
*
|
*
|
||||||
* @param updateReqVO 更新信息
|
* @param updateReqVO 更新信息
|
||||||
* @param userId 用户编号
|
* @param userId 用户编号
|
||||||
*/
|
*/
|
||||||
void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId);
|
void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 校验知识库是否存在
|
||||||
|
*
|
||||||
|
* @param id 记录编号
|
||||||
|
*/
|
||||||
|
AiKnowledgeDO validateKnowledgeExists(Long id);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
|
|||||||
private AiChatModelService chatModalService;
|
private AiChatModelService chatModalService;
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private AiKnowledgeMapper knowledgeBaseMapper;
|
private AiKnowledgeMapper knowledgeMapper;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) {
|
public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) {
|
||||||
@ -39,7 +39,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
|
|||||||
// 2. 插入知识库
|
// 2. 插入知识库
|
||||||
AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
|
AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
|
||||||
.setModel(model.getModel()).setUserId(userId).setStatus(CommonStatusEnum.ENABLE.getStatus());
|
.setModel(model.getModel()).setUserId(userId).setStatus(CommonStatusEnum.ENABLE.getStatus());
|
||||||
knowledgeBaseMapper.insert(knowledgeBase);
|
knowledgeMapper.insert(knowledgeBase);
|
||||||
return knowledgeBase.getId();
|
return knowledgeBase.getId();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,11 +56,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
|
|||||||
// 2. 更新知识库
|
// 2. 更新知识库
|
||||||
AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class);
|
AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class);
|
||||||
updateDO.setModel(model.getModel());
|
updateDO.setModel(model.getModel());
|
||||||
knowledgeBaseMapper.updateById(updateDO);
|
knowledgeMapper.updateById(updateDO);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public AiKnowledgeDO validateKnowledgeExists(Long id) {
|
public AiKnowledgeDO validateKnowledgeExists(Long id) {
|
||||||
AiKnowledgeDO knowledgeBase = knowledgeBaseMapper.selectById(id);
|
AiKnowledgeDO knowledgeBase = knowledgeMapper.selectById(id);
|
||||||
if (knowledgeBase == null) {
|
if (knowledgeBase == null) {
|
||||||
throw exception(KNOWLEDGE_NOT_EXISTS);
|
throw exception(KNOWLEDGE_NOT_EXISTS);
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
|
|||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||||
import jakarta.validation.Valid;
|
import jakarta.validation.Valid;
|
||||||
import org.springframework.ai.chat.model.ChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
|
import org.springframework.ai.embedding.EmbeddingModel;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
|
import org.springframework.ai.vectorstore.VectorStore;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@ -83,6 +85,14 @@ public interface AiApiKeyService {
|
|||||||
*/
|
*/
|
||||||
ChatModel getChatModel(Long id);
|
ChatModel getChatModel(Long id);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获得 EmbeddingModel 对象
|
||||||
|
*
|
||||||
|
* @param id 编号
|
||||||
|
* @return EmbeddingModel 对象
|
||||||
|
*/
|
||||||
|
EmbeddingModel getEmbeddingModel(Long id);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获得 ImageModel 对象
|
* 获得 ImageModel 对象
|
||||||
*
|
*
|
||||||
@ -111,4 +121,12 @@ public interface AiApiKeyService {
|
|||||||
*/
|
*/
|
||||||
SunoApi getSunoApi();
|
SunoApi getSunoApi();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获得 vector 对象
|
||||||
|
*
|
||||||
|
* @param id 编号
|
||||||
|
* @return VectorStore 对象
|
||||||
|
*/
|
||||||
|
VectorStore getOrCreateVectorStore(Long id);
|
||||||
|
|
||||||
}
|
}
|
@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.model;
|
|||||||
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||||
@ -13,7 +14,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
|||||||
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import org.springframework.ai.chat.model.ChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
|
import org.springframework.ai.embedding.EmbeddingModel;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
|
import org.springframework.ai.vectorstore.VectorStore;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.validation.annotation.Validated;
|
import org.springframework.validation.annotation.Validated;
|
||||||
|
|
||||||
@ -36,6 +39,8 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
|
|||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private AiModelFactory modelFactory;
|
private AiModelFactory modelFactory;
|
||||||
|
@Resource
|
||||||
|
private AiVectorFactory vectorFactory;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
|
public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
|
||||||
@ -104,6 +109,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
|
|||||||
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
|
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public EmbeddingModel getEmbeddingModel(Long id) {
|
||||||
|
AiApiKeyDO apiKey = validateApiKey(id);
|
||||||
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
||||||
|
return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ImageModel getImageModel(AiPlatformEnum platform) {
|
public ImageModel getImageModel(AiPlatformEnum platform) {
|
||||||
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
|
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
|
||||||
@ -132,4 +144,11 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
|
|||||||
}
|
}
|
||||||
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
|
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public VectorStore getOrCreateVectorStore(Long id) {
|
||||||
|
AiApiKeyDO apiKey = validateApiKey(id);
|
||||||
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
||||||
|
return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
|
||||||
|
}
|
||||||
}
|
}
|
@ -2,6 +2,8 @@ package cn.iocoder.yudao.framework.ai.config;
|
|||||||
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
||||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
|
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactoryImpl;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
|
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||||
@ -10,22 +12,15 @@ import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
|||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
|
||||||
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
|
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
|
|
||||||
import org.springframework.ai.document.MetadataMode;
|
|
||||||
import org.springframework.ai.embedding.EmbeddingModel;
|
|
||||||
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
|
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
|
||||||
import org.springframework.ai.tokenizer.TokenCountEstimator;
|
import org.springframework.ai.tokenizer.TokenCountEstimator;
|
||||||
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
||||||
import org.springframework.ai.transformers.TransformersEmbeddingModel;
|
|
||||||
import org.springframework.ai.vectorstore.RedisVectorStore;
|
|
||||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||||
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
|
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.context.annotation.Import;
|
import org.springframework.context.annotation.Import;
|
||||||
import org.springframework.context.annotation.Lazy;
|
import org.springframework.context.annotation.Lazy;
|
||||||
import redis.clients.jedis.JedisPooled;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 芋道 AI 自动配置
|
* 芋道 AI 自动配置
|
||||||
@ -43,6 +38,12 @@ public class YudaoAiAutoConfiguration {
|
|||||||
return new AiModelFactoryImpl();
|
return new AiModelFactoryImpl();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public AiVectorFactory aiVectorFactory() {
|
||||||
|
return new AiVectorFactoryImpl();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// ========== 各种 AI Client 创建 ==========
|
// ========== 各种 AI Client 创建 ==========
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
@ -85,30 +86,31 @@ public class YudaoAiAutoConfiguration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ========== rag 相关 ==========
|
// ========== rag 相关 ==========
|
||||||
@Bean
|
// TODO @xin 免费版本
|
||||||
@Lazy // TODO 芋艿:临时注释,避免无法启动
|
// @Bean
|
||||||
public EmbeddingModel transformersEmbeddingClient() {
|
// @Lazy // TODO 芋艿:临时注释,避免无法启动」
|
||||||
return new TransformersEmbeddingModel(MetadataMode.EMBED);
|
// public EmbeddingModel transformersEmbeddingClient() {
|
||||||
}
|
// return new TransformersEmbeddingModel(MetadataMode.EMBED);
|
||||||
|
// }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO @xin 抽离出去,根据具体模型走
|
* TODO @xin 默认版本先不弄,目前都先取对应的 EmbeddingModel
|
||||||
*/
|
*/
|
||||||
@Bean
|
// @Bean
|
||||||
@Lazy // TODO 芋艿:临时注释,避免无法启动
|
// @Lazy // TODO 芋艿:临时注释,避免无法启动
|
||||||
public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties,
|
// public RedisVectorStore vectorStore(TongYiTextEmbeddingModel tongYiTextEmbeddingModel, RedisVectorStoreProperties properties,
|
||||||
RedisProperties redisProperties) {
|
// RedisProperties redisProperties) {
|
||||||
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
||||||
.withIndexName(properties.getIndex())
|
// .withIndexName(properties.getIndex())
|
||||||
.withPrefix(properties.getPrefix())
|
// .withPrefix(properties.getPrefix())
|
||||||
.build();
|
// .build();
|
||||||
|
//
|
||||||
RedisVectorStore redisVectorStore = new RedisVectorStore(config, transformersEmbeddingModel,
|
// RedisVectorStore redisVectorStore = new RedisVectorStore(config, tongYiTextEmbeddingModel,
|
||||||
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
|
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
|
||||||
properties.isInitializeSchema());
|
// properties.isInitializeSchema());
|
||||||
redisVectorStore.afterPropertiesSet();
|
// redisVectorStore.afterPropertiesSet();
|
||||||
return redisVectorStore;
|
// return redisVectorStore;
|
||||||
}
|
// }
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
@Lazy // TODO 芋艿:临时注释,避免无法启动
|
@Lazy // TODO 芋艿:临时注释,避免无法启动
|
||||||
|
@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|||||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||||
import org.springframework.ai.chat.model.ChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
|
import org.springframework.ai.embedding.EmbeddingModel;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -25,6 +26,18 @@ public interface AiModelFactory {
|
|||||||
*/
|
*/
|
||||||
ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url);
|
ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 基于指定配置,获得 EmbeddingModel 对象
|
||||||
|
* <p>
|
||||||
|
* 如果不存在,则进行创建
|
||||||
|
*
|
||||||
|
* @param platform 平台
|
||||||
|
* @param apiKey API KEY
|
||||||
|
* @param url API URL
|
||||||
|
* @return ChatModel 对象
|
||||||
|
*/
|
||||||
|
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 基于默认配置,获得 ChatModel 对象
|
* 基于默认配置,获得 ChatModel 对象
|
||||||
*
|
*
|
||||||
|
@ -21,6 +21,7 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
|
|||||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
|
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
|
||||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
import com.alibaba.dashscope.aigc.generation.Generation;
|
||||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||||
|
import com.alibaba.dashscope.embeddings.TextEmbedding;
|
||||||
import com.azure.ai.openai.OpenAIClient;
|
import com.azure.ai.openai.OpenAIClient;
|
||||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
|
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
|
||||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
|
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
|
||||||
@ -37,6 +38,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
|
|||||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
|
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
|
||||||
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
||||||
import org.springframework.ai.chat.model.ChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
|
import org.springframework.ai.embedding.EmbeddingModel;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||||
import org.springframework.ai.ollama.OllamaChatModel;
|
import org.springframework.ai.ollama.OllamaChatModel;
|
||||||
@ -97,6 +99,21 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) {
|
||||||
|
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url);
|
||||||
|
return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
|
||||||
|
// TODO @xin 先测试一个
|
||||||
|
switch (platform) {
|
||||||
|
case TONG_YI:
|
||||||
|
return buildTongYiEmbeddingModel(apiKey);
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
|
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
|
||||||
//noinspection EnhancedSwitchMigration
|
//noinspection EnhancedSwitchMigration
|
||||||
@ -239,7 +256,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
|
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
|
||||||
* ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
|
*ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
|
||||||
*/
|
*/
|
||||||
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
|
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
|
||||||
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
||||||
@ -249,7 +266,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
|
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
|
||||||
* ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
*ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
||||||
*/
|
*/
|
||||||
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
|
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
|
||||||
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
||||||
@ -315,4 +332,15 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||||||
return new StabilityAiImageModel(stabilityAiApi);
|
return new StabilityAiImageModel(stabilityAiApi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ========== 各种创建 EmbeddingModel 的方法 ==========
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 可参考 {@link TongYiAutoConfiguration#tongYiTextEmbeddingClient(TextEmbedding, TongYiConnectionProperties)}
|
||||||
|
*/
|
||||||
|
private EmbeddingModel buildTongYiEmbeddingModel(String apiKey) {
|
||||||
|
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
|
||||||
|
connectionProperties.setApiKey(apiKey);
|
||||||
|
return new TongYiAutoConfiguration().tongYiTextEmbeddingClient(SpringUtil.getBean(TextEmbedding.class), connectionProperties);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
package cn.iocoder.yudao.framework.ai.core.factory;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
|
import org.springframework.ai.embedding.EmbeddingModel;
|
||||||
|
import org.springframework.ai.vectorstore.VectorStore;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI Vector 模型工厂的接口类
|
||||||
|
* @author xiaoxin
|
||||||
|
*/
|
||||||
|
public interface AiVectorFactory {
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 基于指定配置,获得 VectorStore 对象
|
||||||
|
* <p>
|
||||||
|
* 如果不存在,则进行创建
|
||||||
|
*
|
||||||
|
* @param embeddingModel 嵌入模型
|
||||||
|
* @param platform 平台
|
||||||
|
* @param apiKey API KEY
|
||||||
|
* @param url API URL
|
||||||
|
* @return VectorStore 对象
|
||||||
|
*/
|
||||||
|
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
package cn.iocoder.yudao.framework.ai.core.factory;
|
||||||
|
|
||||||
|
import cn.hutool.core.lang.Singleton;
|
||||||
|
import cn.hutool.core.lang.func.Func0;
|
||||||
|
import cn.hutool.core.util.ArrayUtil;
|
||||||
|
import cn.hutool.core.util.StrUtil;
|
||||||
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
|
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
|
||||||
|
import org.springframework.ai.embedding.EmbeddingModel;
|
||||||
|
import org.springframework.ai.vectorstore.RedisVectorStore;
|
||||||
|
import org.springframework.ai.vectorstore.VectorStore;
|
||||||
|
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
|
||||||
|
import redis.clients.jedis.JedisPooled;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI Vector 模型工厂的实现类
|
||||||
|
* 使用 redisVectorStore 实现 VectorStore
|
||||||
|
*
|
||||||
|
* @author xiaoxin
|
||||||
|
*/
|
||||||
|
public class AiVectorFactoryImpl implements AiVectorFactory {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
|
||||||
|
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
|
||||||
|
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
|
||||||
|
// TODO 芋艿 @xin 这两个配置取哪好呢
|
||||||
|
// TODO 不同模型的向量维度可能会不一样,目前看貌似是以 index 来做区分的,维度不一样存不到一个 index 上
|
||||||
|
String index = "default-index";
|
||||||
|
String prefix = "default:";
|
||||||
|
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
||||||
|
.withIndexName(index)
|
||||||
|
.withPrefix(prefix)
|
||||||
|
.build();
|
||||||
|
RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
|
||||||
|
RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
|
||||||
|
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
|
||||||
|
true);
|
||||||
|
redisVectorStore.afterPropertiesSet();
|
||||||
|
return redisVectorStore;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
|
||||||
|
if (ArrayUtil.isEmpty(params)) {
|
||||||
|
return clazz.getName();
|
||||||
|
}
|
||||||
|
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
|
||||||
|
}
|
||||||
|
}
|
@ -19,6 +19,7 @@ import org.springframework.ai.embedding.EmbeddingModel;
|
|||||||
import org.springframework.ai.vectorstore.RedisVectorStore;
|
import org.springframework.ai.vectorstore.RedisVectorStore;
|
||||||
import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig;
|
import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig;
|
||||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||||
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||||
import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration;
|
import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration;
|
||||||
@ -38,7 +39,7 @@ import redis.clients.jedis.JedisPooled;
|
|||||||
*/
|
*/
|
||||||
@AutoConfiguration(after = RedisAutoConfiguration.class)
|
@AutoConfiguration(after = RedisAutoConfiguration.class)
|
||||||
@ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class})
|
@ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class})
|
||||||
//@ConditionalOnBean(JedisConnectionFactory.class)
|
@ConditionalOnBean(JedisConnectionFactory.class)
|
||||||
@EnableConfigurationProperties(RedisVectorStoreProperties.class)
|
@EnableConfigurationProperties(RedisVectorStoreProperties.class)
|
||||||
public class RedisVectorStoreAutoConfiguration {
|
public class RedisVectorStoreAutoConfiguration {
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user