Spring AI 高级RAG功能实现:查询重写
This commit is contained in:
parent
958df784f2
commit
94dac544e2
@ -14,4 +14,6 @@ public interface RagService {
|
|||||||
String localDoc(String question);
|
String localDoc(String question);
|
||||||
|
|
||||||
List<Query> getMultiQueryExpand(String question);
|
List<Query> getMultiQueryExpand(String question);
|
||||||
|
|
||||||
|
String queryRewrite(String question);
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,8 @@ import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
|
|||||||
import org.springframework.ai.chat.model.ChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
import org.springframework.ai.rag.Query;
|
import org.springframework.ai.rag.Query;
|
||||||
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
|
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
|
||||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
|
||||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
|
||||||
import org.springframework.ai.vectorstore.VectorStore;
|
import org.springframework.ai.vectorstore.VectorStore;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@ -63,7 +65,22 @@ public class RagServiceImpl implements RagService {
|
|||||||
// 生成3个查询变体
|
// 生成3个查询变体
|
||||||
.numberOfQueries(4)
|
.numberOfQueries(4)
|
||||||
.build();
|
.build();
|
||||||
List<Query> queries = queryExpander.expand(new Query(question));
|
return queryExpander.expand(new Query(question));
|
||||||
return queries;
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String queryRewrite(String question) {
|
||||||
|
ChatClient.Builder builder = ChatClient.builder(ollamaChatModel);
|
||||||
|
// 创建一个模拟用户学习AI的查询场景
|
||||||
|
Query query = new Query(question);
|
||||||
|
|
||||||
|
// 创建查询重写转换器
|
||||||
|
QueryTransformer queryTransformer = RewriteQueryTransformer.builder()
|
||||||
|
.chatClientBuilder(builder)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// 执行查询重写
|
||||||
|
Query transformedQuery = queryTransformer.transform(query);
|
||||||
|
return transformedQuery.text();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -40,4 +40,11 @@ public class RagController {
|
|||||||
List<Query> queryList = ragService.getMultiQueryExpand(question);
|
List<Query> queryList = ragService.getMultiQueryExpand(question);
|
||||||
return R.ok(queryList);
|
return R.ok(queryList);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@PostMapping("/queryRewrite")
|
||||||
|
@Operation(summary = "查询重写")
|
||||||
|
public R<String> queryRewrite(@RequestBody String question) {
|
||||||
|
String queryList = ragService.queryRewrite(question);
|
||||||
|
return R.ok(queryList);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user