Spring AI 高级RAG功能实现:多查询扩展

This commit is contained in:
huangge1199 2025-05-26 08:50:56 +08:00
parent 6dfba8de70
commit 958df784f2
4 changed files with 204 additions and 1 deletions

View File

@ -1,5 +1,9 @@
package com.huangge1199.aiagent.Service;
import org.springframework.ai.rag.Query;
import java.util.List;
/**
* RagService
*
@ -8,4 +12,6 @@ package com.huangge1199.aiagent.Service;
*/
public interface RagService {
String localDoc(String question);
List<Query> getMultiQueryExpand(String question);
}

View File

@ -2,12 +2,18 @@ package com.huangge1199.aiagent.Service.impl;
import com.huangge1199.aiagent.Service.RagService;
import com.huangge1199.aiagent.config.MyLoggerAdvisor;
import com.huangge1199.aiagent.rag.MyMultiQueryExpander;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import java.util.List;
/**
* RagServiceImpl
*
@ -23,6 +29,9 @@ public class RagServiceImpl implements RagService {
@Resource
private VectorStore vectorStore;
@Resource
private ChatModel ollamaChatModel;
@Override
public String localDoc(String question) {
return chatClient.prompt()
@ -32,4 +41,29 @@ public class RagServiceImpl implements RagService {
.call()
.content();
}
@Override
public List<Query> getMultiQueryExpand(String question) {
ChatClient.Builder builder = ChatClient.builder(ollamaChatModel);
ChatClient chatClient = builder
.defaultSystem("你是一位专业的室内设计顾问,精通各种装修风格、材料选择和空间布局。请基于提供的参考资料,为用户提供专业、详细且实用的建议。在回答时,请注意:\n" +
"1. 准确理解用户的具体需求\n" +
"2. 结合参考资料中的实际案例\n" +
"3. 提供专业的设计理念和原理解释\n" +
"4. 考虑实用性、美观性和成本效益\n" +
"5. 如有需要,可以提供替代方案")
.build();
// MultiQueryExpander queryExpander = MultiQueryExpander.builder()
MyMultiQueryExpander queryExpander = MyMultiQueryExpander.builder()
.chatClientBuilder(builder)
// 不包含原始查询
.includeOriginal(false)
// 生成3个查询变体
.numberOfQueries(4)
.build();
List<Query> queries = queryExpander.expand(new Query(question));
return queries;
}
}

View File

@ -5,12 +5,13 @@ import com.huangge1199.aiagent.common.R;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import org.springframework.ai.rag.Query;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.UUID;
import java.util.List;
/**
* RagController
@ -32,4 +33,11 @@ public class RagController {
String res = ragService.localDoc(question);
return R.ok(res);
}
@PostMapping("/getMultiQueryExpand")
@Operation(summary = "多查询扩展")
public R<List<Query>> getMultiQueryExpand(@RequestBody String question) {
List<Query> queryList = ragService.getMultiQueryExpand(question);
return R.ok(queryList);
}
}

View File

@ -0,0 +1,155 @@
package com.huangge1199.aiagent.rag;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
/**
* MyMultiQueryExpander
*
* @author huangge1199
* @since 2025/5/24 14:39:38
*/
public class MyMultiQueryExpander implements QueryExpander {
private static final Logger logger = LoggerFactory.getLogger(MyMultiQueryExpander.class);
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
You are an expert at information retrieval and search optimization.
Your task is to generate {number} different versions of the given query.
Each variant must cover different perspectives or aspects of the topic,
while maintaining the core intent of the original query. The goal is to
expand the search space and improve the chances of finding relevant information.
Do not explain your choices or add any other text.
Provide the query variants separated by newlines.
Original query: {query}
Query variants:
""");
private static final Boolean DEFAULT_INCLUDE_ORIGINAL = true;
private static final Integer DEFAULT_NUMBER_OF_QUERIES = 3;
private final ChatClient chatClient;
private final PromptTemplate promptTemplate;
private final boolean includeOriginal;
private final int numberOfQueries;
public MyMultiQueryExpander(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
@Nullable Boolean includeOriginal, @Nullable Integer numberOfQueries) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.includeOriginal = includeOriginal != null ? includeOriginal : DEFAULT_INCLUDE_ORIGINAL;
this.numberOfQueries = numberOfQueries != null ? numberOfQueries : DEFAULT_NUMBER_OF_QUERIES;
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "number", "query");
}
@Override
public List<Query> expand(Query query) {
Assert.notNull(query, "query cannot be null");
logger.debug("Generating {} query variants", this.numberOfQueries);
var response = this.chatClient.prompt()
.user(user -> user.text(this.promptTemplate.getTemplate())
.param("number", this.numberOfQueries)
.param("query", query.text()))
.call()
.content();
if (response == null) {
logger.warn("Query expansion result is null. Returning the input query unchanged.");
return List.of(query);
}
var queryVariants = Arrays.stream(response.split("\n"))
.filter(s -> !s.trim().isEmpty())
.toList();
;
if (CollectionUtils.isEmpty(queryVariants)) {
logger.warn(
"Query expansion result does not contain the requested {} variants. Returning the input query unchanged.",
this.numberOfQueries);
return List.of(query);
}
var queries = queryVariants.stream()
.filter(StringUtils::hasText)
.map(queryText -> query.mutate().text(queryText).build())
.collect(Collectors.toList());
if (this.includeOriginal) {
logger.debug("Including the original query in the result");
queries.add(0, query);
}
return queries;
}
public static MyMultiQueryExpander.Builder builder() {
return new MyMultiQueryExpander.Builder();
}
public static final class Builder {
private ChatClient.Builder chatClientBuilder;
private PromptTemplate promptTemplate;
private Boolean includeOriginal;
private Integer numberOfQueries;
private Builder() {
}
public MyMultiQueryExpander.Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
this.chatClientBuilder = chatClientBuilder;
return this;
}
public MyMultiQueryExpander.Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
public MyMultiQueryExpander.Builder includeOriginal(Boolean includeOriginal) {
this.includeOriginal = includeOriginal;
return this;
}
public MyMultiQueryExpander.Builder numberOfQueries(Integer numberOfQueries) {
this.numberOfQueries = numberOfQueries;
return this;
}
public MyMultiQueryExpander build() {
return new MyMultiQueryExpander(this.chatClientBuilder, this.promptTemplate, this.includeOriginal,
this.numberOfQueries);
}
}
}