调用工具:集中注册

This commit is contained in:
huangge1199 2025-05-29 18:09:02 +08:00
parent 4856d11590
commit cb31c10bb5
11 changed files with 98 additions and 5 deletions

View File

@ -26,4 +26,6 @@ public interface ToolsService {
void downloadTool(String url, String name);
void pdfTool(String name, String context);
String doChatWithTools(String question);
}

View File

@ -5,7 +5,9 @@ import com.huangge1199.aiagent.config.MyLoggerAdvisor;
import com.huangge1199.aiagent.tools.*;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@ -26,6 +28,9 @@ public class ToolsServiceImpl implements ToolsService {
@Value("${search-api.api-key}")
private String searchApiKey;
@Resource
private ToolCallback[] allTools;
@Override
public String getWeather(String question) {
return ChatClient.create(ollamaChatModel)
@ -38,7 +43,7 @@ public class ToolsServiceImpl implements ToolsService {
@Override
public String writeFileTest(String context, String name) {
FileTool fileTool = new FileTool();
return fileTool.writeFile(name,context);
return fileTool.writeFile(name, context);
}
@Override
@ -85,4 +90,18 @@ public class ToolsServiceImpl implements ToolsService {
PDFGenerationTool pdfTool = new PDFGenerationTool();
pdfTool.generatePDF(name, context);
}
@Override
public String doChatWithTools(String question) {
ChatResponse chatResponse = ChatClient.create(ollamaChatModel)
.prompt()
.user(question)
// 开启日志便于观察效果
.advisors(new MyLoggerAdvisor())
.tools(allTools)
.call()
.chatResponse();
assert chatResponse != null;
return chatResponse.getResult().getOutput().getText();
}
}

View File

@ -0,0 +1,39 @@
package com.huangge1199.aiagent.config;
import com.huangge1199.aiagent.tools.*;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* ToolRegistration
*
* @author huangge1199
* @since 2025/5/28 16:54:55
*/
@Configuration
public class ToolRegistration {
@Value("${search-api.api-key}")
private String searchApiKey;
@Bean
public ToolCallback[] allTools() {
FileTool fileTool = new FileTool();
WebSearchTool webSearchTool = new WebSearchTool(searchApiKey);
WebScrapTool webScrapingTool = new WebScrapTool();
DownloadTool resourceDownloadTool = new DownloadTool();
TerminalTool terminalOperationTool = new TerminalTool();
PDFGenerationTool pdfGenerationTool = new PDFGenerationTool();
return ToolCallbacks.from(
fileTool,
webSearchTool,
webScrapingTool,
resourceDownloadTool,
terminalOperationTool,
pdfGenerationTool
);
}
}

View File

@ -7,6 +7,7 @@ import com.huangge1199.aiagent.util.CheckUtils;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
@ -108,4 +109,13 @@ public class ToolController {
toolsService.pdfTool(name, context);
return R.ok();
}
@PostMapping("/doChatWithTools")
@Operation(summary = "集中注册")
public R<String> doChatWithTools(String question) {
CheckUtils.checkEmpty(question, "问题");
String result = toolsService.doChatWithTools(question);
return R.ok(result);
}
}

View File

@ -2,6 +2,7 @@ package com.huangge1199.aiagent.tools;
import cn.hutool.core.io.FileUtil;
import cn.hutool.http.HttpUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
@ -13,6 +14,7 @@ import java.io.File;
* @author huangge1199
* @since 2025/5/28 16:11:48
*/
@Slf4j
public class DownloadTool {
@Tool(description = "Download a resource from a given URL")
public String downloadResource(@ToolParam(description = "URL of the resource to download") String url, @ToolParam(description = "Name of the file to save the downloaded resource") String fileName) {
@ -23,6 +25,7 @@ public class DownloadTool {
FileUtil.mkdir(fileDir);
// 使用 Hutool downloadFile 方法下载资源
HttpUtil.downloadFile(url, new File(filePath));
log.info("Resource downloaded successfully to: {}", filePath);
return "Resource downloaded successfully to: " + filePath;
} catch (Exception e) {
return "Error downloading resource: " + e.getMessage();

View File

@ -1,6 +1,7 @@
package com.huangge1199.aiagent.tools;
import cn.hutool.core.io.FileUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
@ -10,6 +11,7 @@ import org.springframework.ai.tool.annotation.ToolParam;
* @author huangge1199
* @since 2025/5/27 16:51:46
*/
@Slf4j
public class FileTool {
private final String FILE_DIR = FileConstant.FILE_SAVE_DIR + "/file";
@ -28,11 +30,13 @@ public class FileTool {
public String writeFile(
@ToolParam(description = "Name of the file to write") String fileName,
@ToolParam(description = "Content to write to the file") String content) {
log.info("Write content to file: " + fileName);
String filePath = FILE_DIR + "/" + fileName;
try {
// 创建目录
FileUtil.mkdir(FILE_DIR);
FileUtil.writeUtf8String(content, filePath);
log.info("File written successfully to: {}", filePath);
return "File written successfully to: " + filePath;
} catch (Exception e) {
return "Error writing to file: " + e.getMessage();

View File

@ -1,12 +1,11 @@
package com.huangge1199.aiagent.tools;
import cn.hutool.core.io.FileUtil;
import com.itextpdf.kernel.font.PdfFont;
import com.itextpdf.kernel.font.PdfFontFactory;
import com.itextpdf.kernel.pdf.PdfDocument;
import com.itextpdf.kernel.pdf.PdfWriter;
import com.itextpdf.layout.Document;
import com.itextpdf.layout.element.Paragraph;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
@ -18,12 +17,14 @@ import java.io.IOException;
* @author huangge1199
* @since 2025/5/28 16:23:46
*/
@Slf4j
public class PDFGenerationTool {
@Tool(description = "Generate a PDF file with given content")
public String generatePDF(
@ToolParam(description = "Name of the file to save the generated PDF") String fileName,
@ToolParam(description = "Content to be included in the PDF") String content) {
log.info("Generate PDF file with given content");
String fileDir = FileConstant.FILE_SAVE_DIR + "/pdf";
String filePath = fileDir + "/" + fileName;
try {
@ -46,6 +47,7 @@ public class PDFGenerationTool {
// 添加段落并关闭文档
document.add(paragraph);
}
log.info("PDF generated successfully to: {}", filePath);
return "PDF generated successfully to: " + filePath;
} catch (IOException e) {
return "Error generating PDF: " + e.getMessage();

View File

@ -1,5 +1,6 @@
package com.huangge1199.aiagent.tools;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
@ -13,10 +14,12 @@ import java.io.InputStreamReader;
* @author huangge1199
* @since 2025/5/28 15:34:30
*/
@Slf4j
public class TerminalTool {
@Tool(description = "Execute a command in the terminal")
public String executeTerminalCommand(@ToolParam(description = "Command to execute in the terminal") String command) {
log.info("Terminal" + command);
StringBuilder output = new StringBuilder();
try {
Process process = Runtime.getRuntime().exec("cmd.exe /c " + command);
@ -33,6 +36,7 @@ public class TerminalTool {
} catch (IOException | InterruptedException e) {
output.append("Error executing command: ").append(e.getMessage());
}
log.info(output.toString());
return output.toString();
}

View File

@ -1,5 +1,6 @@
package com.huangge1199.aiagent.tools;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
@ -9,10 +10,12 @@ import org.springframework.ai.tool.annotation.ToolParam;
* @author huangge1199
* @since 2025/5/27 15:01:04
*/
@Slf4j
public class WeatherTool {
@Tool(description = "Get current weather for a location")
public String getWeather(@ToolParam(description = "The city name") String city) {
return "Current weather in " + city +": Sunny, 25°";
log.info("Current weather in {}: Sunny, 25°", city);
return "Current weather in " + city + ": Sunny, 25°";
}
}

View File

@ -1,5 +1,6 @@
package com.huangge1199.aiagent.tools;
import lombok.extern.slf4j.Slf4j;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.springframework.ai.tool.annotation.Tool;
@ -13,12 +14,15 @@ import java.io.IOException;
* @author huangge1199
* @since 2025/5/28 15:10:11
*/
@Slf4j
public class WebScrapTool {
@Tool(description = "Scrape the content of a web page")
public String scrapeWebPage(@ToolParam(description = "URL of the web page to scrape") String url) {
try {
log.info("Scraping web page {}", url);
Document doc = Jsoup.connect(url).get();
log.info(doc.toString());
return doc.html();
} catch (IOException e) {
return "Error scraping web page: " + e.getMessage();

View File

@ -4,6 +4,7 @@ import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
@ -18,6 +19,7 @@ import java.util.stream.Collectors;
* @author huangge1199
* @since 2025/5/28 11:00:28
*/
@Slf4j
public class WebSearchTool {
// SearchAPI 的搜索接口地址
@ -32,6 +34,7 @@ public class WebSearchTool {
@Tool(description = "Search for information from Baidu Search Engine")
public String searchWeb(
@ToolParam(description = "Search query keyword") String query) {
log.info("Search query keyword: {}", query);
Map<String, Object> paramMap = new HashMap<>();
paramMap.put("q", query);
paramMap.put("api_key", apiKey);
@ -46,7 +49,7 @@ public class WebSearchTool {
// 拼接搜索结果为字符串
return objects.stream().map(obj -> {
JSONObject tmpJsonObject = (JSONObject) obj;
return tmpJsonObject.get("title").toString();
return tmpJsonObject.toString();
}).collect(Collectors.joining(","));
} catch (Exception e) {
return "Error searching Baidu: " + e.getMessage();