diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java new file mode 100644 index 000000000..e090aaf2d --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java @@ -0,0 +1,29 @@ +package cn.iocoder.yudao.module.ai.controller.admin.write; + +import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.module.ai.service.write.AiWriteService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.annotation.Resource; +import jakarta.annotation.security.PermitAll; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import reactor.core.publisher.Flux; + +@Tag(name = "管理后台 - AI 写作") +@RestController +@RequestMapping("/ai/write") +public class AiWriteController { + + @Resource + private AiWriteService writeService; + + @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + @PermitAll + @Operation(summary = "作文生产(流式)", description = "流式返回,响应较快") + public Flux> generateComposition() { + return writeService.generateComposition(); + } +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteService.java new file mode 100644 index 000000000..94a92dbb8 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteService.java @@ -0,0 +1,17 @@ +package cn.iocoder.yudao.module.ai.service.write; + +import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import reactor.core.publisher.Flux; + +/** + * AI 写作 Service 接口 + * + * @author xiaoxin + */ +public interface AiWriteService { + + + Flux> generateComposition(); + + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java new file mode 100644 index 000000000..241c177c2 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -0,0 +1,84 @@ +package cn.iocoder.yudao.module.ai.service.write; + +import cn.hutool.core.util.StrUtil; +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; +import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions; +import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; +import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; + +import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; +import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; + +/** + * AI 写作 Service 实现类 + * + * @author xiaoxin + */ +@Service +@Slf4j +public class AiWriteServiceImpl implements AiWriteService { + + @Resource + private AiApiKeyService apiKeyService; + + + @Override + public Flux> generateComposition() { + StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(6L); + AiPlatformEnum platform = AiPlatformEnum.validatePlatform("QianWen"); + ChatOptions chatOptions = buildChatOptions(platform, "qwen-72b-chat", 1.0, 1000); + Prompt prompt = new Prompt("请直接写一篇关于 气候变化 的文章,格式为自动,语气为自动,语言为自动,长度为自动。请确保涵盖主要观点,不需要标题,不需要任何额外的解释或道歉。", chatOptions); + Flux streamResponse = chatClient.stream(prompt); + // 3.3 流式返回 + StringBuffer contentBuffer = new StringBuffer(); + return streamResponse.map(chunk -> { + String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; + newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况 + contentBuffer.append(newContent); + // 响应结果 + return success(newContent); + }).doOnComplete(() -> { + log.info("generateComposition complete, content: {}", contentBuffer); + }).onErrorResume(error -> { + return Flux.just(error(ErrorCodeConstants.AI_CHAT_STREAM_ERROR)); + }); + } + + + private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { + Float temperatureF = temperature != null ? temperature.floatValue() : null; + //noinspection EnhancedSwitchMigration + switch (platform) { + case OPENAI: + return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + case OLLAMA: + return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); + case YI_YAN: + // TODO @fan:增加一个 model + return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens); + case XING_HUO: + return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) + .setMaxTokens(maxTokens); + case QIAN_WEN: + // TODO @fan:增加 model、temperature 参数 + return new QianWenOptions().setMaxTokens(maxTokens); + default: + throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); + } + } + +}