【代码评审】AI 大模型:知识库的逻辑

This commit is contained in:
YunaiV 2024-08-31 14:18:35 +08:00
parent 261011d1a8
commit e30795fb69
21 changed files with 111 additions and 112 deletions

View File

@ -52,7 +52,6 @@ public interface ErrorCodeConstants {
// ========== API 思维导图 1-040-008-000 ========== // ========== API 思维导图 1-040-008-000 ==========
ErrorCode MIND_MAP_NOT_EXISTS = new ErrorCode(1_040_008_000, "思维导图不存在!"); ErrorCode MIND_MAP_NOT_EXISTS = new ErrorCode(1_040_008_000, "思维导图不存在!");
// ========== API 知识库 1-022-008-000 ========== // ========== API 知识库 1-022-008-000 ==========
ErrorCode KNOWLEDGE_NOT_EXISTS = new ErrorCode(1_022_008_000, "知识库不存在!"); ErrorCode KNOWLEDGE_NOT_EXISTS = new ErrorCode(1_022_008_000, "知识库不存在!");
ErrorCode KNOWLEDGE_DOCUMENT_NOT_EXISTS = new ErrorCode(1_022_008_001, "文档不存在!"); ErrorCode KNOWLEDGE_DOCUMENT_NOT_EXISTS = new ErrorCode(1_022_008_001, "文档不存在!");

View File

@ -22,6 +22,7 @@ import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUti
@Tag(name = "管理后台 - AI 知识库") @Tag(name = "管理后台 - AI 知识库")
@RestController @RestController
@RequestMapping("/ai/knowledge") @RequestMapping("/ai/knowledge")
@Validated
public class AiKnowledgeController { public class AiKnowledgeController {
@Resource @Resource
@ -34,14 +35,12 @@ public class AiKnowledgeController {
return success(BeanUtils.toBean(pageResult, AiKnowledgeRespVO.class)); return success(BeanUtils.toBean(pageResult, AiKnowledgeRespVO.class));
} }
@PostMapping("/create-my") @PostMapping("/create-my")
@Operation(summary = "创建【我的】知识库") @Operation(summary = "创建【我的】知识库")
public CommonResult<Long> createKnowledgeMy(@RequestBody @Valid AiKnowledgeCreateMyReqVO createReqVO) { public CommonResult<Long> createKnowledgeMy(@RequestBody @Valid AiKnowledgeCreateMyReqVO createReqVO) {
return success(knowledgeService.createKnowledgeMy(createReqVO, getLoginUserId())); return success(knowledgeService.createKnowledgeMy(createReqVO, getLoginUserId()));
} }
@PutMapping("/update-my") @PutMapping("/update-my")
@Operation(summary = "更新【我的】知识库") @Operation(summary = "更新【我的】知识库")
public CommonResult<Boolean> updateKnowledgeMy(@RequestBody @Valid AiKnowledgeUpdateMyReqVO updateReqVO) { public CommonResult<Boolean> updateKnowledgeMy(@RequestBody @Valid AiKnowledgeUpdateMyReqVO updateReqVO) {
@ -49,5 +48,4 @@ public class AiKnowledgeController {
return success(true); return success(true);
} }
} }

View File

@ -12,43 +12,40 @@ import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import jakarta.validation.Valid;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
@Tag(name = "管理后台 - AI 知识库文档")
@Tag(name = "管理后台 - AI 知识库-文档")
@RestController @RestController
@RequestMapping("/ai/knowledge/document") @RequestMapping("/ai/knowledge/document")
@Validated
public class AiKnowledgeDocumentController { public class AiKnowledgeDocumentController {
@Resource @Resource
private AiKnowledgeDocumentService documentService; private AiKnowledgeDocumentService documentService;
@PostMapping("/create") @PostMapping("/create")
@Operation(summary = "新建文档") @Operation(summary = "新建文档")
public CommonResult<Long> createKnowledgeDocument(@Validated AiKnowledgeDocumentCreateReqVO reqVO) { public CommonResult<Long> createKnowledgeDocument(@Valid AiKnowledgeDocumentCreateReqVO reqVO) {
Long knowledgeDocumentId = documentService.createKnowledgeDocument(reqVO); Long knowledgeDocumentId = documentService.createKnowledgeDocument(reqVO);
return success(knowledgeDocumentId); return success(knowledgeDocumentId);
} }
@GetMapping("/page") @GetMapping("/page")
@Operation(summary = "获取文档分页") @Operation(summary = "获取文档分页")
public CommonResult<PageResult<AiKnowledgeDocumentRespVO>> getKnowledgeDocumentPageMy(@Validated AiKnowledgeDocumentPageReqVO pageReqVO) { public CommonResult<PageResult<AiKnowledgeDocumentRespVO>> getKnowledgeDocumentPageMy(@Valid AiKnowledgeDocumentPageReqVO pageReqVO) {
PageResult<AiKnowledgeDocumentDO> pageResult = documentService.getKnowledgeDocumentPage(pageReqVO); PageResult<AiKnowledgeDocumentDO> pageResult = documentService.getKnowledgeDocumentPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiKnowledgeDocumentRespVO.class)); return success(BeanUtils.toBean(pageResult, AiKnowledgeDocumentRespVO.class));
} }
@PutMapping("/update") @PutMapping("/update")
@Operation(summary = "更新文档") @Operation(summary = "更新文档")
public CommonResult<Boolean> updateKnowledgeDocument(@Validated @RequestBody AiKnowledgeDocumentUpdateReqVO reqVO) { public CommonResult<Boolean> updateKnowledgeDocument(@Valid @RequestBody AiKnowledgeDocumentUpdateReqVO reqVO) {
documentService.updateKnowledgeDocument(reqVO); documentService.updateKnowledgeDocument(reqVO);
return success(true); return success(true);
} }
} }

View File

@ -12,15 +12,16 @@ import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import jakarta.validation.Valid;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
@Tag(name = "管理后台 - AI 知识库段落")
@Tag(name = "管理后台 - AI 知识库-段落")
@RestController @RestController
@RequestMapping("/ai/knowledge/segment") @RequestMapping("/ai/knowledge/segment")
@Validated
public class AiKnowledgeSegmentController { public class AiKnowledgeSegmentController {
@Resource @Resource
@ -28,22 +29,21 @@ public class AiKnowledgeSegmentController {
@GetMapping("/page") @GetMapping("/page")
@Operation(summary = "获取段落分页") @Operation(summary = "获取段落分页")
public CommonResult<PageResult<AiKnowledgeSegmentRespVO>> getKnowledgeSegmentPageMy(@Validated AiKnowledgeSegmentPageReqVO pageReqVO) { public CommonResult<PageResult<AiKnowledgeSegmentRespVO>> getKnowledgeSegmentPageMy(@Valid AiKnowledgeSegmentPageReqVO pageReqVO) {
PageResult<AiKnowledgeSegmentDO> pageResult = segmentService.getKnowledgeSegmentPage(pageReqVO); PageResult<AiKnowledgeSegmentDO> pageResult = segmentService.getKnowledgeSegmentPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiKnowledgeSegmentRespVO.class)); return success(BeanUtils.toBean(pageResult, AiKnowledgeSegmentRespVO.class));
} }
@PutMapping("/update") @PutMapping("/update")
@Operation(summary = "更新段落内容") @Operation(summary = "更新段落内容")
public CommonResult<Boolean> updateKnowledgeSegment(@Validated @RequestBody AiKnowledgeSegmentUpdateReqVO reqVO) { public CommonResult<Boolean> updateKnowledgeSegment(@Valid @RequestBody AiKnowledgeSegmentUpdateReqVO reqVO) {
segmentService.updateKnowledgeSegment(reqVO); segmentService.updateKnowledgeSegment(reqVO);
return success(true); return success(true);
} }
@PutMapping("/update-status") @PutMapping("/update-status")
@Operation(summary = "启禁用段落内容") @Operation(summary = "启禁用段落内容")
public CommonResult<Boolean> updateKnowledgeSegmentStatus(@Validated @RequestBody AiKnowledgeSegmentUpdateStatusReqVO reqVO) { public CommonResult<Boolean> updateKnowledgeSegmentStatus(@Valid @RequestBody AiKnowledgeSegmentUpdateStatusReqVO reqVO) {
segmentService.updateKnowledgeSegmentStatus(reqVO); segmentService.updateKnowledgeSegmentStatus(reqVO);
return success(true); return success(true);
} }

View File

@ -4,10 +4,11 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data; import lombok.Data;
@Schema(description = "管理后台 - AI 知识库-文档 分页 Request VO") @Schema(description = "管理后台 - AI 知识库文档的分页 Request VO")
@Data @Data
public class AiKnowledgeDocumentPageReqVO extends PageParam { public class AiKnowledgeDocumentPageReqVO extends PageParam {
@Schema(description = "文档名称", example = "Java 开发手册") @Schema(description = "文档名称", example = "Java 开发手册")
private String name; private String name;
} }

View File

@ -17,21 +17,22 @@ public class AiKnowledgeDocumentRespVO extends PageParam {
@Schema(description = "名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 开发手册") @Schema(description = "名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 开发手册")
private String name; private String name;
@Schema(description = "内容", example = "Java 是一门面向对象的语言.....") @Schema(description = "内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 是一门面向对象的语言.....")
private String content; private String content;
@Schema(description = "文档 url", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://doc.iocoder.cn") @Schema(description = "文档 url", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://doc.iocoder.cn")
private String url; private String url;
@Schema(description = "token 数量", example = "1024") @Schema(description = "token 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
private Integer tokens; private Integer tokens;
@Schema(description = "字符数", example = "1008") @Schema(description = "字符数", requiredMode = Schema.RequiredMode.REQUIRED, example = "1008")
private Integer wordCount; private Integer wordCount;
@Schema(description = "切片状态", example = "1") @Schema(description = "切片状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Integer sliceStatus; private Integer sliceStatus;
@Schema(description = "文档状态", example = "1") @Schema(description = "文档状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Integer status; private Integer status;
} }

View File

@ -1,5 +1,7 @@
package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document; package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.validation.InEnum;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
@ -15,6 +17,7 @@ public class AiKnowledgeDocumentUpdateReqVO {
private Long id; private Long id;
@Schema(description = "是否启用", example = "1") @Schema(description = "是否启用", example = "1")
@InEnum(CommonStatusEnum.class)
private Integer status; private Integer status;
@Schema(description = "名称", example = "Java 开发手册") @Schema(description = "名称", example = "Java 开发手册")

View File

@ -7,7 +7,7 @@ import lombok.Data;
import org.hibernate.validator.constraints.URL; import org.hibernate.validator.constraints.URL;
@Schema(description = "管理后台 - AI 知识库创建【文档】 Request VO") @Schema(description = "管理后台 - AI 知识库文档的创建 Request VO")
@Data @Data
public class AiKnowledgeDocumentCreateReqVO { public class AiKnowledgeDocumentCreateReqVO {

View File

@ -14,12 +14,13 @@ public class AiKnowledgeRespVO {
@Schema(description = "知识库名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "ruoyi-vue-pro 用户指南") @Schema(description = "知识库名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "ruoyi-vue-pro 用户指南")
private String name; private String name;
@Schema(description = "知识库描述", example = "ruoyi-vue-pro 用户指南") @Schema(description = "知识库描述", example = "帮助你快速构建系统")
private String description; private String description;
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "14") @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "14")
private Long modelId; private Long modelId;
@Schema(description = "模型标识", example = "qwen-72b-chat") @Schema(description = "模型标识", requiredMode = Schema.RequiredMode.REQUIRED, example = "qwen-72b-chat")
private String model; private String model;
} }

View File

@ -4,15 +4,14 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data; import lombok.Data;
@Schema(description = "管理后台 - AI 知识库分页 Request VO") @Schema(description = "管理后台 - AI 知识库分段的分页 Request VO")
@Data @Data
public class AiKnowledgeSegmentPageReqVO extends PageParam { public class AiKnowledgeSegmentPageReqVO extends PageParam {
@Schema(description = "分段状态", example = "1") @Schema(description = "分段状态", example = "1")
private Integer status; private Integer status;
@Schema(description = "文档编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "文档编号", example = "1")
private Integer documentId; private Integer documentId;
@Schema(description = "分段内容关键字", example = "Java 开发") @Schema(description = "分段内容关键字", example = "Java 开发")

View File

@ -22,12 +22,13 @@ public class AiKnowledgeSegmentRespVO {
@Schema(description = "切片内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 开发手册") @Schema(description = "切片内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 开发手册")
private String content; private String content;
@Schema(description = "token 数量", example = "1024") @Schema(description = "token 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
private Integer tokens; private Integer tokens;
@Schema(description = "字符数", example = "1008") @Schema(description = "字符数", requiredMode = Schema.RequiredMode.REQUIRED, example = "1008")
private Integer wordCount; private Integer wordCount;
@Schema(description = "文档状态", example = "1") @Schema(description = "文档状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Integer status; private Integer status;
} }

View File

@ -1,10 +1,13 @@
package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment; package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.validation.InEnum;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
@Schema(description = "管理后台 - AI 更新 知识库-段落 request VO") @Schema(description = "管理后台 - AI 知识库段落的更新状态 Request VO")
@Data @Data
public class AiKnowledgeSegmentUpdateStatusReqVO { public class AiKnowledgeSegmentUpdateStatusReqVO {
@ -12,6 +15,8 @@ public class AiKnowledgeSegmentUpdateStatusReqVO {
private Long id; private Long id;
@Schema(description = "是否启用", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "是否启用", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "是否启用不能为空")
@InEnum(CommonStatusEnum.class)
private Integer status; private Integer status;
} }

View File

@ -28,12 +28,13 @@ public class AiKnowledgeSegmentDO extends BaseDO {
private String vectorId; private String vectorId;
/** /**
* 知识库编号 * 知识库编号
*
* 关联 {@link AiKnowledgeDO#getId()} * 关联 {@link AiKnowledgeDO#getId()}
*/ */
private Long knowledgeId; private Long knowledgeId;
/** /**
* 文档编号 * 文档编号
* <p> *
* 关联 {@link AiKnowledgeDocumentDO#getId()} * 关联 {@link AiKnowledgeDocumentDO#getId()}
*/ */
private Long documentId; private Long documentId;
@ -51,7 +52,7 @@ public class AiKnowledgeSegmentDO extends BaseDO {
private Integer tokens; private Integer tokens;
/** /**
* 状态 * 状态
* <p> *
* 枚举 {@link CommonStatusEnum} * 枚举 {@link CommonStatusEnum}
*/ */
private Integer status; private Integer status;

View File

@ -30,13 +30,12 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.util.List; import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_DOCUMENT_NOT_EXISTS; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_DOCUMENT_NOT_EXISTS;
/** /**
* AI 知识库-文档 Service 实现类 * AI 知识库文档 Service 实现类
* *
* @author xiaoxin * @author xiaoxin
*/ */
@ -61,24 +60,21 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
@Resource @Resource
private AiChatModelService chatModelService; private AiChatModelService chatModelService;
// TODO 芋艿需要 review 代码格式
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) { public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) {
// 0. 校验
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 1.1 下载文档 // 1.1 下载文档
String url = createReqVO.getUrl(); TikaDocumentReader loader = new TikaDocumentReader(downloadFile(createReqVO.getUrl()));
// 1.2 加载文档
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url));
List<Document> documents = loader.get(); List<Document> documents = loader.get();
Document document = CollUtil.getFirst(documents); Document document = CollUtil.getFirst(documents);
// 1.2 文档记录入库
String content = document.getContent(); String content = document.getContent();
Integer tokens = tokenCountEstimator.estimate(content);
Integer wordCount = content.length();
// 1.3 文档记录入库
AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class) AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class)
.setTokens(tokens).setWordCount(wordCount) .setTokens(tokenCountEstimator.estimate(content)).setWordCount(content.length())
.setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus()); .setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus());
documentMapper.insert(documentDO); documentMapper.insert(documentDO);
Long documentId = documentDO.getId(); Long documentId = documentDO.getId();
@ -90,22 +86,16 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
List<Document> segments = tokenTextSplitter.apply(documents); List<Document> segments = tokenTextSplitter.apply(documents);
// 2.2 分段内容入库 // 2.2 分段内容入库
List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments, List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments,
segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId).setKnowledgeId(createReqVO.getKnowledgeId()).setVectorId(segment.getId()) segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId)
.setKnowledgeId(createReqVO.getKnowledgeId()).setVectorId(segment.getId())
.setTokens(tokenCountEstimator.estimate(segment.getContent())).setWordCount(segment.getContent().length()) .setTokens(tokenCountEstimator.estimate(segment.getContent())).setWordCount(segment.getContent().length())
.setStatus(CommonStatusEnum.ENABLE.getStatus())); .setStatus(CommonStatusEnum.ENABLE.getStatus()));
segmentMapper.insertBatch(segmentDOList); segmentMapper.insertBatch(segmentDOList);
// 3.1 document 补充源数据 // 3.1 获取向量存储实例
segments.forEach(segment -> {
Map<String, Object> metadata = segment.getMetadata();
metadata.put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, createReqVO.getKnowledgeId());
});
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 3.2 获取向量存储实例
VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId()); VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
// 3.3 向量化并存储 // 3.2 向量化并存储
segments.forEach(segment -> segment.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, createReqVO.getKnowledgeId()));
vectorStore.add(segments); vectorStore.add(segments);
return documentId; return documentId;
} }
@ -117,7 +107,9 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
@Override @Override
public void updateKnowledgeDocument(AiKnowledgeDocumentUpdateReqVO reqVO) { public void updateKnowledgeDocument(AiKnowledgeDocumentUpdateReqVO reqVO) {
// 1. 校验文档是否存在
validateKnowledgeDocumentExists(reqVO.getId()); validateKnowledgeDocumentExists(reqVO.getId());
// 2. 更新文档
AiKnowledgeDocumentDO document = BeanUtils.toBean(reqVO, AiKnowledgeDocumentDO.class); AiKnowledgeDocumentDO document = BeanUtils.toBean(reqVO, AiKnowledgeDocumentDO.class);
documentMapper.updateById(document); documentMapper.updateById(document);
} }

View File

@ -7,7 +7,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowle
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
/** /**
* AI 知识库分片 Service 接口 * AI 知识库段落 Service 接口
* *
* @author xiaoxin * @author xiaoxin
*/ */
@ -22,16 +22,17 @@ public interface AiKnowledgeSegmentService {
PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO); PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO);
/** /**
* 更新段落内容 * 更新段落内容
* *
* @param reqVO 更新内容 * @param reqVO 更新内容
*/ */
void updateKnowledgeSegment(AiKnowledgeSegmentUpdateReqVO reqVO); void updateKnowledgeSegment(AiKnowledgeSegmentUpdateReqVO reqVO);
/** /**
* 更新状态 * 更新段落的状态
* *
* @param reqVO 更新内容 * @param reqVO 更新内容
*/ */
void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO); void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO);
} }

View File

@ -85,14 +85,6 @@ public interface AiApiKeyService {
*/ */
ChatModel getChatModel(Long id); ChatModel getChatModel(Long id);
/**
* 获得 EmbeddingModel 对象
*
* @param id 编号
* @return EmbeddingModel 对象
*/
EmbeddingModel getEmbeddingModel(Long id);
/** /**
* 获得 ImageModel 对象 * 获得 ImageModel 对象
* *
@ -122,7 +114,15 @@ public interface AiApiKeyService {
SunoApi getSunoApi(); SunoApi getSunoApi();
/** /**
* 获得 vector 对象 * 获得 EmbeddingModel 对象
*
* @param id 编号
* @return EmbeddingModel 对象
*/
EmbeddingModel getEmbeddingModel(Long id);
/**
* 获得 VectorStore 对象
* *
* @param id 编号 * @param id 编号
* @return VectorStore 对象 * @return VectorStore 对象

View File

@ -109,13 +109,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl()); return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
} }
@Override
public EmbeddingModel getEmbeddingModel(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override @Override
public ImageModel getImageModel(AiPlatformEnum platform) { public ImageModel getImageModel(AiPlatformEnum platform) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
@ -145,10 +138,18 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
} }
@Override
public EmbeddingModel getEmbeddingModel(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override @Override
public VectorStore getOrCreateVectorStore(Long id) { public VectorStore getOrCreateVectorStore(Long id) {
AiApiKeyDO apiKey = validateApiKey(id); AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl()); return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
} }
} }

View File

@ -26,18 +26,6 @@ public interface AiModelFactory {
*/ */
ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url); ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于指定配置获得 EmbeddingModel 对象
* <p>
* 如果不存在则进行创建
*
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @return ChatModel 对象
*/
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
/** /**
* 基于默认配置获得 ChatModel 对象 * 基于默认配置获得 ChatModel 对象
* *
@ -92,4 +80,16 @@ public interface AiModelFactory {
*/ */
SunoApi getOrCreateSunoApi(String apiKey, String url); SunoApi getOrCreateSunoApi(String apiKey, String url);
/**
* 基于指定配置获得 EmbeddingModel 对象
*
* 如果不存在则进行创建
*
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @return ChatModel 对象
*/
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
} }

View File

@ -99,21 +99,6 @@ public class AiModelFactoryImpl implements AiModelFactory {
}); });
} }
@Override
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
// TODO @xin 先测试一个
switch (platform) {
case TONG_YI:
return buildTongYiEmbeddingModel(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
});
}
@Override @Override
public ChatModel getDefaultChatModel(AiPlatformEnum platform) { public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
@ -192,6 +177,20 @@ public class AiModelFactoryImpl implements AiModelFactory {
return Singleton.get(cacheKey, (Func0<SunoApi>) () -> new SunoApi(url)); return Singleton.get(cacheKey, (Func0<SunoApi>) () -> new SunoApi(url));
} }
@Override
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
// TODO @xin 先测试一个
switch (platform) {
case TONG_YI:
return buildTongYiEmbeddingModel(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
});
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) { private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) { if (ArrayUtil.isEmpty(params)) {
return clazz.getName(); return clazz.getName();
@ -255,8 +254,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
} }
/** /**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel( * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
*ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
*/ */
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) { private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
@ -265,8 +263,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
} }
/** /**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel( * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/ */
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) { private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);

View File

@ -4,13 +4,14 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
// TODO @xin也放到 AiModelFactory 里面好了后续改成 AiFactory
/** /**
* AI Vector 模型工厂的接口类 * AI Vector 模型工厂的接口类
*
* @author xiaoxin * @author xiaoxin
*/ */
public interface AiVectorStoreFactory { public interface AiVectorStoreFactory {
/** /**
* 基于指定配置获得 VectorStore 对象 * 基于指定配置获得 VectorStore 对象
* <p> * <p>

View File

@ -26,6 +26,7 @@ public class AiVectorStoreFactoryImpl implements AiVectorStoreFactory {
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> { return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
// TODO 芋艿 @xin 这两个配置取哪好呢 // TODO 芋艿 @xin 这两个配置取哪好呢
// TODO 不同模型的向量维度可能会不一样目前看貌似是以 index 来做区分的维度不一样存不到一个 index // TODO 不同模型的向量维度可能会不一样目前看貌似是以 index 来做区分的维度不一样存不到一个 index
// TODO 回复好的哈
String index = "default-index"; String index = "default-index";
String prefix = "default:"; String prefix = "default:";
var config = RedisVectorStore.RedisVectorStoreConfig.builder() var config = RedisVectorStore.RedisVectorStoreConfig.builder()
@ -41,11 +42,11 @@ public class AiVectorStoreFactoryImpl implements AiVectorStoreFactory {
}); });
} }
private static String buildClientCacheKey(Class<?> clazz, Object... params) { private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) { if (ArrayUtil.isEmpty(params)) {
return clazz.getName(); return clazz.getName();
} }
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_")); return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
} }
} }