【优化】优化 chat event stream 模式交互,增加 add message 优先记录

This commit is contained in:
cherishsince 2024-05-12 19:04:58 +08:00
parent aaf1599f7a
commit 5a4162cdc1
13 changed files with 181 additions and 259 deletions

View File

@ -36,5 +36,9 @@ public interface ErrorCodeConstants {
ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "AI 角色不存在!");
ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!");
// chat
ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_022_000_100, "AI 提问的 MessageId 不存在!");
}

View File

@ -1,8 +1,7 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
import cn.iocoder.yudao.module.ai.service.AiChatService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
@ -38,10 +37,16 @@ public class AiChatMessageController {
// TODO @fan要不要使用 Flux 来返回可以使用 Flux<AiChatMessageRespVO>
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendStreamReqVO sendReqVO) {
return chatService.chatStream(sendReqVO);
}
@Operation(summary = "添加/提问", description = "先创建好 message 前端才好渲染")
@PostMapping(value = "/add")
public CommonResult<AiChatMessageRespVO> add(@Validated @RequestBody AiChatMessageAddReqVO req) {
return success(chatService.add(req));
}
@Operation(summary = "获得指定会话的消息列表")
@GetMapping("/list-by-conversation-id")
@Parameter(name = "conversationId", required = true, description = "会话编号", example = "1024")

View File

@ -0,0 +1,20 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
@Data
public class AiChatMessageAddReqVO {
@Schema(description = "聊天对话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
@NotNull(message = "聊天对话编号不能为空")
private Long conversationId;
@Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "帮我写个 Java 算法")
@NotEmpty(message = "聊天内容不能为空")
private String content;
}

View File

@ -0,0 +1,17 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
@Schema(description = "管理后台 - AI 聊天消息 Add Response VO")
@Data
public class AiChatMessageAddRespVO {
@Schema(description = "用户信息")
private AiChatMessageRespVO userMessage;
@Schema(description = "系统信息")
private AiChatMessageRespVO systemMessage;
}

View File

@ -0,0 +1,16 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
@Data
public class AiChatMessageSendStreamReqVO {
@Schema(description = "提问的 messageId", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
@NotNull(message = "提问的 messageId 不能为空")
private Long id;
}

View File

@ -26,4 +26,13 @@ public interface AiChatMessageConvert {
* @return
*/
List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList);
/**
* 转换 - aiChatMessageDO
*
* @param aiChatMessageDO
* @return
*/
AiChatMessageRespVO convertAiChatMessageRespVO(AiChatMessageDO aiChatMessageDO);
}

View File

@ -11,7 +11,6 @@ import org.apache.ibatis.annotations.Mapper;
import java.util.Collection;
import java.util.List;
import java.util.Set;
/**
* API 聊天模型 Mapper

View File

@ -1,78 +0,0 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import java.util.List;
import java.util.Set;
/**
* ai modal
*
* @author fansili
* @time 2024/4/24 19:42
* @since 1.0
*/
public interface AiChatModelService {
/**
* ai modal - 列表
*
* @param req
* @return
*/
PageResult<AiChatModelListRespVO> list(AiChatModelListReqVO req);
/**
* ai modal - 添加
*
* @param req
*/
void add(AiChatModelAddReqVO req);
/**
* ai modal - 更新
*
* @param req
*/
void update(AiChatModelUpdateReqVO req);
/**
* ai modal - 删除
*
* @param id
*/
void delete(Long id);
/**
* 获取 - 获取 modal
*
* @param modalId
* @return
*/
AiChatModalRespVO getChatModalOfValidate(Long modalId);
/**
* 校验 - 是否存在
*
* @param id
* @return
*/
AiChatModelDO validateExists(Long id);
/**
* 校验 - 校验是否可用
*
* @param chatModal
*/
void validateAvailable(AiChatModalRespVO chatModal);
/**
* 获取 - 根据 ids 批量获取
*
* @param modalIds
* @return
*/
List<AiChatModelDO> getModalByIds(Set<Long> modalIds);
}

View File

@ -1,7 +1,6 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
import reactor.core.publisher.Flux;
import java.util.List;
@ -29,7 +28,15 @@ public interface AiChatService {
* @param sendReqVO
* @return
*/
Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO sendReqVO);
Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO sendReqVO);
/**
* 添加 - message
*
* @param sendReqVO
* @return
*/
AiChatMessageRespVO add(AiChatMessageAddReqVO sendReqVO);
/**
* 获取 - 获取对话 message list

View File

@ -1,132 +0,0 @@
package cn.iocoder.yudao.module.ai.service.impl;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.validation.ValidationUtil;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*;
import cn.iocoder.yudao.module.ai.convert.AiChatModelConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModelMapper;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO;
import cn.iocoder.yudao.module.ai.service.AiChatModelService;
import jakarta.validation.ConstraintViolation;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Set;
/**
* ai 模型
*
* @author fansili
* @time 2024/4/24 19:42
* @since 1.0
*/
@AllArgsConstructor
@Service
@Slf4j
public class AiChatModalServiceImpl implements AiChatModelService {
private final AiChatModelMapper aiChatModelMapper;
@Override
public PageResult<AiChatModelListRespVO> list(AiChatModelListReqVO req) {
LambdaQueryWrapperX<AiChatModelDO> queryWrapperX = new LambdaQueryWrapperX<>();
// 查询的都是未禁用的模型
queryWrapperX.eq(AiChatModelDO::getStatus, CommonStatusEnum.ENABLE.getStatus());
// search
if (!StrUtil.isBlank(req.getSearch())) {
queryWrapperX.like(AiChatModelDO::getName, req.getSearch().trim());
}
// 默认排序
queryWrapperX.orderByAsc(AiChatModelDO::getSort);
// 查询
PageResult<AiChatModelDO> aiChatModalDOPageResult = aiChatModelMapper.selectPage(req, queryWrapperX);
// 转换 res
List<AiChatModelListRespVO> resList = AiChatModelConvert.INSTANCE.convertAiChatModalListRes(aiChatModalDOPageResult.getList());
return new PageResult<>(resList, aiChatModalDOPageResult.getTotal());
}
@Override
public void add(AiChatModelAddReqVO req) {
// 校验 platformtype
validatePlatform(req.getPlatform());
// 转换 do
AiChatModelDO insertChatModalDO = AiChatModelConvert.INSTANCE.convertAiChatModalDO(req);
// 设置默认属性
insertChatModalDO.setStatus(CommonStatusEnum.ENABLE.getStatus());
// 保存数据库
aiChatModelMapper.insert(insertChatModalDO);
}
@Override
public void update(AiChatModelUpdateReqVO req) {
// 校验 platform
validatePlatform(req.getPlatform());
// 校验模型是否存在
validateExists(req.getId());
// 转换 updateChatModalDO
AiChatModelDO updateChatModalDO = AiChatModelConvert.INSTANCE.convertAiChatModalDO(req);
updateChatModalDO.setId(req.getId());
// 更新数据库
aiChatModelMapper.updateById(updateChatModalDO);
}
@Override
public void delete(Long id) {
// 检查 modal 是否存在
validateExists(id);
// 删除 delete
aiChatModelMapper.deleteById(id);
}
@Override
public AiChatModalRespVO getChatModalOfValidate(Long modalId) {
// 检查 modal 是否存在
AiChatModelDO aiChatModalDO = validateExists(modalId);
return AiChatModelConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);
}
@Override
public void validateAvailable(AiChatModalRespVO chatModal) {
// 对话模型是否可用
if (!CommonStatusEnum.ENABLE.getStatus().equals(chatModal.getStatus())) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
}
}
@Override
public List<AiChatModelDO> getModalByIds(Set<Long> modalIds) {
return aiChatModelMapper.selectByIds(modalIds);
}
public AiChatModelDO validateExists(Long id) {
AiChatModelDO aiChatModalDO = aiChatModelMapper.selectById(id);
if (aiChatModalDO == null) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_NOT_EXIST);
}
return aiChatModalDO;
}
private void validatePlatform(String platform) {
try {
AiPlatformEnum.valueOfPlatform(platform);
} catch (IllegalArgumentException e) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_PLATFORM_PARAMS_INCORRECT, e.getMessage());
}
}
private void validateModalConfig(AiChatModalConfigVO aiChatModalConfigVO) {
Set<ConstraintViolation<AiChatModalConfigVO>> validate = ValidationUtil.validate(aiChatModalConfigVO);
for (ConstraintViolation<AiChatModalConfigVO> constraintViolation : validate) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_CONFIG_PARAMS_INCORRECT, constraintViolation.getMessage());
}
}
}

View File

@ -7,11 +7,15 @@ import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageAddReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendStreamReqVO;
import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
@ -19,11 +23,12 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.AiChatService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.autoconfigure.http.HttpMessageConverters;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
@ -53,6 +58,7 @@ public class AiChatServiceImpl implements AiChatService {
private final AiChatConversationService chatConversationService;
private final AiChatModelService aiChatModalService;
private final AiChatRoleService aiChatRoleService;
private final HttpMessageConverters messageConverters;
@Transactional(rollbackFor = Exception.class)
public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
@ -124,7 +130,75 @@ public class AiChatServiceImpl implements AiChatService {
return insertChatMessageDO;
}
public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO req) {
public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询提问的 message
AiChatMessageDO aiChatMessageDO = aiChatMessageMapper.selectById(req.getId());
if (aiChatMessageDO == null) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST);
}
// 查询对话
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(aiChatMessageDO.getConversationId());
// 获取对话模型
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
// 获取角色信息
AiChatRoleDO aiChatRoleDO = null;
if (conversation.getRoleId() != null) {
aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
}
// 校验角色是否公开
aiChatRoleService.validateIsPublic(aiChatRoleDO);
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(aiChatMessageDO.getContent());
// 提前创建一个 system message
AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), "",
0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
// req.setTopK(req.getTopK());
// req.setTopP(req.getTopP());
// req.setTemperature(req.getTemperature());
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
// 转换 flex AiChatMessageRespVO
StringBuffer contentBuffer = new StringBuffer();
AtomicInteger tokens = new AtomicInteger(0);
return streamResponse.map(res -> {
AiChatMessageRespVO aiChatMessageRespVO =
AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(systemMessage);
aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
contentBuffer.append(res.getResult().getOutput().getContent());
tokens.incrementAndGet();
return aiChatMessageRespVO;
}
).doOnComplete(new Runnable() {
@Override
public void run() {
log.info("发送完成!");
// 保存 chat message
aiChatMessageMapper.updateById(new AiChatMessageDO()
.setId(systemMessage.getId())
.setContent(contentBuffer.toString())
.setTokens(tokens.get())
);
}
}).doOnError(new Consumer<Throwable>() {
@Override
public void accept(Throwable throwable) {
log.error("发送错误 {}!", throwable.getMessage());
// 更新错误信息
aiChatMessageMapper.updateById(new AiChatMessageDO()
.setId(systemMessage.getId())
.setContent(throwable.getMessage())
.setTokens(tokens.get())
);
}
});
}
@Override
public AiChatMessageRespVO add(AiChatMessageAddReqVO req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询对话
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
@ -137,48 +211,10 @@ public class AiChatServiceImpl implements AiChatService {
}
// 校验角色是否公开
aiChatRoleService.validateIsPublic(aiChatRoleDO);
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getContent());
// req.setTopK(req.getTopK());
// req.setTopP(req.getTopP());
// req.setTemperature(req.getTemperature());
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
// 转换 flex AiChatMessageRespVO
StringBuffer contentBuffer = new StringBuffer();
AtomicInteger tokens = new AtomicInteger(0);
return streamResponse.map(res -> {
AiChatMessageRespVO aiChatMessageRespVO = new AiChatMessageRespVO();
aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
contentBuffer.append(res.getResult().getOutput().getContent());
tokens.incrementAndGet();
return aiChatMessageRespVO;
}
).doOnComplete(new Runnable() {
@Override
public void run() {
log.info("发送完成!");
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), contentBuffer.toString(),
tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
}
}).doOnError(new Consumer<Throwable>() {
@Override
public void accept(Throwable throwable) {
log.error("发送错误 {}!", throwable.getMessage());
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), throwable.getMessage(),
tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
}
});
return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(userMessage);
}
@Override
@ -207,4 +243,5 @@ public class AiChatServiceImpl implements AiChatService {
public Boolean deleteMessage(Long id) {
return aiChatMessageMapper.deleteById(id) > 0;
}
}

View File

@ -6,6 +6,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatMode
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import jakarta.validation.Valid;
import java.util.List;
import java.util.Set;
/**
* AI 聊天模型 Service 接口
*
@ -60,4 +63,11 @@ public interface AiChatModelService {
*/
AiChatModelDO validateChatModel(Long id);
/**
* 获取 - 根据多个 ids 获取
*
* @param modalIds
* @return
*/
List<AiChatModelDO> getModalByIds(Set<Long> modalIds);
}

View File

@ -12,6 +12,9 @@ import jakarta.annotation.Resource;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
import java.util.List;
import java.util.Set;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.*;
@ -89,4 +92,9 @@ public class AiChatModelServiceImpl implements AiChatModelService {
return model;
}
@Override
public List<AiChatModelDO> getModalByIds(Set<Long> modalIds) {
return chatModelMapper.selectByIds(modalIds);
}
}