diff --git a/src/main/java/com/huangge1199/aiagent/Service/AiService.java b/src/main/java/com/huangge1199/aiagent/Service/AiService.java new file mode 100644 index 0000000..1253fd5 --- /dev/null +++ b/src/main/java/com/huangge1199/aiagent/Service/AiService.java @@ -0,0 +1,30 @@ +package com.huangge1199.aiagent.Service; + +import reactor.core.publisher.Flux; + +/** + * AiService + * + * @author huangge1199 + * @since 2025/6/10 14:53:53 + */ +public interface AiService { + + /** + * AI 基础对话(支持多轮对话记忆) + * + * @param message 传入信息 + * @param chatId 会话ID + * @return 返回信息 + */ + String doChat(String message, String chatId); + + /** + * AI 基础对话(支持多轮对话记忆,SSE 流式传输) + * + * @param message 传入信息 + * @param chatId 会话ID + * @return 返回信息 + */ + Flux doChatByStream(String message, String chatId); +} diff --git a/src/main/java/com/huangge1199/aiagent/Service/impl/AiServiceImpl.java b/src/main/java/com/huangge1199/aiagent/Service/impl/AiServiceImpl.java new file mode 100644 index 0000000..520d440 --- /dev/null +++ b/src/main/java/com/huangge1199/aiagent/Service/impl/AiServiceImpl.java @@ -0,0 +1,52 @@ +package com.huangge1199.aiagent.Service.impl; + +import com.huangge1199.aiagent.Service.AiService; +import com.huangge1199.aiagent.config.MyLoggerAdvisor; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; + +import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY; +import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY; + +/** + * AiServiceImpl + * + * @author huangge1199 + * @since 2025/6/10 14:54:05 + */ +@Service +@Slf4j +public class AiServiceImpl implements AiService { + + @Resource + private ChatClient chatClient; + + + @Override + public String doChat(String message, String chatId) { + ChatResponse chatResponse = chatClient + .prompt() + .user(message) + .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId) + .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)) + .advisors(new MyLoggerAdvisor()) + .call() + .chatResponse(); + return chatResponse.getResult().getOutput().getText(); + } + + @Override + public Flux doChatByStream(String message, String chatId) { + return chatClient + .prompt() + .user(message) + .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId) + .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)) + .stream() + .content(); + } +} diff --git a/src/main/java/com/huangge1199/aiagent/controller/AiController.java b/src/main/java/com/huangge1199/aiagent/controller/AiController.java new file mode 100644 index 0000000..99a5894 --- /dev/null +++ b/src/main/java/com/huangge1199/aiagent/controller/AiController.java @@ -0,0 +1,112 @@ +package com.huangge1199.aiagent.controller; + +import com.alibaba.dashscope.aigc.generation.GenerationOutput; +import com.alibaba.dashscope.common.Message; +import com.alibaba.dashscope.common.Role; +import com.alibaba.dashscope.exception.InputRequiredException; +import com.alibaba.dashscope.exception.NoApiKeyException; +import com.huangge1199.aiagent.Service.AiService; +import com.huangge1199.aiagent.common.R; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.annotation.Resource; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import reactor.core.publisher.Flux; +import com.alibaba.dashscope.aigc.generation.Generation; +import com.alibaba.dashscope.aigc.generation.GenerationParam; + +import java.io.IOException; +import java.util.Arrays; + +/** + * AiController + * + * @author huangge1199 + * @since 2025/6/10 14:35:39 + */ +@RestController +@RequestMapping("/ai") +@Tag(name = "AI") +public class AiController { + + @Resource + private AiService aiService; + + @Value("${bailian.API-KEY}") + private String apiKey; + + @Operation(summary = "同步调用") + @GetMapping("/sync") + public String doChatWithSync(String message, String chatId) { + return aiService.doChat(message, chatId); + } + + @Operation(summary = "SSE 流式调用") + @GetMapping(value = "/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux doChatWithSse(String message, String chatId) { + return aiService.doChatByStream(message, chatId); + } + +// @GetMapping(value = "/sse") +// public Flux> doChatWithSse(String message, String chatId) { +// return aiService.doChatByStream(message, chatId) +// .map(chunk -> ServerSentEvent.builder() +// .data(chunk) +// .build()); +// } + + @Operation(summary = "SSE Emitter 流式调用") + @GetMapping("/emitter") + public SseEmitter doChatWithSseEmitter(String message, String chatId) { + // 创建一个超时时间较长的 SseEmitter + // 3分钟超时 + SseEmitter emitter = new SseEmitter(180000L); + // 获取 Flux 数据流并直接订阅 + aiService.doChatByStream(message, chatId) + .subscribe( + // 处理每条消息 + chunk -> { + try { + emitter.send(chunk); + } catch (IOException e) { + emitter.completeWithError(e); + } + }, + // 处理错误 + emitter::completeWithError, + // 处理完成 + emitter::complete + ); + // 返回emitter + return emitter; + } + + @Operation(summary = "云百炼测试") + @GetMapping("/yun") + public R yunTest(String message, String model) throws NoApiKeyException, InputRequiredException { + Generation gen = new Generation(); + Message systemMsg = Message.builder() + .role(Role.SYSTEM.getValue()) + .content("You are a helpful assistant.") + .build(); + Message userMsg = Message.builder() + .role(Role.USER.getValue()) + .content(message) + .build(); + GenerationParam param = GenerationParam.builder() + // 若没有配置环境变量,请用百炼API Key将下行替换为:.apiKey("sk-xxx") + .apiKey(apiKey) + .model(model) + .messages(Arrays.asList(systemMsg, userMsg)) + .resultFormat(GenerationParam.ResultFormat.MESSAGE) + .build(); + GenerationOutput output = gen.call(param).getOutput(); + String text = output.getChoices().get(0).getMessage().getContent(); + return R.ok(text); + } +} diff --git a/src/main/java/com/huangge1199/aiagent/rag/RagConfig.java b/src/main/java/com/huangge1199/aiagent/rag/RagConfig.java index b0c079e..affc723 100644 --- a/src/main/java/com/huangge1199/aiagent/rag/RagConfig.java +++ b/src/main/java/com/huangge1199/aiagent/rag/RagConfig.java @@ -15,7 +15,7 @@ public class RagConfig { @Bean ChatClient chatClient(ChatClient.Builder builder) { - return builder.defaultSystem("你将作为一名恋爱大师,对于用户的问题作出解答") + return builder.defaultSystem("你将作为一名旅游规划大师,对于用户的问题作出解答") .build(); } }