diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java new file mode 100644 index 000000000..5bc3dee0d --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java @@ -0,0 +1,63 @@ +package cn.iocoder.yudao.module.ai.enums; + +import cn.iocoder.yudao.framework.common.core.IntArrayValuable; +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.util.Arrays; + +/** + * AI 写作类型的枚举 + * + * @author xiaoxin + */ +@AllArgsConstructor +@Getter +public enum AiChatRoleEnum implements IntArrayValuable { + + AI_WRITE_ROLE(1, "写作助手", """ + 你是一位出色的写作助手,能够帮助用户生成创意和灵感,并在用户提供场景和提示词时生成对应的回复。你的任务包括: + 1. 撰写建议:根据用户提供的主题或问题,提供详细的写作建议、情节发展方向、角色设定以及背景描写,确保内容结构清晰、有逻辑。 + 2. 回复生成:根据用户提供的场景和提示词,生成合适的对话或文字回复,确保语气和风格符合场景需求。 + 除此之外不需要除了正文内容外的其他回复,如标题、开头、任何解释性语句或道歉。 + """), + AI_MIND_MAP_ROLE(2, "脑图助手", """ + 你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子: + # Geek-AI 助手 + ## 完整的开源系统 + ### 前端开源 + ### 后端开源 + ## 支持各种大模型 + ### OpenAI + ### Azure + ### 文心一言 + ### 通义千问 + ## 集成多种收费方式 + ### 支付宝 + ### 微信 + 除此之外不要任何解释性语句。 + """); + + + /** + * 角色 + */ + private final Integer role; + /** + * 角色名 + */ + private final String name; + + /** + * 角色设定 + */ + private final String prompt; + + public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiChatRoleEnum::getRole).toArray(); + + @Override + public int[] array() { + return ARRAYS; + } + +} diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java index 18bf99710..49d825be8 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java @@ -1,7 +1,5 @@ package cn.iocoder.yudao.module.ai.enums.write; -import cn.hutool.core.util.ArrayUtil; -import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.common.core.IntArrayValuable; import lombok.AllArgsConstructor; import lombok.Getter; @@ -41,9 +39,4 @@ public enum AiWriteTypeEnum implements IntArrayValuable { return ARRAYS; } - public static void validateType(Integer type) { - if (ArrayUtil.contains(ARRAYS, type)) return; - throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", type)); - } - } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/AiMindMapController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/AiMindMapController.java new file mode 100644 index 000000000..015180265 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/AiMindMapController.java @@ -0,0 +1,35 @@ +package cn.iocoder.yudao.module.ai.controller.admin.mindmap; + +import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; +import cn.iocoder.yudao.module.ai.service.mindmap.AiMindMapService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.annotation.Resource; +import jakarta.annotation.security.PermitAll; +import jakarta.validation.Valid; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import reactor.core.publisher.Flux; + +import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; + +@Tag(name = "管理后台 - AI 思维导图") +@RestController +@RequestMapping("/ai/mind-map") +public class AiMindMapController { + + @Resource + private AiMindMapService mindMapService; + + @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + @Operation(summary = "脑图生成(流式)", description = "流式返回,响应较快") + @PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题 + public Flux> generateMindMap(@RequestBody @Valid AiMindMapGenerateReqVO generateReqVO) { + return mindMapService.generateMindMap(generateReqVO, getLoginUserId()); + } + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/vo/AiMindMapGenerateReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/vo/AiMindMapGenerateReqVO.java new file mode 100644 index 000000000..adc47b8ea --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/vo/AiMindMapGenerateReqVO.java @@ -0,0 +1,13 @@ +package cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo; + +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotBlank; +import lombok.Data; + +@Schema(description = "管理后台 - AI 思维导图生成 Request VO") +@Data +public class AiMindMapGenerateReqVO { + @Schema(description = "思维导图内容提示", example = "Java 学习路线") + @NotBlank(message = "思维导图内容提示不能为空") + private String prompt; +} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java index 065165150..21c60420d 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java @@ -11,7 +11,7 @@ import lombok.Data; public class AiWriteGenerateReqVO { @Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") - @InEnum(AiWriteTypeEnum.class) + @InEnum(value = AiWriteTypeEnum.class, message = "写作类型必须是 {value}") private Integer type; @Schema(description = "写作内容提示", example = "1.撰写:田忌赛马;2.回复:不批") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java new file mode 100644 index 000000000..b6b87e92e --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java @@ -0,0 +1,57 @@ +package cn.iocoder.yudao.module.ai.dal.dataobject.mindmap; + +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; + +/** + * AI 思维导图 DO + * + * @author xiaoxin + */ +@TableName(value = "ai_mind_map", autoResultMap = true) +@Data +public class AiMindMapDO extends BaseDO { + + /** + * 编号 + */ + @TableId(type = IdType.AUTO) + private Long id; + + /** + * 用户编号 + */ + private Long userId; + + /** + * 模型 + */ + private String model; + + /** + * 平台 + *

+ * 枚举 {@link AiPlatformEnum} + */ + private String platform; + + /** + * 生成内容提示 + */ + private String prompt; + + /** + * 生成的内容 + */ + private String generatedContent; + + /** + * 错误信息 + */ + private String errorMessage; + +} \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java new file mode 100644 index 000000000..54fa7235a --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java @@ -0,0 +1,14 @@ +package cn.iocoder.yudao.module.ai.dal.mysql.mindmap; + +import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; +import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; +import org.apache.ibatis.annotations.Mapper; + +/** + * AI 音乐 Mapper + * + * @author xiaoxin + */ +@Mapper +public interface AiMindMapMapper extends BaseMapperX { +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatRoleMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatRoleMapper.java index 1ddef8345..ed91edf3f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatRoleMapper.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatRoleMapper.java @@ -4,9 +4,7 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; -import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import org.apache.ibatis.annotations.Mapper; @@ -47,4 +45,10 @@ public interface AiChatRoleMapper extends BaseMapperX { .groupBy(AiChatRoleDO::getCategory)); } + default List selectListByName(String name) { + return selectList(new LambdaQueryWrapperX() + .likeIfPresent(AiChatRoleDO::getName, name) + .orderByAsc(AiChatRoleDO::getSort)); + } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapService.java new file mode 100644 index 000000000..2eb1f1b1a --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapService.java @@ -0,0 +1,23 @@ +package cn.iocoder.yudao.module.ai.service.mindmap; + +import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; +import reactor.core.publisher.Flux; + +/** + * AI 思维导图 Service 接口 + * + * @author xiaoxin + */ +public interface AiMindMapService { + + /** + * 生成思维导图内容 + * + * @param generateReqVO 请求参数 + * @param userId 用户编号 + * @return 生成结果 + */ + Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId); + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java new file mode 100644 index 000000000..5169ea91b --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java @@ -0,0 +1,112 @@ +package cn.iocoder.yudao.module.ai.service.mindmap; + +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.util.StrUtil; +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.util.AiUtils; +import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.framework.common.util.object.BeanUtils; +import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; +import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; +import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; +import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper; +import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum; +import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; +import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; +import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; +import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; +import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; + +/** + * AI 写作 Service 实现类 + * + * @author xiaoxin + */ +@Service +@Slf4j +public class AiMindMapServiceImpl implements AiMindMapService { + + @Resource + private AiApiKeyService apiKeyService; + @Resource + private AiChatModelService chatModalService; + @Resource + private AiChatRoleService chatRoleService; + + @Resource + private AiMindMapMapper mindMapMapper; + + @Override + public Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { + // 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型 + AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); + AiChatModelDO model; + String systemMessage; + if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) { + model = chatModalService.getChatModel(mindMapRole.getModelId()); + systemMessage = mindMapRole.getSystemMessage(); + } else { + model = chatModalService.getRequiredDefaultChatModel(); + systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getPrompt(); + } + + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); + + // 2 插入思维导图信息 + AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); + mindMapMapper.insert(mindMapDO); + + ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); + // 3.1 角色设定 + List chatMessages = new ArrayList<>(); + if (StrUtil.isNotBlank(systemMessage)) { + chatMessages.add(new SystemMessage(systemMessage)); + } + // 3.2 用户输入 + chatMessages.add(new UserMessage(generateReqVO.getPrompt())); + // 3.3 构建提示词 + Prompt prompt = new Prompt(chatMessages, chatOptions); + + Flux streamResponse = chatModel.stream(prompt); + // 3.4 流式返回 + StringBuffer contentBuffer = new StringBuffer(); + return streamResponse.map(chunk -> { + String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; + newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况 + contentBuffer.append(newContent); + // 响应结果 + return success(newContent); + }).doOnComplete(() -> { + // 忽略租户,因为 Flux 异步无法透传租户 + TenantUtils.executeIgnore(() -> + mindMapMapper.updateById(new AiMindMapDO().setId(mindMapDO.getId()).setGeneratedContent(contentBuffer.toString()))); + }).doOnError(throwable -> { + log.error("[generateWriteContent][generateReqVO({}) 发生异常]", generateReqVO, throwable); + // 忽略租户,因为 Flux 异步无法透传租户 + TenantUtils.executeIgnore(() -> + mindMapMapper.updateById(new AiMindMapDO().setId(mindMapDO.getId()).setErrorMessage(throwable.getMessage()))); + }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); + + } + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleService.java index a602d6537..ce0cfe21d 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleService.java @@ -118,4 +118,11 @@ public interface AiChatRoleService { */ List getChatRoleCategoryList(); + /** + * 根据名字获得聊天角色 + * @param name 名字 + * @return 聊天角色列表 + */ + List getChatRoleListByName(String name); + } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleServiceImpl.java index 4f358e473..2cf4d46d1 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleServiceImpl.java @@ -137,5 +137,10 @@ public class AiChatRoleServiceImpl implements AiChatRoleService { return convertList(list, AiChatRoleDO::getCategory, role -> role != null && StrUtil.isNotBlank(role.getCategory())); } + @Override + public List getChatRoleListByName(String name) { + return chatRoleMapper.selectListByName(name); + } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index aa2a259a7..01fe5458f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -5,15 +5,14 @@ import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.common.pojo.CommonResult; -import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; +import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum; import cn.iocoder.yudao.module.ai.enums.DictTypeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; @@ -23,6 +22,9 @@ import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.system.api.dict.DictDataApi; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; @@ -30,6 +32,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; +import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -61,13 +64,15 @@ public class AiWriteServiceImpl implements AiWriteService { @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型 - AiChatRoleDO writeRole = selectOneWriteRole(); + AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); AiChatModelDO model; - // TODO @xin:writeRole.getModelId 可能为空。所以,最好是先通过 chatRole 拿。如果它没拿到,通过 getRequiredDefaultChatModel 再拿。 - if (Objects.nonNull(writeRole)) { + String systemMessage; + if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { model = chatModalService.getChatModel(writeRole.getModelId()); + systemMessage = writeRole.getSystemMessage(); } else { model = chatModalService.getRequiredDefaultChatModel(); + systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getPrompt(); } // 1.2 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); @@ -77,9 +82,16 @@ public class AiWriteServiceImpl implements AiWriteService { AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); writeMapper.insert(writeDO); - // 3.1 构建提示词 ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); - Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); + // 3.1 角色设定 + List chatMessages = new ArrayList<>(); + if (StrUtil.isNotBlank(systemMessage)) { + chatMessages.add(new SystemMessage(systemMessage)); + } + // 3.2 用户输入 + chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO))); + // 3.3 构建提示词 + Prompt prompt = new Prompt(chatMessages, chatOptions); Flux streamResponse = chatModel.stream(prompt); // 3.2 流式返回 @@ -102,24 +114,8 @@ public class AiWriteServiceImpl implements AiWriteService { }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); } - // TODO @xin:chatRoleService 增加一个 getChatRoleListByName; - private AiChatRoleDO selectOneWriteRole() { - AiChatRoleDO chatRoleDO = null; - // TODO @xin:"写作助手" 枚举下。 - PageResult writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手")); - List list = writeRolePage.getList(); - // TODO @xin:CollUtil.getFirst 简化下 - if (CollUtil.isNotEmpty(list)) { - chatRoleDO = list.get(0); - } - return chatRoleDO; - } - private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { - // 校验写作类型是否合法 Integer type = generateReqVO.getType(); - // TODO @xin:这里可以搞到 validator 的校验。InEnum - AiWriteTypeEnum.validateType(type); String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());