diff --git a/src/main/java/com/huangge1199/aiagent/Service/RagService.java b/src/main/java/com/huangge1199/aiagent/Service/RagService.java index 25b87f3..f456ff4 100644 --- a/src/main/java/com/huangge1199/aiagent/Service/RagService.java +++ b/src/main/java/com/huangge1199/aiagent/Service/RagService.java @@ -1,5 +1,9 @@ package com.huangge1199.aiagent.Service; +import org.springframework.ai.rag.Query; + +import java.util.List; + /** * RagService * @@ -8,4 +12,6 @@ package com.huangge1199.aiagent.Service; */ public interface RagService { String localDoc(String question); + + List getMultiQueryExpand(String question); } diff --git a/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java b/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java index fd6196e..8174879 100644 --- a/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java +++ b/src/main/java/com/huangge1199/aiagent/Service/impl/RagServiceImpl.java @@ -2,12 +2,18 @@ package com.huangge1199.aiagent.Service.impl; import com.huangge1199.aiagent.Service.RagService; import com.huangge1199.aiagent.config.MyLoggerAdvisor; +import com.huangge1199.aiagent.rag.MyMultiQueryExpander; import jakarta.annotation.Resource; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.stereotype.Service; +import java.util.List; + /** * RagServiceImpl * @@ -23,6 +29,9 @@ public class RagServiceImpl implements RagService { @Resource private VectorStore vectorStore; + @Resource + private ChatModel ollamaChatModel; + @Override public String localDoc(String question) { return chatClient.prompt() @@ -32,4 +41,29 @@ public class RagServiceImpl implements RagService { .call() .content(); } + + @Override + public List getMultiQueryExpand(String question) { + ChatClient.Builder builder = ChatClient.builder(ollamaChatModel); + + ChatClient chatClient = builder + .defaultSystem("你是一位专业的室内设计顾问,精通各种装修风格、材料选择和空间布局。请基于提供的参考资料,为用户提供专业、详细且实用的建议。在回答时,请注意:\n" + + "1. 准确理解用户的具体需求\n" + + "2. 结合参考资料中的实际案例\n" + + "3. 提供专业的设计理念和原理解释\n" + + "4. 考虑实用性、美观性和成本效益\n" + + "5. 如有需要,可以提供替代方案") + .build(); + +// MultiQueryExpander queryExpander = MultiQueryExpander.builder() + MyMultiQueryExpander queryExpander = MyMultiQueryExpander.builder() + .chatClientBuilder(builder) + // 不包含原始查询 + .includeOriginal(false) + // 生成3个查询变体 + .numberOfQueries(4) + .build(); + List queries = queryExpander.expand(new Query(question)); + return queries; + } } diff --git a/src/main/java/com/huangge1199/aiagent/controller/RagController.java b/src/main/java/com/huangge1199/aiagent/controller/RagController.java index 26b0901..f148c9f 100644 --- a/src/main/java/com/huangge1199/aiagent/controller/RagController.java +++ b/src/main/java/com/huangge1199/aiagent/controller/RagController.java @@ -5,12 +5,13 @@ import com.huangge1199.aiagent.common.R; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.annotation.Resource; +import org.springframework.ai.rag.Query; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; -import java.util.UUID; +import java.util.List; /** * RagController @@ -32,4 +33,11 @@ public class RagController { String res = ragService.localDoc(question); return R.ok(res); } + + @PostMapping("/getMultiQueryExpand") + @Operation(summary = "多查询扩展") + public R> getMultiQueryExpand(@RequestBody String question) { + List queryList = ragService.getMultiQueryExpand(question); + return R.ok(queryList); + } } diff --git a/src/main/java/com/huangge1199/aiagent/rag/MyMultiQueryExpander.java b/src/main/java/com/huangge1199/aiagent/rag/MyMultiQueryExpander.java new file mode 100644 index 0000000..26be770 --- /dev/null +++ b/src/main/java/com/huangge1199/aiagent/rag/MyMultiQueryExpander.java @@ -0,0 +1,155 @@ +package com.huangge1199.aiagent.rag; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander; +import org.springframework.ai.util.PromptAssert; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +/** + * MyMultiQueryExpander + * + * @author huangge1199 + * @since 2025/5/24 14:39:38 + */ +public class MyMultiQueryExpander implements QueryExpander { + + private static final Logger logger = LoggerFactory.getLogger(MyMultiQueryExpander.class); + + private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(""" + You are an expert at information retrieval and search optimization. + Your task is to generate {number} different versions of the given query. + + Each variant must cover different perspectives or aspects of the topic, + while maintaining the core intent of the original query. The goal is to + expand the search space and improve the chances of finding relevant information. + + Do not explain your choices or add any other text. + Provide the query variants separated by newlines. + + Original query: {query} + + Query variants: + """); + + private static final Boolean DEFAULT_INCLUDE_ORIGINAL = true; + + private static final Integer DEFAULT_NUMBER_OF_QUERIES = 3; + + private final ChatClient chatClient; + + private final PromptTemplate promptTemplate; + + private final boolean includeOriginal; + + private final int numberOfQueries; + + public MyMultiQueryExpander(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, + @Nullable Boolean includeOriginal, @Nullable Integer numberOfQueries) { + Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null"); + + this.chatClient = chatClientBuilder.build(); + this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; + this.includeOriginal = includeOriginal != null ? includeOriginal : DEFAULT_INCLUDE_ORIGINAL; + this.numberOfQueries = numberOfQueries != null ? numberOfQueries : DEFAULT_NUMBER_OF_QUERIES; + + PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "number", "query"); + } + + @Override + public List expand(Query query) { + Assert.notNull(query, "query cannot be null"); + + logger.debug("Generating {} query variants", this.numberOfQueries); + + var response = this.chatClient.prompt() + .user(user -> user.text(this.promptTemplate.getTemplate()) + .param("number", this.numberOfQueries) + .param("query", query.text())) + .call() + .content(); + + if (response == null) { + logger.warn("Query expansion result is null. Returning the input query unchanged."); + return List.of(query); + } + + var queryVariants = Arrays.stream(response.split("\n")) + .filter(s -> !s.trim().isEmpty()) + .toList(); + ; + + if (CollectionUtils.isEmpty(queryVariants)) { + logger.warn( + "Query expansion result does not contain the requested {} variants. Returning the input query unchanged.", + this.numberOfQueries); + return List.of(query); + } + + var queries = queryVariants.stream() + .filter(StringUtils::hasText) + .map(queryText -> query.mutate().text(queryText).build()) + .collect(Collectors.toList()); + + if (this.includeOriginal) { + logger.debug("Including the original query in the result"); + queries.add(0, query); + } + + return queries; + } + + public static MyMultiQueryExpander.Builder builder() { + return new MyMultiQueryExpander.Builder(); + } + + public static final class Builder { + + private ChatClient.Builder chatClientBuilder; + + private PromptTemplate promptTemplate; + + private Boolean includeOriginal; + + private Integer numberOfQueries; + + private Builder() { + } + + public MyMultiQueryExpander.Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) { + this.chatClientBuilder = chatClientBuilder; + return this; + } + + public MyMultiQueryExpander.Builder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public MyMultiQueryExpander.Builder includeOriginal(Boolean includeOriginal) { + this.includeOriginal = includeOriginal; + return this; + } + + public MyMultiQueryExpander.Builder numberOfQueries(Integer numberOfQueries) { + this.numberOfQueries = numberOfQueries; + return this; + } + + public MyMultiQueryExpander build() { + return new MyMultiQueryExpander(this.chatClientBuilder, this.promptTemplate, this.includeOriginal, + this.numberOfQueries); + } + + } +}