From 471968eaf23dcfce6b3f7ba36ddccedbff5d639d Mon Sep 17 00:00:00 2001 From: YunaiV Date: Thu, 4 Jul 2024 23:53:22 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E3=80=91AI=EF=BC=9A=E6=96=B0=E5=A2=9E=20AIUtils=EF=BC=8C?= =?UTF-8?q?=E7=94=A8=E4=BA=8E=E5=AF=B9=E6=8E=A5=20spring=20ai=20=E5=90=84?= =?UTF-8?q?=E7=A7=8D=E5=AF=B9=E8=B1=A1=E7=9A=84=E6=9E=84=E5=BB=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../admin/write/AiWriteController.java | 3 +- .../chat/AiChatMessageServiceImpl.java | 40 +------------ .../ai/service/write/AiWriteServiceImpl.java | 38 +++--------- .../yudao/framework/ai/core/util/AiUtils.java | 59 +++++++++++++++++++ 4 files changed, 71 insertions(+), 69 deletions(-) create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java index a032998ed..c9023552e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java @@ -26,9 +26,10 @@ public class AiWriteController { private AiWriteService writeService; @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - @PermitAll @Operation(summary = "写作生成(流式)", description = "流式返回,响应较快") + @PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题 public Flux> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) { return writeService.generateWriteContent(generateReqVO, getLoginUserId()); } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 48b1e482d..3a9b34fc8 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -4,8 +4,7 @@ import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; +import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; @@ -19,7 +18,6 @@ import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; -import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.*; @@ -28,9 +26,6 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.ollama.api.OllamaOptions; -import org.springframework.ai.openai.OpenAiChatOptions; -import org.springframework.ai.qianfan.QianFanChatOptions; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; @@ -148,46 +143,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } // 1.2 history message 历史消息 List contextMessages = filterContextMessages(messages, conversation, sendReqVO); - contextMessages.forEach(message -> { - // TODO @芋艿:看看有没优化空间 - if (MessageType.USER.getValue().equals(message.getType())) { - chatMessages.add(new UserMessage(message.getContent())); - } else { - chatMessages.add(new AssistantMessage(message.getContent())); - } - }); + contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); // 1.3 user message 新发送消息 chatMessages.add(new UserMessage(sendReqVO.getContent())); // 2. 构建 ChatOptions 对象 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); - ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), + ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), conversation.getTemperature(), conversation.getMaxTokens()); return new Prompt(chatMessages, chatOptions); } - private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { - Float temperatureF = temperature != null ? temperature.floatValue() : null; - //noinspection EnhancedSwitchMigration - switch (platform) { - case OPENAI: - return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - case OLLAMA: - return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); - case YI_YAN: - // TODO 芋艿:貌似 model 只要一设置,就报错 -// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - case XING_HUO: - return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) - .setMaxTokens(maxTokens); - case QIAN_WEN: - return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build(); - default: - throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); - } - } - /** * 从历史消息中,获得倒序的 n 组消息作为消息上下文 * diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index 25fff0332..9e5c0a7ff 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -2,8 +2,7 @@ package cn.iocoder.yudao.module.ai.service.write; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; +import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; @@ -16,16 +15,12 @@ import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.system.api.dict.DictDataApi; -import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.ollama.api.OllamaOptions; -import org.springframework.ai.openai.OpenAiChatOptions; -import org.springframework.ai.qianfan.QianFanChatOptions; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; @@ -56,19 +51,21 @@ public class AiWriteServiceImpl implements AiWriteService { @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { - // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok? + // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的; AiChatModelDO model = chatModalService.getRequiredDefaultChatModel(); StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); - ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); // 1.2 插入写作信息 + // TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性 AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class); writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); // 2.1 构建提示词 + ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); Flux streamResponse = chatClient.stream(prompt); + // 2.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { @@ -92,7 +89,9 @@ public class AiWriteServiceImpl implements AiWriteService { String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat()); String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat()); String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat()); + // TODO @xin:建议改成 if return 哈;更简洁; if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) { + // TODO @xin:写成静态枚举哈 template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length); } else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) { @@ -103,27 +102,4 @@ public class AiWriteServiceImpl implements AiWriteService { } } - // TODO 芋艿:复用 - private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { - Float temperatureF = temperature != null ? temperature.floatValue() : null; - //noinspection EnhancedSwitchMigration - switch (platform) { - case OPENAI: - return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - case OLLAMA: - return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); - case YI_YAN: - // TODO 芋艿:貌似 model 只要一设置,就报错 -// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - case XING_HUO: - return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) - .setMaxTokens(maxTokens); - case QIAN_WEN: - return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build(); - default: - throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); - } - } - } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java new file mode 100644 index 000000000..306b216a3 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java @@ -0,0 +1,59 @@ +package cn.iocoder.yudao.framework.ai.core.util; + +import cn.hutool.core.util.StrUtil; +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; +import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.qianfan.QianFanChatOptions; + +/** + * Spring AI 工具类 + * + * @author 芋道源码 + */ +public class AiUtils { + + public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { + Float temperatureF = temperature != null ? temperature.floatValue() : null; + //noinspection EnhancedSwitchMigration + switch (platform) { + case OPENAI: + return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + case OLLAMA: + return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); + case YI_YAN: + // TODO @xin:貌似 model 只要一设置,就报错;可以排查下 +// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + case XING_HUO: + return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) + .setMaxTokens(maxTokens); + case QIAN_WEN: + return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build(); + default: + throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); + } + } + + public static Message buildMessage(String type, String content) { + if (MessageType.USER.getValue().equals(type)) { + return new UserMessage(content); + } + if (MessageType.ASSISTANT.getValue().equals(type)) { + return new AssistantMessage(content); + } + if (MessageType.SYSTEM.getValue().equals(type)) { + return new SystemMessage(content); + } + if (MessageType.FUNCTION.getValue().equals(type)) { + return new FunctionMessage(content); + } + throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type)); + } + +} \ No newline at end of file