From 6e71b721e85a5e31f0522733a9f32621a83e3be1 Mon Sep 17 00:00:00 2001 From: xiaoxin <718949661@qq.com> Date: Tue, 9 Jul 2024 22:05:42 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91AI=20?= =?UTF-8?q?=E5=86=99=E4=BD=9C=EF=BC=9A1.=20=E4=BC=98=E5=85=88=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E5=86=99=E4=BD=9C=E8=A7=92=E8=89=B2=EF=BC=9B2.=20?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=86=99=E4=BD=9C=E6=8F=90=E7=A4=BA=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/enums/write/AiWriteTypeEnum.java | 16 ++++- .../ai/service/write/AiWriteServiceImpl.java | 62 +++++++++++++------ 2 files changed, 56 insertions(+), 22 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java index 3a62e1626..69989e5b6 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java @@ -1,5 +1,7 @@ package cn.iocoder.yudao.module.ai.enums.write; +import cn.hutool.core.util.ArrayUtil; +import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.common.core.IntArrayValuable; import lombok.AllArgsConstructor; import lombok.Getter; @@ -15,8 +17,8 @@ import java.util.Arrays; @Getter public enum AiWriteTypeEnum implements IntArrayValuable { - WRITING(1, "撰写"), - REPLY(2, "回复"); + WRITING(1, "撰写", "请撰写一篇关于 [{}] 的文章。文章的内容格式:{},语气:{},语言:{},长度:{}。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"), + REPLY(2, "回复", "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复格式:{},语气:{},语言:{},长度:{}。不需要除了正文内容外的其他回复,如标题、开头、额外的解释或道歉。"); /** * 类型 @@ -27,6 +29,11 @@ public enum AiWriteTypeEnum implements IntArrayValuable { */ private final String name; + /** + * 模版 + */ + private final String template; + public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray(); @Override @@ -34,4 +41,9 @@ public enum AiWriteTypeEnum implements IntArrayValuable { return ARRAYS; } + public static void validateType(Integer type) { + if (ArrayUtil.contains(ARRAYS, type)) return; + throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", type)); + } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index d43c11d3a..e958f06f8 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -1,13 +1,17 @@ package cn.iocoder.yudao.module.ai.service.write; +import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; +import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; import cn.iocoder.yudao.module.ai.enums.DictTypeConstants; @@ -15,6 +19,7 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; +import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.system.api.dict.DictDataApi; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; @@ -25,6 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; +import java.util.List; import java.util.Objects; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; @@ -43,6 +49,8 @@ public class AiWriteServiceImpl implements AiWriteService { private AiApiKeyService apiKeyService; @Resource private AiChatModelService chatModalService; + @Resource + private AiChatRoleService chatRoleService; @Resource private DictDataApi dictDataApi; @@ -52,15 +60,22 @@ public class AiWriteServiceImpl implements AiWriteService { @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { - // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的; - AiChatModelDO model = chatModalService.getRequiredDefaultChatModel(); - StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); + // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型 + AiChatRoleDO writeRole = selectOneWriteRole(); + AiChatModelDO model; + if (Objects.nonNull(writeRole)) { + model = chatModalService.getChatModel(writeRole.getModelId()); + } else { + model = chatModalService.getRequiredDefaultChatModel(); + } + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); + // 1.2 插入写作信息 - // TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性 - AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class); - writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); + AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); + writeMapper.insert(writeDO); // 2.1 构建提示词 ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); @@ -87,23 +102,30 @@ public class AiWriteServiceImpl implements AiWriteService { }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); } + private AiChatRoleDO selectOneWriteRole() { + AiChatRoleDO chatRoleDO = null; + PageResult writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手")); + List list = writeRolePage.getList(); + if (CollUtil.isNotEmpty(list)) { + chatRoleDO = list.get(0); + } + return chatRoleDO; + } + private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { - String template; - Integer writeType = generateReqVO.getType(); + Integer type = generateReqVO.getType(); String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); - String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat()); - String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat()); - String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat()); - // TODO @xin:建议改成 if return 哈;更简洁; - if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) { - // TODO @xin:写成静态枚举哈 - template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; - return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length); - } else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) { - template = "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; - return StrUtil.format(template, generateReqVO.getOriginalContent(), generateReqVO.getPrompt(), format, tone, language, length); + String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); + String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage()); + String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength()); + String prompt = generateReqVO.getPrompt(); + // 校验写作类型是否合法 + AiWriteTypeEnum.validateType(type); + + if (Objects.equals(type, AiWriteTypeEnum.WRITING.getType())) { + return StrUtil.format(AiWriteTypeEnum.WRITING.getTemplate(), prompt, format, tone, language, length); } else { - throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", writeType)); + return StrUtil.format(AiWriteTypeEnum.REPLY.getTemplate(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length); } }