Spring AI 高级RAG功能实现:检索增强顾问(基础用法和高级用法)

This commit is contained in:
huangge1199 2025-05-26 14:59:02 +08:00
parent 77ef26470a
commit ffcfb67e5b
3 changed files with 94 additions and 0 deletions

View File

@ -20,4 +20,8 @@ public interface RagService {
String queryTranslation(String question); String queryTranslation(String question);
String contextAwareQueries(String question); String contextAwareQueries(String question);
String baseAdvisor(String question);
String advancedAdvisor(String question);
} }

View File

@ -6,15 +6,23 @@ import com.huangge1199.aiagent.rag.MyMultiQueryExpander;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.rag.Query; import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
@ -37,6 +45,9 @@ public class RagServiceImpl implements RagService {
@Resource @Resource
private ChatModel ollamaChatModel; private ChatModel ollamaChatModel;
@Resource
private EmbeddingModel ollamaEmbeddingModel;
@Override @Override
public String localDoc(String question) { public String localDoc(String question) {
return chatClient.prompt().user(question).advisors(new MyLoggerAdvisor()).advisors(new QuestionAnswerAdvisor(vectorStore)).call().content(); return chatClient.prompt().user(question).advisors(new MyLoggerAdvisor()).advisors(new QuestionAnswerAdvisor(vectorStore)).call().content();
@ -113,4 +124,69 @@ public class RagServiceImpl implements RagService {
return transformedQuery.text(); return transformedQuery.text();
} }
@Override
public String baseAdvisor(String question) {
Advisor advisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.build())
.build();
return advisor(question, advisor);
}
@Override
public String advancedAdvisor(String question) {
Advisor advisor = RetrievalAugmentationAdvisor.builder()
// 配置查询增强器
.queryAugmenter(ContextualQueryAugmenter.builder()
// 允许空上下文查询
.allowEmptyContext(true)
.build())
// 配置文档检索器
.documentRetriever(VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
// 相似度阈值
.similarityThreshold(0.5)
// 返回文档数量
.topK(3)
.filterExpression(new FilterExpressionBuilder()
.eq("genre", "fairytale")
.build()) // 文档过滤表达式
.build())
.build();
return advisor(question, advisor);
}
/**
* 检索增强顾问
* @param question 问题
* @param advisor 检索增强顾问
* @return 查询结果
*/
private String advisor(String question, Advisor advisor) {
// 1. 初始化向量存储
SimpleVectorStore vectorStore = SimpleVectorStore.builder(ollamaEmbeddingModel)
.build();
// 2. 添加文档到向量存储
List<Document> documents = List.of(
new Document("产品说明书:产品名称:智能机器人\n" +
"产品描述:智能机器人是一个智能设备,能够自动完成各种任务。\n" +
"功能:\n" +
"1. 自动导航:机器人能够自动导航到指定位置。\n" +
"2. 自动抓取:机器人能够自动抓取物品。\n" +
"3. 自动放置:机器人能够自动放置物品。\n"));
vectorStore.add(documents);
// 3. 创建检索增强顾问 advisor
// 4. 在聊天客户端中使用顾问
return chatClient.prompt()
.user(question)
// 添加检索增强顾问
.advisors(advisor)
.call()
.content();
}
} }

View File

@ -61,4 +61,18 @@ public class RagController {
String queryList = ragService.contextAwareQueries(question); String queryList = ragService.contextAwareQueries(question);
return R.ok(queryList); return R.ok(queryList);
} }
@PostMapping("/baseAdvisor")
@Operation(summary = "检索增强顾问:基础用法")
public R<String> baseAdvisor(@RequestBody String question) {
String queryList = ragService.baseAdvisor(question);
return R.ok(queryList);
}
@PostMapping("/advancedAdvisor")
@Operation(summary = "检索增强顾问:高级用法")
public R<String> advancedAdvisor(@RequestBody String question) {
String queryList = ragService.advancedAdvisor(question);
return R.ok(queryList);
}
} }