diff --git a/src/main/java/com/huangge1199/ai/config/SafeInputGuardrail.java b/src/main/java/com/huangge1199/ai/config/SafeInputGuardrail.java new file mode 100644 index 0000000..51b0d70 --- /dev/null +++ b/src/main/java/com/huangge1199/ai/config/SafeInputGuardrail.java @@ -0,0 +1,37 @@ +package com.huangge1199.ai.config; + +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailResult; + +import java.util.Set; + +/** + * 安全检测输入护轨 + * + * @author huangge1199 + * @since 2025/7/14 11:19:01 + */ +public class SafeInputGuardrail implements InputGuardrail { + + private static final Set sensitiveWords = Set.of("kill", "evil"); + + /** + * 检测用户输入是否安全 + */ + @Override + public InputGuardrailResult validate(UserMessage userMessage) { + // 获取用户输入并转换为小写以确保大小写不敏感 + String inputText = userMessage.singleText().toLowerCase(); + // 使用正则表达式分割输入文本为单词 + String[] words = inputText.split("\\W+"); + // 遍历所有单词,检查是否存在敏感词 + for (String word : words) { + if (sensitiveWords.contains(word)) { + return fatal("Sensitive word detected: " + word); + } + } + return success(); + } +} + diff --git a/src/main/java/com/huangge1199/ai/service/LangChainService.java b/src/main/java/com/huangge1199/ai/service/LangChainService.java index 0e1728d..b74118d 100644 --- a/src/main/java/com/huangge1199/ai/service/LangChainService.java +++ b/src/main/java/com/huangge1199/ai/service/LangChainService.java @@ -1,7 +1,9 @@ package com.huangge1199.ai.service; +import com.huangge1199.ai.config.SafeInputGuardrail; import dev.langchain4j.service.Result; import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; import java.util.List; @@ -11,6 +13,7 @@ import java.util.List; * @author huangge1199 * @since 2025/7/12 10:25:27 */ +@InputGuardrails({SafeInputGuardrail.class}) public interface LangChainService { @SystemMessage(fromResource = "system-prompt.txt")