diff --git a/src/main/java/com/huangge1199/aiagent/Service/RagService.java b/src/main/java/com/huangge1199/aiagent/Service/RagService.java index f456ff4..725a9fc 100644 --- a/src/main/java/com/huangge1199/aiagent/Service/RagService.java +++ b/src/main/java/com/huangge1199/aiagent/Service/RagService.java @@ -14,4 +14,6 @@ public interface RagService { String localDoc(String question); List getMultiQueryExpand(String question); + + String queryRewrite(String question); } diff --git a/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java b/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java index 8174879..4da0812 100644 --- a/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java +++ b/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java @@ -9,6 +9,8 @@ import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.rag.Query; import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; +import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.stereotype.Service; @@ -63,7 +65,22 @@ public class RagServiceImpl implements RagService { // 生成3个查询变体 .numberOfQueries(4) .build(); - List queries = queryExpander.expand(new Query(question)); - return queries; + return queryExpander.expand(new Query(question)); + } + + @Override + public String queryRewrite(String question) { + ChatClient.Builder builder = ChatClient.builder(ollamaChatModel); + // 创建一个模拟用户学习AI的查询场景 + Query query = new Query(question); + + // 创建查询重写转换器 + QueryTransformer queryTransformer = RewriteQueryTransformer.builder() + .chatClientBuilder(builder) + .build(); + + // 执行查询重写 + Query transformedQuery = queryTransformer.transform(query); + return transformedQuery.text(); } } diff --git a/src/main/java/com/huangge1199/aiagent/controller/RagController.java b/src/main/java/com/huangge1199/aiagent/controller/RagController.java index f148c9f..5f2008a 100644 --- a/src/main/java/com/huangge1199/aiagent/controller/RagController.java +++ b/src/main/java/com/huangge1199/aiagent/controller/RagController.java @@ -40,4 +40,11 @@ public class RagController { List queryList = ragService.getMultiQueryExpand(question); return R.ok(queryList); } + + @PostMapping("/queryRewrite") + @Operation(summary = "查询重写") + public R queryRewrite(@RequestBody String question) { + String queryList = ragService.queryRewrite(question); + return R.ok(queryList); + } }