Spring AI 高级RAG功能实现:查询重写

This commit is contained in:
huangge1199 2025-05-26 11:21:32 +08:00
parent 958df784f2
commit 94dac544e2
3 changed files with 28 additions and 2 deletions

View File

@ -14,4 +14,6 @@ public interface RagService {
String localDoc(String question); String localDoc(String question);
List<Query> getMultiQueryExpand(String question); List<Query> getMultiQueryExpand(String question);
String queryRewrite(String question);
} }

View File

@ -9,6 +9,8 @@ import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.Query; import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -63,7 +65,22 @@ public class RagServiceImpl implements RagService {
// 生成3个查询变体 // 生成3个查询变体
.numberOfQueries(4) .numberOfQueries(4)
.build(); .build();
List<Query> queries = queryExpander.expand(new Query(question)); return queryExpander.expand(new Query(question));
return queries; }
@Override
public String queryRewrite(String question) {
ChatClient.Builder builder = ChatClient.builder(ollamaChatModel);
// 创建一个模拟用户学习AI的查询场景
Query query = new Query(question);
// 创建查询重写转换器
QueryTransformer queryTransformer = RewriteQueryTransformer.builder()
.chatClientBuilder(builder)
.build();
// 执行查询重写
Query transformedQuery = queryTransformer.transform(query);
return transformedQuery.text();
} }
} }

View File

@ -40,4 +40,11 @@ public class RagController {
List<Query> queryList = ragService.getMultiQueryExpand(question); List<Query> queryList = ragService.getMultiQueryExpand(question);
return R.ok(queryList); return R.ok(queryList);
} }
@PostMapping("/queryRewrite")
@Operation(summary = "查询重写")
public R<String> queryRewrite(@RequestBody String question) {
String queryList = ragService.queryRewrite(question);
return R.ok(queryList);
}
} }