diff --git a/src/main/java/com/huangge1199/aiagent/Service/RagService.java b/src/main/java/com/huangge1199/aiagent/Service/RagService.java index ee639c9..054fb58 100644 --- a/src/main/java/com/huangge1199/aiagent/Service/RagService.java +++ b/src/main/java/com/huangge1199/aiagent/Service/RagService.java @@ -20,4 +20,8 @@ public interface RagService { String queryTranslation(String question); String contextAwareQueries(String question); + + String baseAdvisor(String question); + + String advancedAdvisor(String question); } diff --git a/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java b/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java index a8f398d..531d18e 100644 --- a/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java +++ b/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java @@ -6,15 +6,23 @@ import com.huangge1199.aiagent.rag.MyMultiQueryExpander; import jakarta.annotation.Resource; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; +import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; +import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter; import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer; +import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.SimpleVectorStore; +import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; import org.springframework.stereotype.Service; import java.util.List; @@ -37,6 +45,9 @@ public class RagServiceImpl implements RagService { @Resource private ChatModel ollamaChatModel; + @Resource + private EmbeddingModel ollamaEmbeddingModel; + @Override public String localDoc(String question) { return chatClient.prompt().user(question).advisors(new MyLoggerAdvisor()).advisors(new QuestionAnswerAdvisor(vectorStore)).call().content(); @@ -113,4 +124,69 @@ public class RagServiceImpl implements RagService { return transformedQuery.text(); } + + @Override + public String baseAdvisor(String question) { + Advisor advisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(VectorStoreDocumentRetriever.builder() + .vectorStore(vectorStore) + .build()) + .build(); + return advisor(question, advisor); + } + + @Override + public String advancedAdvisor(String question) { + Advisor advisor = RetrievalAugmentationAdvisor.builder() + // 配置查询增强器 + .queryAugmenter(ContextualQueryAugmenter.builder() + // 允许空上下文查询 + .allowEmptyContext(true) + .build()) + // 配置文档检索器 + .documentRetriever(VectorStoreDocumentRetriever.builder() + .vectorStore(vectorStore) + // 相似度阈值 + .similarityThreshold(0.5) + // 返回文档数量 + .topK(3) + .filterExpression(new FilterExpressionBuilder() + .eq("genre", "fairytale") + .build()) // 文档过滤表达式 + .build()) + .build(); + return advisor(question, advisor); + } + + /** + * 检索增强顾问 + * @param question 问题 + * @param advisor 检索增强顾问 + * @return 查询结果 + */ + private String advisor(String question, Advisor advisor) { + // 1. 初始化向量存储 + SimpleVectorStore vectorStore = SimpleVectorStore.builder(ollamaEmbeddingModel) + .build(); + + // 2. 添加文档到向量存储 + List documents = List.of( + new Document("产品说明书:产品名称:智能机器人\n" + + "产品描述:智能机器人是一个智能设备,能够自动完成各种任务。\n" + + "功能:\n" + + "1. 自动导航:机器人能够自动导航到指定位置。\n" + + "2. 自动抓取:机器人能够自动抓取物品。\n" + + "3. 自动放置:机器人能够自动放置物品。\n")); + vectorStore.add(documents); + + // 3. 创建检索增强顾问 advisor + + // 4. 在聊天客户端中使用顾问 + return chatClient.prompt() + .user(question) + // 添加检索增强顾问 + .advisors(advisor) + .call() + .content(); + } } diff --git a/src/main/java/com/huangge1199/aiagent/controller/RagController.java b/src/main/java/com/huangge1199/aiagent/controller/RagController.java index 89b77b3..e26e329 100644 --- a/src/main/java/com/huangge1199/aiagent/controller/RagController.java +++ b/src/main/java/com/huangge1199/aiagent/controller/RagController.java @@ -61,4 +61,18 @@ public class RagController { String queryList = ragService.contextAwareQueries(question); return R.ok(queryList); } + + @PostMapping("/baseAdvisor") + @Operation(summary = "检索增强顾问:基础用法") + public R baseAdvisor(@RequestBody String question) { + String queryList = ragService.baseAdvisor(question); + return R.ok(queryList); + } + + @PostMapping("/advancedAdvisor") + @Operation(summary = "检索增强顾问:高级用法") + public R advancedAdvisor(@RequestBody String question) { + String queryList = ragService.advancedAdvisor(question); + return R.ok(queryList); + } }