【优化】AI:聊天消息的 Service 实现

This commit is contained in:
YunaiV 2024-05-22 13:12:54 +08:00
parent 3b02bcf4f8
commit a482876113
13 changed files with 198 additions and 227 deletions

View File

@ -30,8 +30,8 @@ public interface ErrorCodeConstants {
ErrorCode CHAT_CONVERSATION_UPDATE_MAX_TOKENS_ERROR = new ErrorCode(1_040_003_002, "更新对话失败,最大 Token 超过上限"); ErrorCode CHAT_CONVERSATION_UPDATE_MAX_TOKENS_ERROR = new ErrorCode(1_040_003_002, "更新对话失败,最大 Token 超过上限");
ErrorCode CHAT_CONVERSATION_UPDATE_MAX_CONTEXTS_ERROR = new ErrorCode(1_040_003_002, "更新对话失败,最大 Context 超过上限"); ErrorCode CHAT_CONVERSATION_UPDATE_MAX_CONTEXTS_ERROR = new ErrorCode(1_040_003_002, "更新对话失败,最大 Context 超过上限");
// chat // ========== API 聊天消息 1-040-004-000 ==========
ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_022_000_100, "提问的 MessageId 不存在!"); ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
// midjourney // midjourney

View File

@ -1,8 +1,19 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat; package cn.iocoder.yudao.module.ai.controller.admin.chat;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.util.collection.MapUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
import cn.iocoder.yudao.module.ai.service.AiChatService; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.chat.AiChatMessageService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.system.api.user.AdminUserApi;
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
@ -14,9 +25,12 @@ import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
@Tag(name = "管理后台 - 聊天消息") @Tag(name = "管理后台 - 聊天消息")
@ -26,39 +40,65 @@ import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUti
public class AiChatMessageController { public class AiChatMessageController {
@Resource @Resource
private AiChatService chatService; private AiChatMessageService chatMessageService;
@Resource
private AiChatConversationService chatConversationService;
@Resource
private AiChatRoleService chatRoleService;
@Resource
private AdminUserApi adminUserApi;
@Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢") @Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢")
@PostMapping("/send") @PostMapping("/send")
public CommonResult<AiChatMessageRespVO> sendMessage(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { public CommonResult<AiChatMessageRespVO> sendMessage(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
return success(chatService.chat(sendReqVO)); return success(chatMessageService.sendMessage(sendReqVO));
} }
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题 @PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题
public Flux<AiChatMessageSendRespVO> sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { public Flux<AiChatMessageSendRespVO> sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
return chatService.sendChatMessageStream(sendReqVO, getLoginUserId()); return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId());
} }
@Operation(summary = "获得指定会话的消息列表") @Operation(summary = "获得指定会话的消息列表")
@GetMapping("/list-by-conversation-id") @GetMapping("/list-by-conversation-id")
@Parameter(name = "conversationId", required = true, description = "会话编号", example = "1024") @Parameter(name = "conversationId", required = true, description = "会话编号", example = "1024")
public CommonResult<List<AiChatMessageRespVO>> getMessageListByConversationId(@RequestParam("conversationId") Long conversationId) { public CommonResult<List<AiChatMessageRespVO>> getChatMessageListByConversationId(
return success(chatService.getMessageListByConversationId(conversationId)); @RequestParam("conversationId") Long conversationId) {
AiChatConversationDO conversation = chatConversationService.getChatConversation(conversationId);
if (conversation == null || ObjUtil.notEqual(conversation.getUserId(), getLoginUserId())) {
return success(Collections.emptyList());
}
List<AiChatMessageDO> messageList = chatMessageService.getChatMessageListByConversationId(conversationId);
if (CollUtil.isEmpty(messageList)) {
return success(Collections.emptyList());
}
// 拼接数据
Map<Long, AiChatRoleDO> roleMap = chatRoleService.getChatRoleMap(convertSet(messageList, AiChatMessageDO::getRoleId));
AdminUserRespDTO user = adminUserApi.getUser(getLoginUserId());
return success(BeanUtils.toBean(messageList, AiChatMessageRespVO.class, respVO -> {
MapUtils.findAndThen(roleMap, respVO.getRoleId(), role -> respVO.setRoleAvatar(role.getAvatar()));
respVO.setUserAvatar(user.getAvatar());
}));
} }
@Operation(summary = "删除消息") @Operation(summary = "删除消息")
@DeleteMapping("/delete") @DeleteMapping("/delete")
@Parameter(name = "id", required = true, description = "消息编号", example = "1024") @Parameter(name = "id", required = true, description = "消息编号", example = "1024")
public CommonResult<Boolean> deleteMessage(@RequestParam("id") Long id) { public CommonResult<Boolean> deleteChatMessage(@RequestParam("id") Long id) {
return success(chatService.deleteMessage(id)); chatMessageService.deleteMessage(id, getLoginUserId());
return success(true);
} }
@Operation(summary = "删除消息-对于对话全部消息") @Operation(summary = "删除指定会话的消息")
@DeleteMapping("/delete-by-conversation-id") @DeleteMapping("/delete-by-conversation-id")
@Parameter(name = "id", required = true, description = "消息编号", example = "1024") @Parameter(name = "conversationId", required = true, description = "会话编号", example = "1024")
public CommonResult<Boolean> deleteByConversationId(@RequestParam("conversationId") Long conversationId) { public CommonResult<Boolean> deleteChatMessageByConversationId(@RequestParam("conversationId") Long conversationId) {
return success(chatService.deleteByConversationId(conversationId)); chatMessageService.deleteChatMessageByConversationId(conversationId, getLoginUserId());
return success(true);
} }
} }

View File

@ -31,13 +31,14 @@ public class AiChatMessageSendRespVO {
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED) @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime createTime; private LocalDateTime createTime;
// ========= 扩展字段 // ========== 扩展字段 ==========
@Schema(description = "用户头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "http://xxx") @Schema(description = "用户头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://iocoder.cn/1.png")
private String userAvatar; private String userAvatar;
@Schema(description = "角色头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "http://xxx") @Schema(description = "角色头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://iocoder.cn/2.png")
private String roleAvatar; private String roleAvatar;
} }
} }

View File

@ -1,30 +0,0 @@
package cn.iocoder.yudao.module.ai.convert;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import org.mapstruct.Mapper;
import org.mapstruct.factory.Mappers;
import java.util.List;
/**
* 聊天 对话 convert
*
* @author fansili
* @time 2024/4/18 16:39
* @since 1.0
*/
@Mapper
public interface AiChatMessageConvert {
AiChatMessageConvert INSTANCE = Mappers.getMapper(AiChatMessageConvert.class);
/**
* 转换 AiChatMessageRespVO
*
* @param aiChatMessageDOList
* @return
*/
List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList);
}

View File

@ -1,43 +0,0 @@
package cn.iocoder.yudao.module.ai.dal.mysql;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import org.apache.ibatis.annotations.Mapper;
import org.springframework.stereotype.Repository;
import java.util.List;
/**
* message mapper
*
* @fansili
* @since v1.0
*/
@Repository
@Mapper
public interface AiChatMessageMapper extends BaseMapperX<AiChatMessageDO> {
/**
* 查询 - 根据 对话id查询
*
* @param conversationId
*/
default List<AiChatMessageDO> selectByConversationId(Long conversationId) {
return this.selectList(
new LambdaQueryWrapperX<AiChatMessageDO>()
.eq(AiChatMessageDO::getConversationId, conversationId)
.orderByAsc(AiChatMessageDO::getId)
);
}
/**
* 删除 - 根据 conversationId
*
* @param conversationId
*/
default int deleteByConversationId(Long conversationId) {
return this.delete(new LambdaQueryWrapperX<AiChatMessageDO>().eq(AiChatMessageDO::getConversationId, conversationId));
}
}

View File

@ -1,18 +1,13 @@
package cn.iocoder.yudao.module.ai.dal.mysql; package cn.iocoder.yudao.module.ai.dal.mysql;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.List;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
@ -42,7 +37,6 @@ public interface AiChatModelMapper extends BaseMapperX<AiChatModelDO> {
return this.selectList(new LambdaQueryWrapperX<AiChatModelDO>().eq(AiChatModelDO::getId, modalIds)); return this.selectList(new LambdaQueryWrapperX<AiChatModelDO>().eq(AiChatModelDO::getId, modalIds));
} }
default PageResult<AiChatModelDO> selectPage(AiChatModelPageReqVO reqVO) { default PageResult<AiChatModelDO> selectPage(AiChatModelPageReqVO reqVO) {
return selectPage(reqVO, new LambdaQueryWrapperX<AiChatModelDO>() return selectPage(reqVO, new LambdaQueryWrapperX<AiChatModelDO>()
.likeIfPresent(AiChatModelDO::getName, reqVO.getName()) .likeIfPresent(AiChatModelDO::getName, reqVO.getName())

View File

@ -13,7 +13,6 @@ import org.springframework.stereotype.Repository;
* @time 2024/4/28 14:01 * @time 2024/4/28 14:01
* @since 1.0 * @since 1.0
*/ */
@Repository
@Mapper @Mapper
public interface AiImageMapper extends BaseMapperX<AiImageDO> { public interface AiImageMapper extends BaseMapperX<AiImageDO> {

View File

@ -0,0 +1,24 @@
package cn.iocoder.yudao.module.ai.dal.mysql.chat;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import org.apache.ibatis.annotations.Mapper;
import java.util.List;
/**
* AI 聊天对话 Mapper
*
* @author fansili
*/
@Mapper
public interface AiChatMessageMapper extends BaseMapperX<AiChatMessageDO> {
default List<AiChatMessageDO> selectListByConversationId(Long conversationId) {
return selectList(new LambdaQueryWrapperX<AiChatMessageDO>()
.eq(AiChatMessageDO::getConversationId, conversationId)
.orderByAsc(AiChatMessageDO::getId));
}
}

View File

@ -1,57 +0,0 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
import reactor.core.publisher.Flux;
import java.util.List;
/**
* 聊天 chat
*
* @author fansili
* @time 2024/4/14 15:55
* @since 1.0
*/
public interface AiChatService {
/**
* chat
*
* @param sendReqVO
* @return
*/
AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO);
/**
* 获取 - 获取对话 message list
*
* @param conversationId
* @return
*/
List<AiChatMessageRespVO> getMessageListByConversationId(Long conversationId);
/**
* 删除 - 删除message
*
* @param id
* @return
*/
Boolean deleteMessage(Long id);
/**
* 发送消息
*
* @param sendReqVO
* @param userId
* @return
*/
Flux<AiChatMessageSendRespVO> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId);
/**
* 删除消息-对于对话全部消息
*
* @param conversationId
* @return
*/
Boolean deleteByConversationId(Long conversationId);
}

View File

@ -0,0 +1,57 @@
package cn.iocoder.yudao.module.ai.service.chat;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import reactor.core.publisher.Flux;
import java.util.List;
/**
* AI 聊天消息 Service 接口
*
* @author fansili
*/
public interface AiChatMessageService {
/**
* 发送消息
*
* @param sendReqVO 发送信息
* @return 发送结果
*/
AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO sendReqVO);
/**
* 发送消息
*
* @param sendReqVO 发送信息
* @param userId 用户编号
* @return 发送结果
*/
Flux<AiChatMessageSendRespVO> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId);
/**
* 获得指定会话的消息列表
*
* @param conversationId 会话编号
* @return 消息列表
*/
List<AiChatMessageDO> getChatMessageListByConversationId(Long conversationId);
/**
* 删除消息
*
* @param id 消息编号
* @param userId 用户编号
*/
void deleteMessage(Long id, Long userId);
/**
* 删除指定会话的消息
*
* @param conversationId 会话编号
* @param userId 用户编号
*/
void deleteChatMessageByConversationId(Long conversationId, Long userId);
}

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.service.impl; package cn.iocoder.yudao.module.ai.service.chat;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
@ -20,13 +20,10 @@ import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.service.AiChatService;
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -37,21 +34,20 @@ import reactor.core.scheduler.Schedulers;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.*; import java.util.*;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS; import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
/** /**
* 聊天 service * AI 聊天消息 Service 实现类
* *
* @author fansili * @author fansili
* @time 2024/4/14 15:55
* @since 1.0
*/ */
@Slf4j
@Service @Service
public class AiChatServiceImpl implements AiChatService { @Slf4j
public class AiChatMessageServiceImpl implements AiChatMessageService {
@Resource @Resource
private AiChatMessageMapper chatMessageMapper; private AiChatMessageMapper chatMessageMapper;
@ -72,7 +68,7 @@ public class AiChatServiceImpl implements AiChatService {
private AdminUserApi adminUserApi; private AdminUserApi adminUserApi;
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) { public AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO req) {
return null; // TODO 芋艿一起改 return null; // TODO 芋艿一起改
// Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); // Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// // 查询对话 // // 查询对话
@ -117,10 +113,13 @@ public class AiChatServiceImpl implements AiChatService {
if (ObjUtil.notEqual(conversation.getUserId(), userId)) { if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿异常情况的对接 throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿异常情况的对接
} }
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectByConversationId(conversation.getId()); List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型 // 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(model.getKeyId()); StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
// 1.3 获取用户头像角色头像
AdminUserRespDTO user = adminUserApi.getUser(SecurityFrameworkUtils.getLoginUserId());
AiChatRoleDO role = conversation.getRoleId() != null ? chatRoleService.getChatRole(conversation.getRoleId()) : null;
// 2. 插入 user 发送消息 // 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -136,19 +135,17 @@ public class AiChatServiceImpl implements AiChatService {
// 3.3 流式返回 // 3.3 流式返回
// 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题 // 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
// 3.4 获取用户头像角色头像
AdminUserRespDTO user = adminUserApi.getUser(SecurityFrameworkUtils.getLoginUserId());
AiChatRoleDO chatRole = chatRoleService.getChatRole(assistantMessage.getRoleId());
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.publishOn(Schedulers.single()).map(chunk -> { return streamResponse.publishOn(Schedulers.single()).map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 情况 newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 情况
contentBuffer.append(newContent); contentBuffer.append(newContent);
// 响应结果 // 响应结果
return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class).setUserAvatar(user.getAvatar())) AiChatMessageSendRespVO.Message send = BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class,
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent).setRoleAvatar(chatRole == null ? null : chatRole.getAvatar())); o -> o.setUserAvatar(user.getAvatar()));
AiChatMessageSendRespVO.Message receive = BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class,
o -> o.setRoleAvatar(role != null ? role.getAvatar() : null)).setContent(newContent);
return new AiChatMessageSendRespVO().setSend(send).setReceive(receive);
}).doOnComplete(() -> { }).doOnComplete(() -> {
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())); chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString()));
}).doOnError(throwable -> { }).doOnError(throwable -> {
@ -157,11 +154,6 @@ public class AiChatServiceImpl implements AiChatService {
}); });
} }
@Override
public Boolean deleteByConversationId(Long conversationId) {
return chatMessageMapper.deleteByConversationId(conversationId) > 0;
}
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) { AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
// 1. 构建 Prompt Message 列表 // 1. 构建 Prompt Message 列表
@ -174,12 +166,11 @@ public class AiChatServiceImpl implements AiChatService {
// 1.3 user message 新发送消息 // 1.3 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent())); chatMessages.add(new UserMessage(sendReqVO.getContent()));
// 2. 构建 ChatOptions 对象 TODO 芋艿临时注释掉等文心一言兼容了 // 2. 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = clientFactory.buildChatOptions(platform, model.getModel(), ChatOptions chatOptions = clientFactory.buildChatOptions(platform, model.getModel(),
conversation.getTemperature(), conversation.getMaxTokens()); conversation.getTemperature(), conversation.getMaxTokens());
return new Prompt(chatMessages, chatOptions); return new Prompt(chatMessages, chatOptions);
// return new Prompt(chatMessages);
} }
/** /**
@ -231,42 +222,30 @@ public class AiChatServiceImpl implements AiChatService {
} }
@Override @Override
public List<AiChatMessageRespVO> getMessageListByConversationId(Long conversationId) { public List<AiChatMessageDO> getChatMessageListByConversationId(Long conversationId) {
// 校验对话是否存在 return chatMessageMapper.selectListByConversationId(conversationId);
chatConversationService.validateExists(conversationId);
// 获取对话所有 message
List<AiChatMessageDO> aiChatMessageDOList = chatMessageMapper.selectByConversationId(conversationId);
// 获取模型信息
Set<Long> roleIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getRoleId).collect(Collectors.toSet());
List<AiChatRoleDO> roleList;
if (!CollUtil.isEmpty(roleIds)) {
roleList = chatRoleService.getChatRoles(roleIds);
} else {
roleList = Collections.emptyList();
}
Map<Long, AiChatRoleDO> roleMap = roleList.stream().collect(Collectors.toMap(AiChatRoleDO::getId, o -> o));
// 转换 AiChatMessageRespVO
List<AiChatMessageRespVO> aiChatMessageRespList = AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList);
// 获取用户信息
AdminUserRespDTO user = adminUserApi.getUser(SecurityFrameworkUtils.getLoginUserId());
// 设置用户头像 模型头像
return aiChatMessageRespList.stream().map(item -> {
// 设置 role 头像
if (roleMap.containsKey(item.getRoleId())) {
AiChatRoleDO role = roleMap.get(item.getRoleId());
item.setRoleAvatar(role.getAvatar());
}
// 设置 user 头像
if (user != null) {
item.setUserAvatar(user.getAvatar());
}
return item;
}).collect(Collectors.toList());
} }
@Override @Override
public Boolean deleteMessage(Long id) { public void deleteMessage(Long id, Long userId) {
return chatMessageMapper.deleteById(id) > 0; // 1. 校验消息存在
AiChatMessageDO message = chatMessageMapper.selectById(id);
if (message == null || ObjUtil.notEqual(message.getUserId(), userId)) {
throw exception(AI_CHAT_MESSAGE_NOT_EXIST);
}
// 2. 执行删除
chatMessageMapper.deleteById(id);
}
@Override
public void deleteChatMessageByConversationId(Long conversationId, Long userId) {
// 1. 校验消息存在
List<AiChatMessageDO> messages = chatMessageMapper.selectListByConversationId(conversationId);
if (CollUtil.isEmpty(messages) || ObjUtil.notEqual(messages.get(0).getUserId(), userId)) {
throw exception(AI_CHAT_MESSAGE_NOT_EXIST);
}
// 2. 执行删除
chatMessageMapper.deleteBatchIds(convertList(messages, AiChatMessageDO::getId));
} }
} }

View File

@ -4,12 +4,14 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRoleSaveMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRoleSaveMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRoleSaveReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRoleSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Map;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
/** /**
* AI 聊天角色 Service 接口 * AI 聊天角色 Service 接口
@ -74,12 +76,16 @@ public interface AiChatRoleService {
AiChatRoleDO getChatRole(Long id); AiChatRoleDO getChatRole(Long id);
/** /**
* 获得聊天角色 - 根据 ids * 获得聊天角色列表
* *
* @param roleIds * @param ids 编号数组
* @return * @return 聊天角色列表
*/ */
List<AiChatRoleDO> getChatRoles(Set<Long> roleIds); List<AiChatRoleDO> getChatRoleList(Collection<Long> ids);
default Map<Long, AiChatRoleDO> getChatRoleMap(Collection<Long> ids) {
return convertMap(getChatRoleList(ids), AiChatRoleDO::getId);
}
/** /**
* 校验聊天角色是否合法 * 校验聊天角色是否合法

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.service.model; package cn.iocoder.yudao.module.ai.service.model;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
@ -14,12 +15,9 @@ import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@ -107,8 +105,11 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
} }
@Override @Override
public List<AiChatRoleDO> getChatRoles(Set<Long> roleIds) { public List<AiChatRoleDO> getChatRoleList(Collection<Long> ids) {
return chatRoleMapper.selectBatchIds(roleIds); if (CollUtil.isEmpty(ids)) {
return Collections.emptyList();
}
return chatRoleMapper.selectBatchIds(ids);
} }
@Override @Override