From 77ead4859c61cb216f9d26b651390ac96c7baa92 Mon Sep 17 00:00:00 2001 From: xiaoxin <718949661@qq.com> Date: Wed, 3 Jul 2024 10:52:41 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=A2=9E=E5=8A=A0=E3=80=91AI=20?= =?UTF-8?q?=E5=86=99=E4=BD=9C=EF=BC=9A=E6=94=AF=E6=8C=81=E6=92=B0=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../module/ai/enums/ErrorCodeConstants.java | 5 + .../enums/music/AiMusicGenerateModeEnum.java | 2 +- .../ai/enums/write/AiWriteTypeEnum.java | 37 +++++++ .../admin/write/AiWriteController.java | 4 +- .../admin/write/vo/AiWriteGenerateReqVO.java | 8 +- .../ai/dal/dataobject/write/AiWriteDO.java | 97 +++++++++++++++++++ .../ai/dal/mysql/write/AiWriteMapper.java | 14 +++ .../ai/service/write/AiWriteService.java | 9 +- .../ai/service/write/AiWriteServiceImpl.java | 58 +++++++---- 9 files changed, 207 insertions(+), 27 deletions(-) create mode 100644 yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java create mode 100644 yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java create mode 100644 yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/write/AiWriteMapper.java diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java index 5a3e290a3..fd42fe155 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java @@ -42,4 +42,9 @@ public interface ErrorCodeConstants { // ========== API 音乐 1-040-006-000 ========== ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!"); + + // ========== API 写作 1-022-007-000 ========== + ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!"); + ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "Stream 对话异常!"); + } diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/music/AiMusicGenerateModeEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/music/AiMusicGenerateModeEnum.java index ad4b81b36..651731b60 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/music/AiMusicGenerateModeEnum.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/music/AiMusicGenerateModeEnum.java @@ -7,7 +7,7 @@ import lombok.Getter; import java.util.Arrays; /** - * AI 音乐状态的枚举 + * AI 音乐生成模式的枚举 * * @author xiaoxin */ diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java new file mode 100644 index 000000000..05db29dda --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java @@ -0,0 +1,37 @@ +package cn.iocoder.yudao.module.ai.enums.write; + +import cn.iocoder.yudao.framework.common.core.IntArrayValuable; +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.util.Arrays; + +/** + * AI 写作类型的枚举 + * + * @author xiaoxin + */ +@AllArgsConstructor +@Getter +public enum AiWriteTypeEnum implements IntArrayValuable { + + DESCRIPTION(1, "撰写"), + LYRIC(2, "回复"); + + /** + * 类型 + */ + private final Integer type; + /** + * 类型名 + */ + private final String name; + + public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray(); + + @Override + public int[] array() { + return ARRAYS; + } + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java index d229ba267..6298d44e3 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java @@ -15,6 +15,8 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import reactor.core.publisher.Flux; +import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; + @Tag(name = "管理后台 - AI 写作") @RestController @RequestMapping("/ai/write") @@ -27,6 +29,6 @@ public class AiWriteController { @PermitAll @Operation(summary = "写作生成(流式)", description = "流式返回,响应较快") public Flux> generateComposition(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) { - return writeService.generateComposition(generateReqVO); + return writeService.generateWriteContent(generateReqVO, getLoginUserId()); } } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java index 396c0b8ca..88bcf0568 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java @@ -8,14 +8,14 @@ import lombok.Data; @Data public class AiWriteGenerateReqVO { - @Schema(description = "写作内容", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "田忌赛马") - private String content; + @Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "田忌赛马") + private String contentPrompt; @Schema(description = "原文", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "领导我要辞职") private String originalContent; @Schema(description = "回复内容", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "准了") - private String replyContent; + private String replyContentPrompt; @Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "中等") @NotBlank(message = "长度不能为空") @@ -35,5 +35,5 @@ public class AiWriteGenerateReqVO { @Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") - private Integer writeType; + private Integer writeType; //参见 AiWriteTypeEnum 枚举 } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java new file mode 100644 index 000000000..a569f1096 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java @@ -0,0 +1,97 @@ +package cn.iocoder.yudao.module.ai.dal.dataobject.write; + +import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; +import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; + +/** + * AI 写作 DO + * + * @author xiaoxin + */ +@TableName(value = "ai_write", autoResultMap = true) +@Data +public class AiWriteDO extends BaseDO { + + /** + * 编号 + */ + @TableId(type = IdType.AUTO) + private Long id; + + /** + * 用户编号 + */ + private Long userId; + + /** + * 写作类型 + *

+ * 枚举 {@link AiWriteTypeEnum} + */ + private Integer writeType; + + /** + * 撰写内容提示 + */ + private String contentPrompt; + + /** + * 生成的撰写内容 + */ + private String generatedContent; + + /** + * 原文 + */ + private String originalContent; + + /** + * 回复内容提示 + */ + private String replyContentPrompt; + + /** + * 生成的回复内容 + */ + private String generatedReplyContent; + + /** + * 长度提示词 + */ + private String length; + + /** + * 格式提示词 + */ + private String format; + + /** + * 语气提示词 + */ + private String tone; + + /** + * 语言提示词 + */ + private String language; + + /** + * 模型 + */ + private String model; + + /** + * 平台 + */ + private String platform; + + /** + * 错误信息 + */ + private String errorMessage; + +} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/write/AiWriteMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/write/AiWriteMapper.java new file mode 100644 index 000000000..9564466eb --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/write/AiWriteMapper.java @@ -0,0 +1,14 @@ +package cn.iocoder.yudao.module.ai.dal.mysql.write; + +import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; +import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; +import org.apache.ibatis.annotations.Mapper; + +/** + * AI 音乐 Mapper + * + * @author xiaoxin + */ +@Mapper +public interface AiWriteMapper extends BaseMapperX { +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteService.java index f8fb0634c..0dc349cba 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteService.java @@ -12,7 +12,14 @@ import reactor.core.publisher.Flux; public interface AiWriteService { - Flux> generateComposition(AiWriteGenerateReqVO generateReqVO); + /** + * 生成写作内容 + * + * @param generateReqVO 作文生成请求参数 + * @param userId 用户编号 + * @return 生成结果 + */ + Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index 63d2d6192..3cc185884 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -2,22 +2,27 @@ package cn.iocoder.yudao.module.ai.service.write; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; -import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions; import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; +import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; +import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; +import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.qianfan.QianFanChatOptions; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; @@ -35,16 +40,29 @@ public class AiWriteServiceImpl implements AiWriteService { @Resource private AiApiKeyService apiKeyService; + @Resource + private AiChatModelService chatModalService; + @Resource + private AiWriteMapper writeMapper; @Override - public Flux> generateComposition(AiWriteGenerateReqVO generateReqVO) { - StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(6L); - AiPlatformEnum platform = AiPlatformEnum.validatePlatform("QianWen"); - ChatOptions chatOptions = buildChatOptions(platform, "qwen-72b-chat", 1.0, 1000); + public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { + //TODO 芋艿 写作的模型配置放哪好 先用千问测试 + // 1.1 校验模型 + AiChatModelDO model = chatModalService.validateChatModel(14L); + StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId()); + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); + + //1.2 插入写作信息 + AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class); + writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); + + //2.1 构建提示词 Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); Flux streamResponse = chatClient.stream(prompt); - // 3.3 流式返回 + // 2.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; @@ -53,17 +71,17 @@ public class AiWriteServiceImpl implements AiWriteService { // 响应结果 return success(newContent); }).doOnComplete(() -> { - log.info("generateComposition complete, content: {}", contentBuffer); - }).onErrorResume(error -> { - log.error("[AI 写作] 发生异常", error); - return Flux.just(error(ErrorCodeConstants.AI_CHAT_STREAM_ERROR)); - }); + writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setGeneratedContent(contentBuffer.toString())); + }).doOnError(throwable -> { + log.error("[AI Write][generateReqVO({}) 发生异常]", generateReqVO, throwable); + writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setErrorMessage(throwable.getMessage())); + }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); } private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { - String template = "请直接写一篇关于 [{}] 的文章,格式为:{},语气为:{},语言为:{},长度为:{}。请确保涵盖主要内容,不需要任何额外的解释或道歉。"; - String content = generateReqVO.getContent(); + String template = "请直接写一篇关于 [{}] 的文章,格式为:{},语气为:{},语言为:{},长度为:{}。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; + String content = generateReqVO.getContentPrompt(); String format = generateReqVO.getFormat(); String tone = generateReqVO.getTone(); String language = generateReqVO.getLanguage(); @@ -81,14 +99,14 @@ public class AiWriteServiceImpl implements AiWriteService { case OLLAMA: return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); case YI_YAN: - // TODO @fan:增加一个 model - return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens); + // TODO 芋艿:貌似 model 只要一设置,就报错 +// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build(); case XING_HUO: return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) .setMaxTokens(maxTokens); case QIAN_WEN: - // TODO @fan:增加 model、temperature 参数 - return new QianWenOptions().setMaxTokens(maxTokens); + return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build(); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); }