From 79bd78998d25dda9ab7f308c7ca8e9ee67912eed Mon Sep 17 00:00:00 2001 From: cherishsince Date: Sun, 12 May 2024 14:20:24 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E3=80=90=E5=A2=9E=E5=8A=A0=E3=80=91?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=A8=A1=E5=9E=8B=E5=A4=B4=E5=83=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat/vo/message/AiChatMessageRespVO.java | 10 ++++++++++ .../module/ai/dal/mysql/AiChatModelMapper.java | 12 +++++++++++- .../module/ai/service/AiChatModelService.java | 12 ++++++++++++ .../service/impl/AiChatModalServiceImpl.java | 7 +++++++ .../ai/service/impl/AiChatServiceImpl.java | 18 +++++++++++++++++- 5 files changed, 57 insertions(+), 2 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java index d5f830d17..f117c67c6 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java @@ -3,6 +3,8 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message; import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; +import java.time.LocalDateTime; + @Schema(description = "管理后台 - AI 聊天消息 Response VO") @Data public class AiChatMessageRespVO { @@ -19,6 +21,9 @@ public class AiChatMessageRespVO { @Schema(description = "用户编号", example = "4096") private Long userId; // 仅当 user 发送时非空 + @Schema(description = "用户头像", example = "http://xxx") + private Long avatarUrl; // 仅当 user 发送时非空 + @Schema(description = "角色编号", example = "888") private Long roleId; // 仅当 assistant 回复时非空 @@ -28,10 +33,15 @@ public class AiChatMessageRespVO { @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "123") private Long modelId; + @Schema(description = "模型图片", requiredMode = Schema.RequiredMode.REQUIRED, example = "123") + private String modelImage; + @Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊") private String content; @Schema(description = "消耗 Token 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "80") private Integer tokens; + @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51") + private LocalDateTime createTime; } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java index 46d66ff5f..418b25975 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java @@ -9,6 +9,10 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import org.apache.ibatis.annotations.Mapper; import org.springframework.stereotype.Repository; +import java.util.Collection; +import java.util.List; +import java.util.Set; + /** * chat modal * @@ -36,5 +40,11 @@ public interface AiChatModelMapper extends BaseMapperX { return pageResult.getList().get(0); } - + /** + * 查询 - 根据 ids + * + * @param modalIds + * @return + */ + List selectByIds(Collection modalIds); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java index 07a902dd8..d133c729e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java @@ -4,6 +4,9 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; +import java.util.List; +import java.util.Set; + /** * ai modal * @@ -64,4 +67,13 @@ public interface AiChatModelService { * @param chatModal */ void validateAvailable(AiChatModalRespVO chatModal); + + /** + * 获取 - 根据 ids 批量获取 + * + * @param modalIds + * @return + */ + List getModalByIds(Set modalIds); + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java index 0aec7a94c..e25110e1f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java @@ -17,6 +17,7 @@ import cn.iocoder.yudao.module.ai.service.AiChatModelService; import jakarta.validation.ConstraintViolation; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.mybatis.spring.annotation.MapperScannerRegistrar; import org.springframework.stereotype.Service; import java.util.List; @@ -35,6 +36,7 @@ import java.util.Set; public class AiChatModalServiceImpl implements AiChatModelService { private final AiChatModelMapper aiChatModelMapper; + private final MapperScannerRegistrar mapperScannerRegistrar; @Override public PageResult list(AiChatModelListReqVO req) { @@ -102,6 +104,11 @@ public class AiChatModalServiceImpl implements AiChatModelService { } } + @Override + public List getModalByIds(Set modalIds) { + return aiChatModelMapper.selectByIds(modalIds); + } + public AiChatModelDO validateExists(Long id) { AiChatModelDO aiChatModalDO = aiChatModelMapper.selectById(id); if (aiChatModalDO == null) { diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java index 90649f417..de33e65e4 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java @@ -15,6 +15,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRespVO; import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; +import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper; @@ -29,8 +30,11 @@ import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import java.util.stream.Collectors; /** * 聊天 service @@ -188,8 +192,20 @@ public class AiChatServiceImpl implements AiChatService { chatConversationService.validateExists(conversationId); // 获取对话所有 message List aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId); + // 获取模型信息 + Set modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet()); + List modalList = aiChatModalService.getModalByIds(modalIds); + Map modalIdMap = modalList.stream().collect(Collectors.toMap(AiChatModelDO::getId, o -> o)); // 转换 AiChatMessageRespVO - return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList); + List aiChatMessageRespList = AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList); + // 设置用户头像 和 模型头像 todo @芋艿 这里需要转换 用户头像、模型头像 + return aiChatMessageRespList.stream().map(item -> { + if (modalIdMap.containsKey(item.getModelId())) { +// modalIdMap.get(item.getModelId()); +// item.setModelImage() + } + return item; + }).collect(Collectors.toList()); } @Override From 5a4162cdc13bc9cb8939133cb24f1a9a936d3435 Mon Sep 17 00:00:00 2001 From: cherishsince Date: Sun, 12 May 2024 19:04:58 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91?= =?UTF-8?q?=E4=BC=98=E5=8C=96=20chat=20event=20stream=20=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=E4=BA=A4=E4=BA=92=EF=BC=8C=E5=A2=9E=E5=8A=A0=20add=20message?= =?UTF-8?q?=20=E4=BC=98=E5=85=88=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../yudao/module/ai/ErrorCodeConstants.java | 4 + .../admin/chat/AiChatMessageController.java | 11 +- .../vo/message/AiChatMessageAddReqVO.java | 20 +++ .../vo/message/AiChatMessageAddRespVO.java | 17 +++ .../message/AiChatMessageSendStreamReqVO.java | 16 +++ .../ai/convert/AiChatMessageConvert.java | 9 ++ .../ai/dal/mysql/AiChatModelMapper.java | 1 - .../module/ai/service/AiChatModelService.java | 78 ----------- .../module/ai/service/AiChatService.java | 13 +- .../service/impl/AiChatModalServiceImpl.java | 132 ------------------ .../ai/service/impl/AiChatServiceImpl.java | 121 ++++++++++------ .../ai/service/model/AiChatModelService.java | 10 ++ .../service/model/AiChatModelServiceImpl.java | 8 ++ 13 files changed, 181 insertions(+), 259 deletions(-) create mode 100644 yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddReqVO.java create mode 100644 yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddRespVO.java create mode 100644 yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java delete mode 100644 yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java delete mode 100644 yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java index a3c343e12..4e101cd7e 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java @@ -36,5 +36,9 @@ public interface ErrorCodeConstants { ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "AI 角色不存在!"); ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!"); + // chat + + ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_022_000_100, "AI 提问的 MessageId 不存在!"); + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java index 82392ed27..fb0d9f5ad 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java @@ -1,8 +1,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat; import cn.iocoder.yudao.framework.common.pojo.CommonResult; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*; import cn.iocoder.yudao.module.ai.service.AiChatService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -38,10 +37,16 @@ public class AiChatMessageController { // TODO @fan:要不要使用 Flux 来返回;可以使用 Flux @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public Flux sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { + public Flux sendMessageStream(@Validated @RequestBody AiChatMessageSendStreamReqVO sendReqVO) { return chatService.chatStream(sendReqVO); } + @Operation(summary = "添加/提问", description = "先创建好 message 前端才好渲染") + @PostMapping(value = "/add") + public CommonResult add(@Validated @RequestBody AiChatMessageAddReqVO req) { + return success(chatService.add(req)); + } + @Operation(summary = "获得指定会话的消息列表") @GetMapping("/list-by-conversation-id") @Parameter(name = "conversationId", required = true, description = "会话编号", example = "1024") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddReqVO.java new file mode 100644 index 000000000..994472d03 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddReqVO.java @@ -0,0 +1,20 @@ +package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message; + +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotNull; +import lombok.Data; + +@Schema(description = "管理后台 - AI 聊天消息发送 Request VO") +@Data +public class AiChatMessageAddReqVO { + + @Schema(description = "聊天对话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") + @NotNull(message = "聊天对话编号不能为空") + private Long conversationId; + + @Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "帮我写个 Java 算法") + @NotEmpty(message = "聊天内容不能为空") + private String content; + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddRespVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddRespVO.java new file mode 100644 index 000000000..70cfb5b40 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddRespVO.java @@ -0,0 +1,17 @@ +package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +import java.time.LocalDateTime; + +@Schema(description = "管理后台 - AI 聊天消息 Add Response VO") +@Data +public class AiChatMessageAddRespVO { + + @Schema(description = "用户信息") + private AiChatMessageRespVO userMessage; + + @Schema(description = "系统信息") + private AiChatMessageRespVO systemMessage; +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java new file mode 100644 index 000000000..cfd67ccba --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java @@ -0,0 +1,16 @@ +package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message; + +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotNull; +import lombok.Data; + +@Schema(description = "管理后台 - AI 聊天消息发送 Request VO") +@Data +public class AiChatMessageSendStreamReqVO { + + @Schema(description = "提问的 messageId", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") + @NotNull(message = "提问的 messageId 不能为空") + private Long id; + +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java index a5019b2cd..05f7b83b6 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java @@ -26,4 +26,13 @@ public interface AiChatMessageConvert { * @return */ List convertAiChatMessageRespVOList(List aiChatMessageDOList); + + /** + * 转换 - aiChatMessageDO + * + * @param aiChatMessageDO + * @return + */ + AiChatMessageRespVO convertAiChatMessageRespVO(AiChatMessageDO aiChatMessageDO); + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java index 8d00cd66c..7eaef6602 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java @@ -11,7 +11,6 @@ import org.apache.ibatis.annotations.Mapper; import java.util.Collection; import java.util.List; -import java.util.Set; /** * API 聊天模型 Mapper diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java deleted file mode 100644 index aeb7bbccb..000000000 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java +++ /dev/null @@ -1,78 +0,0 @@ -package cn.iocoder.yudao.module.ai.service; - -import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; - -import java.util.List; -import java.util.Set; - -/** - * ai modal - * - * @author fansili - * @time 2024/4/24 19:42 - * @since 1.0 - */ -public interface AiChatModelService { - - /** - * ai modal - 列表 - * - * @param req - * @return - */ - PageResult list(AiChatModelListReqVO req); - - /** - * ai modal - 添加 - * - * @param req - */ - void add(AiChatModelAddReqVO req); - - /** - * ai modal - 更新 - * - * @param req - */ - void update(AiChatModelUpdateReqVO req); - - /** - * ai modal - 删除 - * - * @param id - */ - void delete(Long id); - - /** - * 获取 - 获取 modal - * - * @param modalId - * @return - */ - AiChatModalRespVO getChatModalOfValidate(Long modalId); - - /** - * 校验 - 是否存在 - * - * @param id - * @return - */ - AiChatModelDO validateExists(Long id); - - /** - * 校验 - 校验是否可用 - * - * @param chatModal - */ - void validateAvailable(AiChatModalRespVO chatModal); - - /** - * 获取 - 根据 ids 批量获取 - * - * @param modalIds - * @return - */ - List getModalByIds(Set modalIds); -} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java index a5e97ce5f..7be2b8afc 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java @@ -1,7 +1,6 @@ package cn.iocoder.yudao.module.ai.service; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; -import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*; import reactor.core.publisher.Flux; import java.util.List; @@ -29,7 +28,15 @@ public interface AiChatService { * @param sendReqVO * @return */ - Flux chatStream(AiChatMessageSendReqVO sendReqVO); + Flux chatStream(AiChatMessageSendStreamReqVO sendReqVO); + + /** + * 添加 - message + * + * @param sendReqVO + * @return + */ + AiChatMessageRespVO add(AiChatMessageAddReqVO sendReqVO); /** * 获取 - 获取对话 message list diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java deleted file mode 100644 index 51a972fa7..000000000 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java +++ /dev/null @@ -1,132 +0,0 @@ -package cn.iocoder.yudao.module.ai.service.impl; - -import cn.hutool.core.util.StrUtil; -import cn.hutool.extra.validation.ValidationUtil; -import cn.iocoder.yudao.framework.ai.AiPlatformEnum; -import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; -import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; -import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; -import cn.iocoder.yudao.module.ai.ErrorCodeConstants; -import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*; -import cn.iocoder.yudao.module.ai.convert.AiChatModelConvert; -import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; -import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModelMapper; -import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO; -import cn.iocoder.yudao.module.ai.service.AiChatModelService; -import jakarta.validation.ConstraintViolation; -import lombok.AllArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.springframework.stereotype.Service; - -import java.util.List; -import java.util.Set; - -/** - * ai 模型 - * - * @author fansili - * @time 2024/4/24 19:42 - * @since 1.0 - */ -@AllArgsConstructor -@Service -@Slf4j -public class AiChatModalServiceImpl implements AiChatModelService { - - private final AiChatModelMapper aiChatModelMapper; - - @Override - public PageResult list(AiChatModelListReqVO req) { - LambdaQueryWrapperX queryWrapperX = new LambdaQueryWrapperX<>(); - // 查询的都是未禁用的模型 - queryWrapperX.eq(AiChatModelDO::getStatus, CommonStatusEnum.ENABLE.getStatus()); - // search - if (!StrUtil.isBlank(req.getSearch())) { - queryWrapperX.like(AiChatModelDO::getName, req.getSearch().trim()); - } - // 默认排序 - queryWrapperX.orderByAsc(AiChatModelDO::getSort); - // 查询 - PageResult aiChatModalDOPageResult = aiChatModelMapper.selectPage(req, queryWrapperX); - // 转换 res - List resList = AiChatModelConvert.INSTANCE.convertAiChatModalListRes(aiChatModalDOPageResult.getList()); - return new PageResult<>(resList, aiChatModalDOPageResult.getTotal()); - } - - @Override - public void add(AiChatModelAddReqVO req) { - // 校验 platform、type - validatePlatform(req.getPlatform()); - // 转换 do - AiChatModelDO insertChatModalDO = AiChatModelConvert.INSTANCE.convertAiChatModalDO(req); - // 设置默认属性 - insertChatModalDO.setStatus(CommonStatusEnum.ENABLE.getStatus()); - // 保存数据库 - aiChatModelMapper.insert(insertChatModalDO); - } - - @Override - public void update(AiChatModelUpdateReqVO req) { - // 校验 platform - validatePlatform(req.getPlatform()); - // 校验模型是否存在 - validateExists(req.getId()); - // 转换 updateChatModalDO - AiChatModelDO updateChatModalDO = AiChatModelConvert.INSTANCE.convertAiChatModalDO(req); - updateChatModalDO.setId(req.getId()); - // 更新数据库 - aiChatModelMapper.updateById(updateChatModalDO); - } - - @Override - public void delete(Long id) { - // 检查 modal 是否存在 - validateExists(id); - // 删除 delete - aiChatModelMapper.deleteById(id); - } - - @Override - public AiChatModalRespVO getChatModalOfValidate(Long modalId) { - // 检查 modal 是否存在 - AiChatModelDO aiChatModalDO = validateExists(modalId); - return AiChatModelConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO); - } - - @Override - public void validateAvailable(AiChatModalRespVO chatModal) { - // 对话模型是否可用 - if (!CommonStatusEnum.ENABLE.getStatus().equals(chatModal.getStatus())) { - throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED); - } - } - - @Override - public List getModalByIds(Set modalIds) { - return aiChatModelMapper.selectByIds(modalIds); - } - - public AiChatModelDO validateExists(Long id) { - AiChatModelDO aiChatModalDO = aiChatModelMapper.selectById(id); - if (aiChatModalDO == null) { - throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_NOT_EXIST); - } - return aiChatModalDO; - } - - private void validatePlatform(String platform) { - try { - AiPlatformEnum.valueOfPlatform(platform); - } catch (IllegalArgumentException e) { - throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_PLATFORM_PARAMS_INCORRECT, e.getMessage()); - } - } - - private void validateModalConfig(AiChatModalConfigVO aiChatModalConfigVO) { - Set> validate = ValidationUtil.validate(aiChatModalConfigVO); - for (ConstraintViolation constraintViolation : validate) { - throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_CONFIG_PARAMS_INCORRECT, constraintViolation.getMessage()); - } - } -} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java index c3842e409..a7dcd122d 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java @@ -7,11 +7,15 @@ import cn.iocoder.yudao.framework.ai.chat.ChatResponse; import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient; import cn.iocoder.yudao.framework.ai.chat.messages.MessageType; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; +import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; +import cn.iocoder.yudao.module.ai.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageAddReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; +import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendStreamReqVO; import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; @@ -19,11 +23,12 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.service.AiChatConversationService; -import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.AiChatService; +import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.autoconfigure.http.HttpMessageConverters; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; @@ -53,6 +58,7 @@ public class AiChatServiceImpl implements AiChatService { private final AiChatConversationService chatConversationService; private final AiChatModelService aiChatModalService; private final AiChatRoleService aiChatRoleService; + private final HttpMessageConverters messageConverters; @Transactional(rollbackFor = Exception.class) public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) { @@ -124,7 +130,75 @@ public class AiChatServiceImpl implements AiChatService { return insertChatMessageDO; } - public Flux chatStream(AiChatMessageSendReqVO req) { + public Flux chatStream(AiChatMessageSendStreamReqVO req) { + Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); + // 查询提问的 message + AiChatMessageDO aiChatMessageDO = aiChatMessageMapper.selectById(req.getId()); + if (aiChatMessageDO == null) { + throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST); + } + // 查询对话 + AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(aiChatMessageDO.getConversationId()); + // 获取对话模型 + AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId()); + // 获取角色信息 + AiChatRoleDO aiChatRoleDO = null; + if (conversation.getRoleId() != null) { + aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId()); + } + // 校验角色是否公开 + aiChatRoleService.validateIsPublic(aiChatRoleDO); + // 创建 chat 需要的 Prompt + Prompt prompt = new Prompt(aiChatMessageDO.getContent()); + // 提前创建一个 system message + AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), + chatModel.getModel(), chatModel.getId(), "", + 0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); +// req.setTopK(req.getTopK()); +// req.setTopP(req.getTopP()); +// req.setTemperature(req.getTemperature()); + // 获取 client 类型 + AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform()); + StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); + Flux streamResponse = streamingChatClient.stream(prompt); + // 转换 flex AiChatMessageRespVO + StringBuffer contentBuffer = new StringBuffer(); + AtomicInteger tokens = new AtomicInteger(0); + return streamResponse.map(res -> { + AiChatMessageRespVO aiChatMessageRespVO = + AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(systemMessage); + aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent()); + contentBuffer.append(res.getResult().getOutput().getContent()); + tokens.incrementAndGet(); + return aiChatMessageRespVO; + } + ).doOnComplete(new Runnable() { + @Override + public void run() { + log.info("发送完成!"); + // 保存 chat message + aiChatMessageMapper.updateById(new AiChatMessageDO() + .setId(systemMessage.getId()) + .setContent(contentBuffer.toString()) + .setTokens(tokens.get()) + ); + } + }).doOnError(new Consumer() { + @Override + public void accept(Throwable throwable) { + log.error("发送错误 {}!", throwable.getMessage()); + // 更新错误信息 + aiChatMessageMapper.updateById(new AiChatMessageDO() + .setId(systemMessage.getId()) + .setContent(throwable.getMessage()) + .setTokens(tokens.get()) + ); + } + }); + } + + @Override + public AiChatMessageRespVO add(AiChatMessageAddReqVO req) { Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); // 查询对话 AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId()); @@ -137,48 +211,10 @@ public class AiChatServiceImpl implements AiChatService { } // 校验角色是否公开 aiChatRoleService.validateIsPublic(aiChatRoleDO); - // 创建 chat 需要的 Prompt - Prompt prompt = new Prompt(req.getContent()); -// req.setTopK(req.getTopK()); -// req.setTopP(req.getTopP()); -// req.setTemperature(req.getTemperature()); - // 保存 chat message - insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), + AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), chatModel.getModel(), chatModel.getId(), req.getContent(), null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); - // 获取 client 类型 - AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform()); - StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); - Flux streamResponse = streamingChatClient.stream(prompt); - // 转换 flex AiChatMessageRespVO - StringBuffer contentBuffer = new StringBuffer(); - AtomicInteger tokens = new AtomicInteger(0); - return streamResponse.map(res -> { - AiChatMessageRespVO aiChatMessageRespVO = new AiChatMessageRespVO(); - aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent()); - contentBuffer.append(res.getResult().getOutput().getContent()); - tokens.incrementAndGet(); - return aiChatMessageRespVO; - } - ).doOnComplete(new Runnable() { - @Override - public void run() { - log.info("发送完成!"); - // 保存 chat message - insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), - chatModel.getModel(), chatModel.getId(), contentBuffer.toString(), - tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); - } - }).doOnError(new Consumer() { - @Override - public void accept(Throwable throwable) { - log.error("发送错误 {}!", throwable.getMessage()); - // 保存 chat message - insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), - chatModel.getModel(), chatModel.getId(), throwable.getMessage(), - tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); - } - }); + return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(userMessage); } @Override @@ -207,4 +243,5 @@ public class AiChatServiceImpl implements AiChatService { public Boolean deleteMessage(Long id) { return aiChatMessageMapper.deleteById(id) > 0; } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java index d05941989..de203661e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java @@ -6,6 +6,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatMode import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import jakarta.validation.Valid; +import java.util.List; +import java.util.Set; + /** * AI 聊天模型 Service 接口 * @@ -60,4 +63,11 @@ public interface AiChatModelService { */ AiChatModelDO validateChatModel(Long id); + /** + * 获取 - 根据多个 ids 获取 + * + * @param modalIds + * @return + */ + List getModalByIds(Set modalIds); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java index 0c0386ccc..5e87fd180 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java @@ -12,6 +12,9 @@ import jakarta.annotation.Resource; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; +import java.util.List; +import java.util.Set; + import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.*; @@ -89,4 +92,9 @@ public class AiChatModelServiceImpl implements AiChatModelService { return model; } + @Override + public List getModalByIds(Set modalIds) { + return chatModelMapper.selectByIds(modalIds); + } + } \ No newline at end of file