Spring AI 高级RAG功能实现:多查询扩展
This commit is contained in:
parent
6dfba8de70
commit
958df784f2
@ -1,5 +1,9 @@
|
|||||||
package com.huangge1199.aiagent.Service;
|
package com.huangge1199.aiagent.Service;
|
||||||
|
|
||||||
|
import org.springframework.ai.rag.Query;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* RagService
|
* RagService
|
||||||
*
|
*
|
||||||
@ -8,4 +12,6 @@ package com.huangge1199.aiagent.Service;
|
|||||||
*/
|
*/
|
||||||
public interface RagService {
|
public interface RagService {
|
||||||
String localDoc(String question);
|
String localDoc(String question);
|
||||||
|
|
||||||
|
List<Query> getMultiQueryExpand(String question);
|
||||||
}
|
}
|
||||||
|
@ -2,12 +2,18 @@ package com.huangge1199.aiagent.Service.impl;
|
|||||||
|
|
||||||
import com.huangge1199.aiagent.Service.RagService;
|
import com.huangge1199.aiagent.Service.RagService;
|
||||||
import com.huangge1199.aiagent.config.MyLoggerAdvisor;
|
import com.huangge1199.aiagent.config.MyLoggerAdvisor;
|
||||||
|
import com.huangge1199.aiagent.rag.MyMultiQueryExpander;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import org.springframework.ai.chat.client.ChatClient;
|
import org.springframework.ai.chat.client.ChatClient;
|
||||||
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
|
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
|
||||||
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
|
import org.springframework.ai.rag.Query;
|
||||||
|
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
|
||||||
import org.springframework.ai.vectorstore.VectorStore;
|
import org.springframework.ai.vectorstore.VectorStore;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* RagServiceImpl
|
* RagServiceImpl
|
||||||
*
|
*
|
||||||
@ -23,6 +29,9 @@ public class RagServiceImpl implements RagService {
|
|||||||
@Resource
|
@Resource
|
||||||
private VectorStore vectorStore;
|
private VectorStore vectorStore;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private ChatModel ollamaChatModel;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String localDoc(String question) {
|
public String localDoc(String question) {
|
||||||
return chatClient.prompt()
|
return chatClient.prompt()
|
||||||
@ -32,4 +41,29 @@ public class RagServiceImpl implements RagService {
|
|||||||
.call()
|
.call()
|
||||||
.content();
|
.content();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Query> getMultiQueryExpand(String question) {
|
||||||
|
ChatClient.Builder builder = ChatClient.builder(ollamaChatModel);
|
||||||
|
|
||||||
|
ChatClient chatClient = builder
|
||||||
|
.defaultSystem("你是一位专业的室内设计顾问,精通各种装修风格、材料选择和空间布局。请基于提供的参考资料,为用户提供专业、详细且实用的建议。在回答时,请注意:\n" +
|
||||||
|
"1. 准确理解用户的具体需求\n" +
|
||||||
|
"2. 结合参考资料中的实际案例\n" +
|
||||||
|
"3. 提供专业的设计理念和原理解释\n" +
|
||||||
|
"4. 考虑实用性、美观性和成本效益\n" +
|
||||||
|
"5. 如有需要,可以提供替代方案")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// MultiQueryExpander queryExpander = MultiQueryExpander.builder()
|
||||||
|
MyMultiQueryExpander queryExpander = MyMultiQueryExpander.builder()
|
||||||
|
.chatClientBuilder(builder)
|
||||||
|
// 不包含原始查询
|
||||||
|
.includeOriginal(false)
|
||||||
|
// 生成3个查询变体
|
||||||
|
.numberOfQueries(4)
|
||||||
|
.build();
|
||||||
|
List<Query> queries = queryExpander.expand(new Query(question));
|
||||||
|
return queries;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,12 +5,13 @@ import com.huangge1199.aiagent.common.R;
|
|||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
|
import org.springframework.ai.rag.Query;
|
||||||
import org.springframework.web.bind.annotation.PostMapping;
|
import org.springframework.web.bind.annotation.PostMapping;
|
||||||
import org.springframework.web.bind.annotation.RequestBody;
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
|
||||||
import java.util.UUID;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* RagController
|
* RagController
|
||||||
@ -32,4 +33,11 @@ public class RagController {
|
|||||||
String res = ragService.localDoc(question);
|
String res = ragService.localDoc(question);
|
||||||
return R.ok(res);
|
return R.ok(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@PostMapping("/getMultiQueryExpand")
|
||||||
|
@Operation(summary = "多查询扩展")
|
||||||
|
public R<List<Query>> getMultiQueryExpand(@RequestBody String question) {
|
||||||
|
List<Query> queryList = ragService.getMultiQueryExpand(question);
|
||||||
|
return R.ok(queryList);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,155 @@
|
|||||||
|
package com.huangge1199.aiagent.rag;
|
||||||
|
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.ai.chat.client.ChatClient;
|
||||||
|
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||||
|
import org.springframework.ai.rag.Query;
|
||||||
|
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
|
||||||
|
import org.springframework.ai.util.PromptAssert;
|
||||||
|
import org.springframework.lang.Nullable;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MyMultiQueryExpander
|
||||||
|
*
|
||||||
|
* @author huangge1199
|
||||||
|
* @since 2025/5/24 14:39:38
|
||||||
|
*/
|
||||||
|
public class MyMultiQueryExpander implements QueryExpander {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(MyMultiQueryExpander.class);
|
||||||
|
|
||||||
|
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
|
||||||
|
You are an expert at information retrieval and search optimization.
|
||||||
|
Your task is to generate {number} different versions of the given query.
|
||||||
|
|
||||||
|
Each variant must cover different perspectives or aspects of the topic,
|
||||||
|
while maintaining the core intent of the original query. The goal is to
|
||||||
|
expand the search space and improve the chances of finding relevant information.
|
||||||
|
|
||||||
|
Do not explain your choices or add any other text.
|
||||||
|
Provide the query variants separated by newlines.
|
||||||
|
|
||||||
|
Original query: {query}
|
||||||
|
|
||||||
|
Query variants:
|
||||||
|
""");
|
||||||
|
|
||||||
|
private static final Boolean DEFAULT_INCLUDE_ORIGINAL = true;
|
||||||
|
|
||||||
|
private static final Integer DEFAULT_NUMBER_OF_QUERIES = 3;
|
||||||
|
|
||||||
|
private final ChatClient chatClient;
|
||||||
|
|
||||||
|
private final PromptTemplate promptTemplate;
|
||||||
|
|
||||||
|
private final boolean includeOriginal;
|
||||||
|
|
||||||
|
private final int numberOfQueries;
|
||||||
|
|
||||||
|
public MyMultiQueryExpander(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
|
||||||
|
@Nullable Boolean includeOriginal, @Nullable Integer numberOfQueries) {
|
||||||
|
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
|
||||||
|
|
||||||
|
this.chatClient = chatClientBuilder.build();
|
||||||
|
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
|
||||||
|
this.includeOriginal = includeOriginal != null ? includeOriginal : DEFAULT_INCLUDE_ORIGINAL;
|
||||||
|
this.numberOfQueries = numberOfQueries != null ? numberOfQueries : DEFAULT_NUMBER_OF_QUERIES;
|
||||||
|
|
||||||
|
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "number", "query");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Query> expand(Query query) {
|
||||||
|
Assert.notNull(query, "query cannot be null");
|
||||||
|
|
||||||
|
logger.debug("Generating {} query variants", this.numberOfQueries);
|
||||||
|
|
||||||
|
var response = this.chatClient.prompt()
|
||||||
|
.user(user -> user.text(this.promptTemplate.getTemplate())
|
||||||
|
.param("number", this.numberOfQueries)
|
||||||
|
.param("query", query.text()))
|
||||||
|
.call()
|
||||||
|
.content();
|
||||||
|
|
||||||
|
if (response == null) {
|
||||||
|
logger.warn("Query expansion result is null. Returning the input query unchanged.");
|
||||||
|
return List.of(query);
|
||||||
|
}
|
||||||
|
|
||||||
|
var queryVariants = Arrays.stream(response.split("\n"))
|
||||||
|
.filter(s -> !s.trim().isEmpty())
|
||||||
|
.toList();
|
||||||
|
;
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(queryVariants)) {
|
||||||
|
logger.warn(
|
||||||
|
"Query expansion result does not contain the requested {} variants. Returning the input query unchanged.",
|
||||||
|
this.numberOfQueries);
|
||||||
|
return List.of(query);
|
||||||
|
}
|
||||||
|
|
||||||
|
var queries = queryVariants.stream()
|
||||||
|
.filter(StringUtils::hasText)
|
||||||
|
.map(queryText -> query.mutate().text(queryText).build())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (this.includeOriginal) {
|
||||||
|
logger.debug("Including the original query in the result");
|
||||||
|
queries.add(0, query);
|
||||||
|
}
|
||||||
|
|
||||||
|
return queries;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static MyMultiQueryExpander.Builder builder() {
|
||||||
|
return new MyMultiQueryExpander.Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static final class Builder {
|
||||||
|
|
||||||
|
private ChatClient.Builder chatClientBuilder;
|
||||||
|
|
||||||
|
private PromptTemplate promptTemplate;
|
||||||
|
|
||||||
|
private Boolean includeOriginal;
|
||||||
|
|
||||||
|
private Integer numberOfQueries;
|
||||||
|
|
||||||
|
private Builder() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public MyMultiQueryExpander.Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
|
||||||
|
this.chatClientBuilder = chatClientBuilder;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MyMultiQueryExpander.Builder promptTemplate(PromptTemplate promptTemplate) {
|
||||||
|
this.promptTemplate = promptTemplate;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MyMultiQueryExpander.Builder includeOriginal(Boolean includeOriginal) {
|
||||||
|
this.includeOriginal = includeOriginal;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MyMultiQueryExpander.Builder numberOfQueries(Integer numberOfQueries) {
|
||||||
|
this.numberOfQueries = numberOfQueries;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MyMultiQueryExpander build() {
|
||||||
|
return new MyMultiQueryExpander(this.chatClientBuilder, this.promptTemplate, this.includeOriginal,
|
||||||
|
this.numberOfQueries);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user