【新增】AI:对话管理 50%

This commit is contained in:
YunaiV 2024-05-25 00:08:27 +08:00
parent 4fddec5f02
commit 2d11f085c8
16 changed files with 196 additions and 68 deletions

View File

@ -1,27 +1,35 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat; package cn.iocoder.yudao.module.ai.controller.admin.chat;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService; import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.chat.AiChatMessageService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
@Tag(name = "管理后台 - AI 聊天") @Tag(name = "管理后台 - AI 聊天")
@RestController @RestController
@RequestMapping("/ai/chat/conversation") @RequestMapping("/ai/chat/conversation")
@Validated @Validated
@ -29,30 +37,32 @@ public class AiChatConversationController {
@Resource @Resource
private AiChatConversationService chatConversationService; private AiChatConversationService chatConversationService;
@Resource
private AiChatMessageService chatMessageService;
@PostMapping("/create-my") @PostMapping("/create-my")
@Operation(summary = "创建【我的】聊天") @Operation(summary = "创建【我的】聊天")
public CommonResult<Long> createChatConversationMy(@RequestBody @Valid AiChatConversationCreateMyReqVO createReqVO) { public CommonResult<Long> createChatConversationMy(@RequestBody @Valid AiChatConversationCreateMyReqVO createReqVO) {
return success(chatConversationService.createChatConversationMy(createReqVO, getLoginUserId())); return success(chatConversationService.createChatConversationMy(createReqVO, getLoginUserId()));
} }
@PutMapping("/update-my") @PutMapping("/update-my")
@Operation(summary = "更新【我的】聊天") @Operation(summary = "更新【我的】聊天")
public CommonResult<Boolean> updateChatConversationMy(@RequestBody @Valid AiChatConversationUpdateMyReqVO updateReqVO) { public CommonResult<Boolean> updateChatConversationMy(@RequestBody @Valid AiChatConversationUpdateMyReqVO updateReqVO) {
chatConversationService.updateChatConversationMy(updateReqVO, getLoginUserId()); chatConversationService.updateChatConversationMy(updateReqVO, getLoginUserId());
return success(true); return success(true);
} }
@GetMapping("/my-list") @GetMapping("/my-list")
@Operation(summary = "获得【我的】聊天话列表") @Operation(summary = "获得【我的】聊天话列表")
public CommonResult<List<AiChatConversationRespVO>> getChatConversationMyList() { public CommonResult<List<AiChatConversationRespVO>> getChatConversationMyList() {
List<AiChatConversationDO> list = chatConversationService.getChatConversationListByUserId(getLoginUserId()); List<AiChatConversationDO> list = chatConversationService.getChatConversationListByUserId(getLoginUserId());
return success(BeanUtils.toBean(list, AiChatConversationRespVO.class)); return success(BeanUtils.toBean(list, AiChatConversationRespVO.class));
} }
@GetMapping("/get-my") @GetMapping("/get-my")
@Operation(summary = "获得【我的】聊天") @Operation(summary = "获得【我的】聊天")
@Parameter(name = "id", required = true, description = "话编号", example = "1024") @Parameter(name = "id", required = true, description = "话编号", example = "1024")
public CommonResult<AiChatConversationRespVO> getChatConversationMy(@RequestParam("id") Long id) { public CommonResult<AiChatConversationRespVO> getChatConversationMy(@RequestParam("id") Long id) {
AiChatConversationDO conversation = chatConversationService.getChatConversation(id); AiChatConversationDO conversation = chatConversationService.getChatConversation(id);
if (conversation != null && ObjUtil.notEqual(conversation.getUserId(), getLoginUserId())) { if (conversation != null && ObjUtil.notEqual(conversation.getUserId(), getLoginUserId())) {
@ -62,20 +72,36 @@ public class AiChatConversationController {
} }
@DeleteMapping("/delete-my") @DeleteMapping("/delete-my")
@Operation(summary = "删除聊天") @Operation(summary = "删除聊天")
@Parameter(name = "id", required = true, description = "话编号", example = "1024") @Parameter(name = "id", required = true, description = "话编号", example = "1024")
public CommonResult<Boolean> deleteChatConversationMy(@RequestParam("id") Long id) { public CommonResult<Boolean> deleteChatConversationMy(@RequestParam("id") Long id) {
chatConversationService.deleteChatConversationMy(id, getLoginUserId()); chatConversationService.deleteChatConversationMy(id, getLoginUserId());
return success(true); return success(true);
} }
// TODO 芋艿这个 url 可以改下
@DeleteMapping("/delete-my-all-except-pinned") @DeleteMapping("/delete-my-all-except-pinned")
@Operation(summary = "删除所有对话(置顶除外)") @Operation(summary = "删除所有对话(置顶除外)")
@Parameter(name = "id", required = true, description = "会话编号", example = "1024") public CommonResult<Boolean> deleteChatConversationMyByUnpinned() {
public CommonResult<Boolean> deleteMyAllExceptPinned() { chatConversationService.deleteChatConversationMyByUnpinned(getLoginUserId());
chatConversationService.deleteMyAllExceptPinned(getLoginUserId());
return success(true); return success(true);
} }
// ========== 会话管理 ==========
// ========== 对话管理 ==========
@GetMapping("/page")
@Operation(summary = "获得对话分页", description = "用于【对话管理】菜单")
@PreAuthorize("@ss.hasPermission('ai:chat-conversation:query')")
public CommonResult<PageResult<AiChatConversationRespVO>> getChatConversationPage(AiChatConversationPageReqVO pageReqVO) {
PageResult<AiChatConversationDO> pageResult = chatConversationService.getChatConversationPage(pageReqVO);
if (CollUtil.isEmpty(pageResult.getList())) {
return success(PageResult.empty());
}
// 拼接关联数据
Map<Long, Integer> messageCountMap = chatMessageService.getChatMessageCountMap(
convertList(pageResult.getList(), AiChatConversationDO::getId));
return success(BeanUtils.toBean(pageResult, AiChatConversationRespVO.class,
conversation -> conversation.setMessageCount(messageCountMap.getOrDefault(conversation.getId(), 0))));
}
} }

View File

@ -62,9 +62,9 @@ public class AiChatMessageController {
return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId()); return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId());
} }
@Operation(summary = "获得指定话的消息列表") @Operation(summary = "获得指定话的消息列表")
@GetMapping("/list-by-conversation-id") @GetMapping("/list-by-conversation-id")
@Parameter(name = "conversationId", required = true, description = "话编号", example = "1024") @Parameter(name = "conversationId", required = true, description = "话编号", example = "1024")
public CommonResult<List<AiChatMessageRespVO>> getChatMessageListByConversationId( public CommonResult<List<AiChatMessageRespVO>> getChatMessageListByConversationId(
@RequestParam("conversationId") Long conversationId) { @RequestParam("conversationId") Long conversationId) {
AiChatConversationDO conversation = chatConversationService.getChatConversation(conversationId); AiChatConversationDO conversation = chatConversationService.getChatConversation(conversationId);
@ -93,12 +93,16 @@ public class AiChatMessageController {
return success(true); return success(true);
} }
@Operation(summary = "删除指定话的消息") @Operation(summary = "删除指定话的消息")
@DeleteMapping("/delete-by-conversation-id") @DeleteMapping("/delete-by-conversation-id")
@Parameter(name = "conversationId", required = true, description = "话编号", example = "1024") @Parameter(name = "conversationId", required = true, description = "话编号", example = "1024")
public CommonResult<Boolean> deleteChatMessageByConversationId(@RequestParam("conversationId") Long conversationId) { public CommonResult<Boolean> deleteChatMessageByConversationId(@RequestParam("conversationId") Long conversationId) {
chatMessageService.deleteChatMessageByConversationId(conversationId, getLoginUserId()); chatMessageService.deleteChatMessageByConversationId(conversationId, getLoginUserId());
return success(true); return success(true);
} }
// ========== 对话管理 ==========
} }

View File

@ -3,7 +3,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data; import lombok.Data;
@Schema(description = "管理后台 - AI 聊天话创建【我的】 Request VO") @Schema(description = "管理后台 - AI 聊天话创建【我的】 Request VO")
@Data @Data
public class AiChatConversationCreateMyReqVO { public class AiChatConversationCreateMyReqVO {

View File

@ -0,0 +1,26 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import org.springframework.format.annotation.DateTimeFormat;
import java.time.LocalDateTime;
import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND;
@Schema(description = "管理后台 - AI 聊天对话的分页 Request VO")
@Data
public class AiChatConversationPageReqVO extends PageParam {
@Schema(description = "用户编号", example = "1024")
private Long userId;
@Schema(description = "对话标题", example = "你好")
private String title;
@Schema(description = "创建时间")
@DateTimeFormat(pattern = FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND)
private LocalDateTime[] createTime;
}

View File

@ -10,24 +10,24 @@ import lombok.Data;
import java.time.LocalDateTime; import java.time.LocalDateTime;
@Schema(description = "管理后台 - AI 聊天话 Response VO") @Schema(description = "管理后台 - AI 聊天话 Response VO")
@Data @Data
public class AiChatConversationRespVO implements VO { public class AiChatConversationRespVO implements VO {
@Schema(description = "话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") @Schema(description = "话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
private Long id; private Long id;
@Schema(description = "用户编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "2048") @Schema(description = "用户编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "2048")
private Long userId; private Long userId;
@Schema(description = "话标题", requiredMode = Schema.RequiredMode.REQUIRED, example = "我是一个标题") @Schema(description = "话标题", requiredMode = Schema.RequiredMode.REQUIRED, example = "我是一个标题")
private String title; private String title;
@Schema(description = "是否置顶", requiredMode = Schema.RequiredMode.REQUIRED, example = "true") @Schema(description = "是否置顶", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
private Boolean pinned; private Boolean pinned;
@Schema(description = "角色编号", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1") @Schema(description = "角色编号", example = "1")
@Trans(type = TransType.SIMPLE, target = AiChatRoleDO.class, fields = "avatar", ref = "roleAvatar") @Trans(type = TransType.SIMPLE, target = AiChatRoleDO.class, fields = {"name", "avatar"}, refs = {"roleName", "roleAvatar"})
private Long roleId; private Long roleId;
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@ -52,12 +52,20 @@ public class AiChatConversationRespVO implements VO {
@Schema(description = "上下文的最大 Message 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "10") @Schema(description = "上下文的最大 Message 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "10")
private Integer maxContexts; private Integer maxContexts;
@Schema(description = "最后更新时间", requiredMode = Schema.RequiredMode.REQUIRED) @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime updateTime; private LocalDateTime createTime;
// ========== 关联 role 信息 ========== // ========== 关联 role 信息 ==========
@Schema(description = "角色头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://www.iocoder.cn/1.png") @Schema(description = "角色头像", example = "https://www.iocoder.cn/1.png")
private String roleAvatar; private String roleAvatar;
@Schema(description = "角色名字", example = "小黄")
private String roleName;
// ========== 仅在对话管理时加载 ==========
@Schema(description = "消息数量", example = "20")
private Integer messageCount;
} }

View File

@ -4,15 +4,15 @@ import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
@Schema(description = "管理后台 - AI 聊天话更新【我的】 Request VO") @Schema(description = "管理后台 - AI 聊天话更新【我的】 Request VO")
@Data @Data
public class AiChatConversationUpdateMyReqVO { public class AiChatConversationUpdateMyReqVO {
@Schema(description = "话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") @Schema(description = "话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
@NotNull(message = "话编号不能为空") @NotNull(message = "话编号不能为空")
private Long id; private Long id;
@Schema(description = "话标题", example = "我是一个标题") @Schema(description = "话标题", example = "我是一个标题")
private String title; private String title;
@Schema(description = "是否置顶", example = "true") @Schema(description = "是否置顶", example = "true")

View File

@ -12,7 +12,7 @@ public class AiChatMessageRespVO {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
private Long id; private Long id;
@Schema(description = "话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "2048") @Schema(description = "话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "2048")
private Long conversationId; private Long conversationId;
@Schema(description = "消息类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "role") @Schema(description = "消息类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "role")

View File

@ -13,7 +13,7 @@ import java.time.LocalDateTime;
import java.util.Date; import java.util.Date;
/** /**
* AI Chat DO * AI Chat DO
* *
* 用户每次发起 Chat 聊天时会创建一个 {@link AiChatConversationDO} 对象将它的消息关联在一起 * 用户每次发起 Chat 聊天时会创建一个 {@link AiChatConversationDO} 对象将它的消息关联在一起
* *
@ -45,7 +45,7 @@ public class AiChatConversationDO extends BaseDO {
private Long userId; private Long userId;
/** /**
* 话标题 * 话标题
* *
* 默认由系统自动生成可用户手动修改 * 默认由系统自动生成可用户手动修改
*/ */
@ -79,7 +79,7 @@ public class AiChatConversationDO extends BaseDO {
*/ */
private String model; private String model;
// ========== 话配置 ========== // ========== 话配置 ==========
/** /**
* 角色设定 * 角色设定

View File

@ -32,7 +32,7 @@ public class AiChatMessageDO extends BaseDO {
private Long id; private Long id;
/** /**
* 话编号 * 话编号
* *
* 关联 {@link AiChatConversationDO#getId()} 字段 * 关联 {@link AiChatConversationDO#getId()} 字段
*/ */

View File

@ -62,7 +62,7 @@ public class AiChatModelDO extends BaseDO {
*/ */
private Integer status; private Integer status;
// ========== 话配置 ========== // ========== 话配置 ==========
/** /**
* 温度参数 * 温度参数

View File

@ -1,7 +1,9 @@
package cn.iocoder.yudao.module.ai.dal.mysql.chat; package cn.iocoder.yudao.module.ai.dal.mysql.chat;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
@ -22,4 +24,18 @@ public interface AiChatConversationMapper extends BaseMapperX<AiChatConversation
.orderByAsc(AiChatConversationDO::getCreateTime)); .orderByAsc(AiChatConversationDO::getCreateTime));
} }
default List<AiChatConversationDO> selectListByUserIdAndPinned(Long userId, boolean pinned) {
return selectList(new LambdaQueryWrapperX<AiChatConversationDO>()
.eq(AiChatConversationDO::getUserId, userId)
.eq(AiChatConversationDO::getPinned, pinned));
}
default PageResult<AiChatConversationDO> selectChatConversationPage(AiChatConversationPageReqVO pageReqVO) {
return selectPage(pageReqVO, new LambdaQueryWrapperX<AiChatConversationDO>()
.eqIfPresent(AiChatConversationDO::getUserId, pageReqVO.getUserId())
.likeIfPresent(AiChatConversationDO::getTitle, pageReqVO.getTitle())
.betweenIfPresent(AiChatConversationDO::getCreateTime, pageReqVO.getCreateTime())
.orderByDesc(AiChatConversationDO::getId));
}
} }

View File

@ -1,11 +1,18 @@
package cn.iocoder.yudao.module.ai.dal.mysql.chat; package cn.iocoder.yudao.module.ai.dal.mysql.chat;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.map.MapUtil;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* AI 聊天对话 Mapper * AI 聊天对话 Mapper
@ -21,4 +28,19 @@ public interface AiChatMessageMapper extends BaseMapperX<AiChatMessageDO> {
.orderByAsc(AiChatMessageDO::getId)); .orderByAsc(AiChatMessageDO::getId));
} }
default Map<Long, Integer> selectCountMapByConversationId(Collection<Long> conversationIds) {
// SQL count 查询
List<Map<String, Object>> result = selectMaps(new QueryWrapper<AiChatMessageDO>()
.select("COUNT(id) AS count, conversation_id AS conversationId")
.in("conversation_id", conversationIds)
.groupBy("conversation_id"));
if (CollUtil.isEmpty(result)) {
return Collections.emptyMap();
}
// 转换数据
return CollectionUtils.convertMap(result,
record -> MapUtil.getLong(record, "conversationId"),
record -> MapUtil.getInt(record, "count" ));
}
} }

View File

@ -1,6 +1,8 @@
package cn.iocoder.yudao.module.ai.service.chat; package cn.iocoder.yudao.module.ai.service.chat;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
@ -14,7 +16,7 @@ import java.util.List;
public interface AiChatConversationService { public interface AiChatConversationService {
/** /**
* 创建我的聊天 * 创建我的聊天
* *
* @param createReqVO 创建信息 * @param createReqVO 创建信息
* @param userId 用户编号 * @param userId 用户编号
@ -23,7 +25,7 @@ public interface AiChatConversationService {
Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId); Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId);
/** /**
* 更新我的聊天 * 更新我的聊天
* *
* @param updateReqVO 更新信息 * @param updateReqVO 更新信息
* @param userId 用户编号 * @param userId 用户编号
@ -31,23 +33,23 @@ public interface AiChatConversationService {
void updateChatConversationMy(AiChatConversationUpdateMyReqVO updateReqVO, Long userId); void updateChatConversationMy(AiChatConversationUpdateMyReqVO updateReqVO, Long userId);
/** /**
* 获得我的聊天话列表 * 获得我的聊天话列表
* *
* @param userId 用户编号 * @param userId 用户编号
* @return 聊天话列表 * @return 聊天话列表
*/ */
List<AiChatConversationDO> getChatConversationListByUserId(Long userId); List<AiChatConversationDO> getChatConversationListByUserId(Long userId);
/** /**
* 获得聊天 * 获得聊天
* *
* @param id 编号 * @param id 编号
* @return 聊天 * @return 聊天
*/ */
AiChatConversationDO getChatConversation(Long id); AiChatConversationDO getChatConversation(Long id);
/** /**
* 删除我的聊天 * 删除我的聊天
* *
* @param id 编号 * @param id 编号
* @param userId 用户编号 * @param userId 用户编号
@ -55,17 +57,20 @@ public interface AiChatConversationService {
void deleteChatConversationMy(Long id, Long userId); void deleteChatConversationMy(Long id, Long userId);
/** /**
* 校验 - 是否存在 * 校验聊天对话是否存在
* *
* @param id * @param id 编号
* @return * @return 聊天对话
*/ */
AiChatConversationDO validateExists(Long id); AiChatConversationDO validateChatConversationExists(Long id);
/** /**
* 删除 - 所有对话置顶除外 * 删除我的 + 非置顶的聊天对话
* *
* @param loginUserId * @param userId 用户编号
*/ */
void deleteMyAllExceptPinned(Long loginUserId); void deleteChatConversationMyByUnpinned(Long userId);
PageResult<AiChatConversationDO> getChatConversationPage(AiChatConversationPageReqVO pageReqVO);
} }

View File

@ -1,11 +1,13 @@
package cn.iocoder.yudao.module.ai.service.chat; package cn.iocoder.yudao.module.ai.service.chat;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert; import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
@ -22,6 +24,8 @@ import java.time.LocalDateTime;
import java.util.List; import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_MODEL_ERROR; import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_MODEL_ERROR;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS; import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
@ -69,7 +73,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
@Override @Override
public void updateChatConversationMy(AiChatConversationUpdateMyReqVO updateReqVO, Long userId) { public void updateChatConversationMy(AiChatConversationUpdateMyReqVO updateReqVO, Long userId) {
// 1.1 校验对话是否存在 // 1.1 校验对话是否存在
AiChatConversationDO conversation = validateExists(updateReqVO.getId()); AiChatConversationDO conversation = validateChatConversationExists(updateReqVO.getId());
if (ObjUtil.notEqual(conversation.getUserId(), userId)) { if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS); throw exception(CHAT_CONVERSATION_NOT_EXISTS);
} }
@ -103,7 +107,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
@Override @Override
public void deleteChatConversationMy(Long id, Long userId) { public void deleteChatConversationMy(Long id, Long userId) {
// 1. 校验对话是否存在 // 1. 校验对话是否存在
AiChatConversationDO conversation = validateExists(id); AiChatConversationDO conversation = validateChatConversationExists(id);
if (ObjUtil.notEqual(conversation.getUserId(), userId)) { if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS); throw exception(CHAT_CONVERSATION_NOT_EXISTS);
} }
@ -119,7 +123,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
throw exception(CHAT_CONVERSATION_MODEL_ERROR); throw exception(CHAT_CONVERSATION_MODEL_ERROR);
} }
public AiChatConversationDO validateExists(Long id) { public AiChatConversationDO validateChatConversationExists(Long id) {
AiChatConversationDO conversation = chatConversationMapper.selectById(id); AiChatConversationDO conversation = chatConversationMapper.selectById(id);
if (conversation == null) { if (conversation == null) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS); throw exception(CHAT_CONVERSATION_NOT_EXISTS);
@ -128,12 +132,17 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
} }
@Override @Override
public void deleteMyAllExceptPinned(Long loginUserId) { public void deleteChatConversationMyByUnpinned(Long userId) {
chatConversationMapper.delete( List<AiChatConversationDO> list = chatConversationMapper.selectListByUserIdAndPinned(userId, false);
new LambdaQueryWrapperX<AiChatConversationDO>() if (CollUtil.isEmpty(list)) {
.eq(AiChatConversationDO::getUserId, loginUserId) return;
.eq(AiChatConversationDO::getPinned, false) }
); chatConversationMapper.deleteBatchIds(convertList(list, AiChatConversationDO::getId));
}
@Override
public PageResult<AiChatConversationDO> getChatConversationPage(AiChatConversationPageReqVO pageReqVO) {
return chatConversationMapper.selectChatConversationPage(pageReqVO);
} }
} }

View File

@ -5,7 +5,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* AI 聊天消息 Service 接口 * AI 聊天消息 Service 接口
@ -32,9 +34,9 @@ public interface AiChatMessageService {
Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId); Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId);
/** /**
* 获得指定话的消息列表 * 获得指定话的消息列表
* *
* @param conversationId 话编号 * @param conversationId 话编号
* @return 消息列表 * @return 消息列表
*/ */
List<AiChatMessageDO> getChatMessageListByConversationId(Long conversationId); List<AiChatMessageDO> getChatMessageListByConversationId(Long conversationId);
@ -48,11 +50,19 @@ public interface AiChatMessageService {
void deleteChatMessage(Long id, Long userId); void deleteChatMessage(Long id, Long userId);
/** /**
* 删除指定话的消息 * 删除指定话的消息
* *
* @param conversationId 话编号 * @param conversationId 话编号
* @param userId 用户编号 * @param userId 用户编号
*/ */
void deleteChatMessageByConversationId(Long conversationId, Long userId); void deleteChatMessageByConversationId(Long conversationId, Long userId);
/**
* 获得聊天对话的消息数量 Map
*
* @param conversationIds 对话编号数组
* @return 消息数量 Map
*/
Map<Long, Integer> getChatMessageCountMap(Collection<Long> conversationIds);
} }

View File

@ -5,7 +5,6 @@ import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
@ -15,7 +14,6 @@ import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.system.api.user.AdminUserApi; import cn.iocoder.yudao.module.system.api.user.AdminUserApi;
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO; import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.reactivestreams.Publisher;
import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.messages.*;
@ -34,7 +32,6 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.*; import java.util.*;
@ -115,7 +112,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Override @Override
public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) { public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) {
// 1.1 校验对话存在 // 1.1 校验对话存在
AiChatConversationDO conversation = chatConversationService.validateExists(sendReqVO.getConversationId()); AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
if (ObjUtil.notEqual(conversation.getUserId(), userId)) { if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿异常情况的对接 throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿异常情况的对接
} }
@ -189,7 +186,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
* n 指的是 user + assistant 形成一组 * n 指的是 user + assistant 形成一组
* *
* @param messages 消息列表 * @param messages 消息列表
* @param conversation * @param conversation
* @param sendReqVO 发送请求 * @param sendReqVO 发送请求
* @return 消息上下文 * @return 消息上下文
*/ */
@ -258,4 +255,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
chatMessageMapper.deleteBatchIds(convertList(messages, AiChatMessageDO::getId)); chatMessageMapper.deleteBatchIds(convertList(messages, AiChatMessageDO::getId));
} }
@Override
public Map<Long, Integer> getChatMessageCountMap(Collection<Long> conversationIds) {
return chatMessageMapper.selectCountMapByConversationId(conversationIds);
}
} }