mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2024-11-27 01:32:03 +08:00
【优化】优化 chat event stream 模式交互,增加 add message 优先记录
This commit is contained in:
parent
aaf1599f7a
commit
5a4162cdc1
@ -36,5 +36,9 @@ public interface ErrorCodeConstants {
|
||||
ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "AI 角色不存在!");
|
||||
ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!");
|
||||
|
||||
// chat
|
||||
|
||||
ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_022_000_100, "AI 提问的 MessageId 不存在!");
|
||||
|
||||
|
||||
}
|
||||
|
@ -1,8 +1,7 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.chat;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
@ -38,10 +37,16 @@ public class AiChatMessageController {
|
||||
// TODO @fan:要不要使用 Flux 来返回;可以使用 Flux<AiChatMessageRespVO>
|
||||
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
|
||||
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
|
||||
public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendStreamReqVO sendReqVO) {
|
||||
return chatService.chatStream(sendReqVO);
|
||||
}
|
||||
|
||||
@Operation(summary = "添加/提问", description = "先创建好 message 前端才好渲染")
|
||||
@PostMapping(value = "/add")
|
||||
public CommonResult<AiChatMessageRespVO> add(@Validated @RequestBody AiChatMessageAddReqVO req) {
|
||||
return success(chatService.add(req));
|
||||
}
|
||||
|
||||
@Operation(summary = "获得指定会话的消息列表")
|
||||
@GetMapping("/list-by-conversation-id")
|
||||
@Parameter(name = "conversationId", required = true, description = "会话编号", example = "1024")
|
||||
|
@ -0,0 +1,20 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
|
||||
@Data
|
||||
public class AiChatMessageAddReqVO {
|
||||
|
||||
@Schema(description = "聊天对话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
|
||||
@NotNull(message = "聊天对话编号不能为空")
|
||||
private Long conversationId;
|
||||
|
||||
@Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "帮我写个 Java 算法")
|
||||
@NotEmpty(message = "聊天内容不能为空")
|
||||
private String content;
|
||||
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
@Schema(description = "管理后台 - AI 聊天消息 Add Response VO")
|
||||
@Data
|
||||
public class AiChatMessageAddRespVO {
|
||||
|
||||
@Schema(description = "用户信息")
|
||||
private AiChatMessageRespVO userMessage;
|
||||
|
||||
@Schema(description = "系统信息")
|
||||
private AiChatMessageRespVO systemMessage;
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
|
||||
@Data
|
||||
public class AiChatMessageSendStreamReqVO {
|
||||
|
||||
@Schema(description = "提问的 messageId", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
|
||||
@NotNull(message = "提问的 messageId 不能为空")
|
||||
private Long id;
|
||||
|
||||
}
|
@ -26,4 +26,13 @@ public interface AiChatMessageConvert {
|
||||
* @return
|
||||
*/
|
||||
List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList);
|
||||
|
||||
/**
|
||||
* 转换 - aiChatMessageDO
|
||||
*
|
||||
* @param aiChatMessageDO
|
||||
* @return
|
||||
*/
|
||||
AiChatMessageRespVO convertAiChatMessageRespVO(AiChatMessageDO aiChatMessageDO);
|
||||
|
||||
}
|
||||
|
@ -11,7 +11,6 @@ import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* API 聊天模型 Mapper
|
||||
|
@ -1,78 +0,0 @@
|
||||
package cn.iocoder.yudao.module.ai.service;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* ai modal
|
||||
*
|
||||
* @author fansili
|
||||
* @time 2024/4/24 19:42
|
||||
* @since 1.0
|
||||
*/
|
||||
public interface AiChatModelService {
|
||||
|
||||
/**
|
||||
* ai modal - 列表
|
||||
*
|
||||
* @param req
|
||||
* @return
|
||||
*/
|
||||
PageResult<AiChatModelListRespVO> list(AiChatModelListReqVO req);
|
||||
|
||||
/**
|
||||
* ai modal - 添加
|
||||
*
|
||||
* @param req
|
||||
*/
|
||||
void add(AiChatModelAddReqVO req);
|
||||
|
||||
/**
|
||||
* ai modal - 更新
|
||||
*
|
||||
* @param req
|
||||
*/
|
||||
void update(AiChatModelUpdateReqVO req);
|
||||
|
||||
/**
|
||||
* ai modal - 删除
|
||||
*
|
||||
* @param id
|
||||
*/
|
||||
void delete(Long id);
|
||||
|
||||
/**
|
||||
* 获取 - 获取 modal
|
||||
*
|
||||
* @param modalId
|
||||
* @return
|
||||
*/
|
||||
AiChatModalRespVO getChatModalOfValidate(Long modalId);
|
||||
|
||||
/**
|
||||
* 校验 - 是否存在
|
||||
*
|
||||
* @param id
|
||||
* @return
|
||||
*/
|
||||
AiChatModelDO validateExists(Long id);
|
||||
|
||||
/**
|
||||
* 校验 - 校验是否可用
|
||||
*
|
||||
* @param chatModal
|
||||
*/
|
||||
void validateAvailable(AiChatModalRespVO chatModal);
|
||||
|
||||
/**
|
||||
* 获取 - 根据 ids 批量获取
|
||||
*
|
||||
* @param modalIds
|
||||
* @return
|
||||
*/
|
||||
List<AiChatModelDO> getModalByIds(Set<Long> modalIds);
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package cn.iocoder.yudao.module.ai.service;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.List;
|
||||
@ -29,7 +28,15 @@ public interface AiChatService {
|
||||
* @param sendReqVO
|
||||
* @return
|
||||
*/
|
||||
Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO sendReqVO);
|
||||
Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO sendReqVO);
|
||||
|
||||
/**
|
||||
* 添加 - message
|
||||
*
|
||||
* @param sendReqVO
|
||||
* @return
|
||||
*/
|
||||
AiChatMessageRespVO add(AiChatMessageAddReqVO sendReqVO);
|
||||
|
||||
/**
|
||||
* 获取 - 获取对话 message list
|
||||
|
@ -1,132 +0,0 @@
|
||||
package cn.iocoder.yudao.module.ai.service.impl;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.extra.validation.ValidationUtil;
|
||||
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*;
|
||||
import cn.iocoder.yudao.module.ai.convert.AiChatModelConvert;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModelMapper;
|
||||
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatModelService;
|
||||
import jakarta.validation.ConstraintViolation;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* ai 模型
|
||||
*
|
||||
* @author fansili
|
||||
* @time 2024/4/24 19:42
|
||||
* @since 1.0
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@Service
|
||||
@Slf4j
|
||||
public class AiChatModalServiceImpl implements AiChatModelService {
|
||||
|
||||
private final AiChatModelMapper aiChatModelMapper;
|
||||
|
||||
@Override
|
||||
public PageResult<AiChatModelListRespVO> list(AiChatModelListReqVO req) {
|
||||
LambdaQueryWrapperX<AiChatModelDO> queryWrapperX = new LambdaQueryWrapperX<>();
|
||||
// 查询的都是未禁用的模型
|
||||
queryWrapperX.eq(AiChatModelDO::getStatus, CommonStatusEnum.ENABLE.getStatus());
|
||||
// search
|
||||
if (!StrUtil.isBlank(req.getSearch())) {
|
||||
queryWrapperX.like(AiChatModelDO::getName, req.getSearch().trim());
|
||||
}
|
||||
// 默认排序
|
||||
queryWrapperX.orderByAsc(AiChatModelDO::getSort);
|
||||
// 查询
|
||||
PageResult<AiChatModelDO> aiChatModalDOPageResult = aiChatModelMapper.selectPage(req, queryWrapperX);
|
||||
// 转换 res
|
||||
List<AiChatModelListRespVO> resList = AiChatModelConvert.INSTANCE.convertAiChatModalListRes(aiChatModalDOPageResult.getList());
|
||||
return new PageResult<>(resList, aiChatModalDOPageResult.getTotal());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(AiChatModelAddReqVO req) {
|
||||
// 校验 platform、type
|
||||
validatePlatform(req.getPlatform());
|
||||
// 转换 do
|
||||
AiChatModelDO insertChatModalDO = AiChatModelConvert.INSTANCE.convertAiChatModalDO(req);
|
||||
// 设置默认属性
|
||||
insertChatModalDO.setStatus(CommonStatusEnum.ENABLE.getStatus());
|
||||
// 保存数据库
|
||||
aiChatModelMapper.insert(insertChatModalDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void update(AiChatModelUpdateReqVO req) {
|
||||
// 校验 platform
|
||||
validatePlatform(req.getPlatform());
|
||||
// 校验模型是否存在
|
||||
validateExists(req.getId());
|
||||
// 转换 updateChatModalDO
|
||||
AiChatModelDO updateChatModalDO = AiChatModelConvert.INSTANCE.convertAiChatModalDO(req);
|
||||
updateChatModalDO.setId(req.getId());
|
||||
// 更新数据库
|
||||
aiChatModelMapper.updateById(updateChatModalDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(Long id) {
|
||||
// 检查 modal 是否存在
|
||||
validateExists(id);
|
||||
// 删除 delete
|
||||
aiChatModelMapper.deleteById(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiChatModalRespVO getChatModalOfValidate(Long modalId) {
|
||||
// 检查 modal 是否存在
|
||||
AiChatModelDO aiChatModalDO = validateExists(modalId);
|
||||
return AiChatModelConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validateAvailable(AiChatModalRespVO chatModal) {
|
||||
// 对话模型是否可用
|
||||
if (!CommonStatusEnum.ENABLE.getStatus().equals(chatModal.getStatus())) {
|
||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AiChatModelDO> getModalByIds(Set<Long> modalIds) {
|
||||
return aiChatModelMapper.selectByIds(modalIds);
|
||||
}
|
||||
|
||||
public AiChatModelDO validateExists(Long id) {
|
||||
AiChatModelDO aiChatModalDO = aiChatModelMapper.selectById(id);
|
||||
if (aiChatModalDO == null) {
|
||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_NOT_EXIST);
|
||||
}
|
||||
return aiChatModalDO;
|
||||
}
|
||||
|
||||
private void validatePlatform(String platform) {
|
||||
try {
|
||||
AiPlatformEnum.valueOfPlatform(platform);
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_PLATFORM_PARAMS_INCORRECT, e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private void validateModalConfig(AiChatModalConfigVO aiChatModalConfigVO) {
|
||||
Set<ConstraintViolation<AiChatModalConfigVO>> validate = ValidationUtil.validate(aiChatModalConfigVO);
|
||||
for (ConstraintViolation<AiChatModalConfigVO> constraintViolation : validate) {
|
||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_CONFIG_PARAMS_INCORRECT, constraintViolation.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
@ -7,11 +7,15 @@ import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
|
||||
import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
|
||||
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
|
||||
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
||||
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
|
||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
||||
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageAddReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendStreamReqVO;
|
||||
import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
@ -19,11 +23,12 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.boot.autoconfigure.http.HttpMessageConverters;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import reactor.core.publisher.Flux;
|
||||
@ -53,6 +58,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
private final AiChatConversationService chatConversationService;
|
||||
private final AiChatModelService aiChatModalService;
|
||||
private final AiChatRoleService aiChatRoleService;
|
||||
private final HttpMessageConverters messageConverters;
|
||||
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
|
||||
@ -124,7 +130,75 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
return insertChatMessageDO;
|
||||
}
|
||||
|
||||
public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO req) {
|
||||
public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO req) {
|
||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||
// 查询提问的 message
|
||||
AiChatMessageDO aiChatMessageDO = aiChatMessageMapper.selectById(req.getId());
|
||||
if (aiChatMessageDO == null) {
|
||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST);
|
||||
}
|
||||
// 查询对话
|
||||
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(aiChatMessageDO.getConversationId());
|
||||
// 获取对话模型
|
||||
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
||||
// 获取角色信息
|
||||
AiChatRoleDO aiChatRoleDO = null;
|
||||
if (conversation.getRoleId() != null) {
|
||||
aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
|
||||
}
|
||||
// 校验角色是否公开
|
||||
aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
||||
// 创建 chat 需要的 Prompt
|
||||
Prompt prompt = new Prompt(aiChatMessageDO.getContent());
|
||||
// 提前创建一个 system message
|
||||
AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
||||
chatModel.getModel(), chatModel.getId(), "",
|
||||
0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||
// req.setTopK(req.getTopK());
|
||||
// req.setTopP(req.getTopP());
|
||||
// req.setTemperature(req.getTemperature());
|
||||
// 获取 client 类型
|
||||
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
|
||||
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
|
||||
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
|
||||
// 转换 flex AiChatMessageRespVO
|
||||
StringBuffer contentBuffer = new StringBuffer();
|
||||
AtomicInteger tokens = new AtomicInteger(0);
|
||||
return streamResponse.map(res -> {
|
||||
AiChatMessageRespVO aiChatMessageRespVO =
|
||||
AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(systemMessage);
|
||||
aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
|
||||
contentBuffer.append(res.getResult().getOutput().getContent());
|
||||
tokens.incrementAndGet();
|
||||
return aiChatMessageRespVO;
|
||||
}
|
||||
).doOnComplete(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
log.info("发送完成!");
|
||||
// 保存 chat message
|
||||
aiChatMessageMapper.updateById(new AiChatMessageDO()
|
||||
.setId(systemMessage.getId())
|
||||
.setContent(contentBuffer.toString())
|
||||
.setTokens(tokens.get())
|
||||
);
|
||||
}
|
||||
}).doOnError(new Consumer<Throwable>() {
|
||||
@Override
|
||||
public void accept(Throwable throwable) {
|
||||
log.error("发送错误 {}!", throwable.getMessage());
|
||||
// 更新错误信息
|
||||
aiChatMessageMapper.updateById(new AiChatMessageDO()
|
||||
.setId(systemMessage.getId())
|
||||
.setContent(throwable.getMessage())
|
||||
.setTokens(tokens.get())
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiChatMessageRespVO add(AiChatMessageAddReqVO req) {
|
||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||
// 查询对话
|
||||
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
||||
@ -137,48 +211,10 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
}
|
||||
// 校验角色是否公开
|
||||
aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
||||
// 创建 chat 需要的 Prompt
|
||||
Prompt prompt = new Prompt(req.getContent());
|
||||
// req.setTopK(req.getTopK());
|
||||
// req.setTopP(req.getTopP());
|
||||
// req.setTemperature(req.getTemperature());
|
||||
// 保存 chat message
|
||||
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
||||
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
||||
chatModel.getModel(), chatModel.getId(), req.getContent(),
|
||||
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||
// 获取 client 类型
|
||||
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
|
||||
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
|
||||
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
|
||||
// 转换 flex AiChatMessageRespVO
|
||||
StringBuffer contentBuffer = new StringBuffer();
|
||||
AtomicInteger tokens = new AtomicInteger(0);
|
||||
return streamResponse.map(res -> {
|
||||
AiChatMessageRespVO aiChatMessageRespVO = new AiChatMessageRespVO();
|
||||
aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
|
||||
contentBuffer.append(res.getResult().getOutput().getContent());
|
||||
tokens.incrementAndGet();
|
||||
return aiChatMessageRespVO;
|
||||
}
|
||||
).doOnComplete(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
log.info("发送完成!");
|
||||
// 保存 chat message
|
||||
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
||||
chatModel.getModel(), chatModel.getId(), contentBuffer.toString(),
|
||||
tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||
}
|
||||
}).doOnError(new Consumer<Throwable>() {
|
||||
@Override
|
||||
public void accept(Throwable throwable) {
|
||||
log.error("发送错误 {}!", throwable.getMessage());
|
||||
// 保存 chat message
|
||||
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
||||
chatModel.getModel(), chatModel.getId(), throwable.getMessage(),
|
||||
tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||
}
|
||||
});
|
||||
return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(userMessage);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -207,4 +243,5 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
public Boolean deleteMessage(Long id) {
|
||||
return aiChatMessageMapper.deleteById(id) > 0;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -6,6 +6,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatMode
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import jakarta.validation.Valid;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* AI 聊天模型 Service 接口
|
||||
*
|
||||
@ -60,4 +63,11 @@ public interface AiChatModelService {
|
||||
*/
|
||||
AiChatModelDO validateChatModel(Long id);
|
||||
|
||||
/**
|
||||
* 获取 - 根据多个 ids 获取
|
||||
*
|
||||
* @param modalIds
|
||||
* @return
|
||||
*/
|
||||
List<AiChatModelDO> getModalByIds(Set<Long> modalIds);
|
||||
}
|
||||
|
@ -12,6 +12,9 @@ import jakarta.annotation.Resource;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.*;
|
||||
|
||||
@ -89,4 +92,9 @@ public class AiChatModelServiceImpl implements AiChatModelService {
|
||||
return model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AiChatModelDO> getModalByIds(Set<Long> modalIds) {
|
||||
return chatModelMapper.selectByIds(modalIds);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user