From 94dac544e2d1e48ce605371ca71fd95396d9cc8f Mon Sep 17 00:00:00 2001 From: huangge1199 Date: Mon, 26 May 2025 11:21:32 +0800 Subject: [PATCH] =?UTF-8?q?Spring=20AI=20=E9=AB=98=E7=BA=A7RAG=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=E5=AE=9E=E7=8E=B0=EF=BC=9A=E6=9F=A5=E8=AF=A2=E9=87=8D?= =?UTF-8?q?=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../aiagent/Service/RagService.java | 2 ++ .../aiagent/Service/impl/RagServiceImpl.java | 21 +++++++++++++++++-- .../aiagent/controller/RagController.java | 7 +++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/huangge1199/aiagent/Service/RagService.java b/src/main/java/com/huangge1199/aiagent/Service/RagService.java index f456ff4..725a9fc 100644 --- a/src/main/java/com/huangge1199/aiagent/Service/RagService.java +++ b/src/main/java/com/huangge1199/aiagent/Service/RagService.java @@ -14,4 +14,6 @@ public interface RagService { String localDoc(String question); List getMultiQueryExpand(String question); + + String queryRewrite(String question); } diff --git a/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java b/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java index 8174879..4da0812 100644 --- a/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java +++ b/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java @@ -9,6 +9,8 @@ import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.rag.Query; import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; +import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.stereotype.Service; @@ -63,7 +65,22 @@ public class RagServiceImpl implements RagService { // 生成3个查询变体 .numberOfQueries(4) .build(); - List queries = queryExpander.expand(new Query(question)); - return queries; + return queryExpander.expand(new Query(question)); + } + + @Override + public String queryRewrite(String question) { + ChatClient.Builder builder = ChatClient.builder(ollamaChatModel); + // 创建一个模拟用户学习AI的查询场景 + Query query = new Query(question); + + // 创建查询重写转换器 + QueryTransformer queryTransformer = RewriteQueryTransformer.builder() + .chatClientBuilder(builder) + .build(); + + // 执行查询重写 + Query transformedQuery = queryTransformer.transform(query); + return transformedQuery.text(); } } diff --git a/src/main/java/com/huangge1199/aiagent/controller/RagController.java b/src/main/java/com/huangge1199/aiagent/controller/RagController.java index f148c9f..5f2008a 100644 --- a/src/main/java/com/huangge1199/aiagent/controller/RagController.java +++ b/src/main/java/com/huangge1199/aiagent/controller/RagController.java @@ -40,4 +40,11 @@ public class RagController { List queryList = ragService.getMultiQueryExpand(question); return R.ok(queryList); } + + @PostMapping("/queryRewrite") + @Operation(summary = "查询重写") + public R queryRewrite(@RequestBody String question) { + String queryList = ragService.queryRewrite(question); + return R.ok(queryList); + } }