【新增】AI:流式发送消息的微调,统一成单接口

This commit is contained in:
YunaiV 2024-05-15 23:06:18 +08:00
parent 20657ccaf3
commit b31e919d52
9 changed files with 124 additions and 136 deletions

View File

@ -10,13 +10,13 @@ Authorization: {{token}}
} }
### chat call ### 发送消息(流式)
POST {{baseUrl}}/admin-api/ai/chat/message/send-stream POST {{baseUrl}}/ai/chat/message/send-stream
Content-Type: application/json Content-Type: application/json
Authorization: {{token}} Authorization: {{token}}
{ {
"conversationId": "1781604279872581649", "conversationId": "1781604279872581651",
"content": "苹果是什么颜色?" "content": "苹果是什么颜色?"
} }

View File

@ -17,6 +17,7 @@ import reactor.core.publisher.Flux;
import java.util.List; import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
@Tag(name = "管理后台 - 聊天消息") @Tag(name = "管理后台 - 聊天消息")
@RestController @RestController
@ -36,14 +37,8 @@ public class AiChatMessageController {
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题 @PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题
public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendStreamReqVO sendReqVO) { public Flux<AiChatMessageSendRespVO> sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
return chatService.chatStream(sendReqVO); return chatService.sendChatMessageStream(sendReqVO, getLoginUserId());
}
@Operation(summary = "添加/提问", description = "先创建好 message 前端才好渲染")
@PostMapping(value = "/add")
public CommonResult<AiChatMessageRespVO> add(@Validated @RequestBody AiChatMessageAddReqVO req) {
return success(chatService.add(req));
} }
@Operation(summary = "获得指定会话的消息列表") @Operation(summary = "获得指定会话的消息列表")

View File

@ -44,4 +44,5 @@ public class AiChatMessageRespVO {
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51") @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
private LocalDateTime createTime; private LocalDateTime createTime;
} }

View File

@ -0,0 +1,36 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
@Schema(description = "管理后台 - AI 聊天消息发送 Response VO")
@Data
public class AiChatMessageSendRespVO {
@Schema(description = "发送消息", requiredMode = Schema.RequiredMode.REQUIRED)
private Message send;
@Schema(description = "接收消息", requiredMode = Schema.RequiredMode.REQUIRED)
private Message receive;
@Schema(description = "消息")
@Data
public static class Message {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
private Long id;
@Schema(description = "消息类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "role")
private String type; // 参见 MessageType 枚举类
@Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊")
private String content;
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
private LocalDateTime createTime;
}
}

View File

@ -1,16 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
@Data
public class AiChatMessageSendStreamReqVO {
@Schema(description = "提问的 messageId", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
@NotNull(message = "提问的 messageId 不能为空")
private Long id;
}

View File

@ -27,12 +27,4 @@ public interface AiChatMessageConvert {
*/ */
List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList); List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList);
/**
* 转换 - aiChatMessageDO
*
* @param aiChatMessageDO
* @return
*/
AiChatMessageRespVO convertAiChatMessageRespVO(AiChatMessageDO aiChatMessageDO);
} }

View File

@ -22,22 +22,6 @@ public interface AiChatService {
*/ */
AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO); AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO);
/**
* chat stream
*
* @param sendReqVO
* @return
*/
Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO sendReqVO);
/**
* 添加 - message
*
* @param sendReqVO
* @return
*/
AiChatMessageRespVO add(AiChatMessageAddReqVO sendReqVO);
/** /**
* 获取 - 获取对话 message list * 获取 - 获取对话 message list
* *
@ -54,4 +38,13 @@ public interface AiChatService {
*/ */
Boolean deleteMessage(Long id); Boolean deleteMessage(Long id);
/**
* 发送消息
*
* @param sendReqVO
* @param userId
* @return
*/
Flux<AiChatMessageSendRespVO> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId);
} }

View File

@ -1,27 +1,24 @@
package cn.iocoder.yudao.module.ai.service.impl; package cn.iocoder.yudao.module.ai.service.impl;
import cn.hutool.core.exceptions.ExceptionUtil; import cn.hutool.core.exceptions.ExceptionUtil;
import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageAddReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendStreamReqVO;
import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert; import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService; import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
@ -33,13 +30,16 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.time.LocalDateTime;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
/** /**
* 聊天 service * 聊天 service
* *
@ -52,11 +52,11 @@ import java.util.stream.Collectors;
@AllArgsConstructor @AllArgsConstructor
public class AiChatServiceImpl implements AiChatService { public class AiChatServiceImpl implements AiChatService {
private final AiChatClientFactory aiChatClientFactory; private final AiChatClientFactory chatClientFactory;
private final AiChatMessageMapper aiChatMessageMapper; private final AiChatMessageMapper aiChatMessageMapper;
private final AiChatConversationService chatConversationService; private final AiChatConversationService chatConversationService;
private final AiChatModelService aiChatModalService; private final AiChatModelService chatModalService;
private final AiChatRoleService chatRoleService; private final AiChatRoleService chatRoleService;
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
@ -65,7 +65,7 @@ public class AiChatServiceImpl implements AiChatService {
// 查询对话 // 查询对话
AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId()); AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
// 获取对话模型 // 获取对话模型
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId()); AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId());
// 获取角色信息 // 获取角色信息
AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
// 获取 client 类型 // 获取 client 类型
@ -84,7 +84,7 @@ public class AiChatServiceImpl implements AiChatService {
// req.setTopP(req.getTopP()); // req.setTopP(req.getTopP());
// req.setTemperature(req.getTemperature()); // req.setTemperature(req.getTemperature());
// 发送 call 调用 // 发送 call 调用
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum); ChatClient chatClient = chatClientFactory.getChatClient(platformEnum);
ChatResponse call = chatClient.call(prompt); ChatResponse call = chatClient.call(prompt);
content = call.getResult().getOutput().getContent(); content = call.getResult().getOutput().getContent();
tokens = call.getResults().size(); tokens = call.getResults().size();
@ -113,55 +113,56 @@ public class AiChatServiceImpl implements AiChatService {
.setModelId(modelId) .setModelId(modelId)
.setContent(content) .setContent(content)
.setTokens(tokens) .setTokens(tokens)
.setTemperature(temperature) .setTemperature(temperature)
.setMaxTokens(maxTokens) .setMaxTokens(maxTokens)
.setMaxContexts(maxContexts); .setMaxContexts(maxContexts);
insertChatMessageDO.setCreateTime(LocalDateTime.now());
// 增加 chat message 记录 // 增加 chat message 记录
aiChatMessageMapper.insert(insertChatMessageDO); aiChatMessageMapper.insert(insertChatMessageDO);
return insertChatMessageDO; return insertChatMessageDO;
} }
public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO req) { @Override
Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); public Flux<AiChatMessageSendRespVO> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) {
// 查询提问的 message // 1.1 校验对话存在
AiChatMessageDO aiChatMessageDO = aiChatMessageMapper.selectById(req.getId()); AiChatConversationDO conversation = chatConversationService.validateExists(sendReqVO.getConversationId());
if (aiChatMessageDO == null) { if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST); throw exception(CHAT_CONVERSATION_NOT_EXISTS);
} }
// 查询对话 // 1.2 校验模型
AiChatConversationDO conversation = chatConversationService.validateExists(aiChatMessageDO.getConversationId()); AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
// 获取对话模型 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId()); StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
// 获取角色信息
AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; // 2. 插入 user 发送消息 TODO tokens 计算
// 创建 chat 需要的 Prompt AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(),
Prompt prompt = new Prompt(aiChatMessageDO.getContent()); conversation.getModel(), conversation.getId(), sendReqVO.getContent(),
// 提前创建一个 system message null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), "", // 3.1 插入 system 接收消息
AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(),
conversation.getModel(), conversation.getId(), conversation.getSystemMessage(),
0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); 0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
// req.setTopK(req.getTopK()); // 3.2 创建 chat 需要的 Prompt
// req.setTopP(req.getTopP()); // TODO 消息上下文
// req.setTemperature(req.getTemperature()); Prompt prompt = new Prompt(sendReqVO.getContent());
// 获取 client 类型 // ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build()
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform()); Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); // 3.3 转换 flex AiChatMessageRespVO
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
// 转换 flex AiChatMessageRespVO
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
AtomicInteger tokens = new AtomicInteger(0); AtomicInteger tokens = new AtomicInteger(0); // TODO token 计算不对
return streamResponse.map(res -> { return streamResponse.map(res -> {
AiChatMessageRespVO aiChatMessageRespVO =
AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(systemMessage);
aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
contentBuffer.append(res.getResult().getOutput().getContent()); contentBuffer.append(res.getResult().getOutput().getContent());
tokens.incrementAndGet(); tokens.incrementAndGet();
return aiChatMessageRespVO;
} AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId())
).doOnComplete(new Runnable() { .setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime())
@Override .setContent(sendReqVO.getContent());
public void run() { AiChatMessageSendRespVO.Message receive = new AiChatMessageSendRespVO.Message().setId(systemMessage.getId())
.setType(MessageType.SYSTEM.getValue()).setCreateTime(systemMessage.getCreateTime())
.setContent(res.getResult().getOutput().getContent());
return new AiChatMessageSendRespVO().setSend(send).setReceive(receive);
}).doOnComplete(() -> {
log.info("发送完成!"); log.info("发送完成!");
// 保存 chat message // 保存 chat message
aiChatMessageMapper.updateById(new AiChatMessageDO() aiChatMessageMapper.updateById(new AiChatMessageDO()
@ -169,34 +170,17 @@ public class AiChatServiceImpl implements AiChatService {
.setContent(contentBuffer.toString()) .setContent(contentBuffer.toString())
.setTokens(tokens.get()) .setTokens(tokens.get())
); );
} }).doOnError(throwable -> {
}).doOnError(new Consumer<Throwable>() {
@Override
public void accept(Throwable throwable) {
log.error("发送错误 {}!", throwable.getMessage()); log.error("发送错误 {}!", throwable.getMessage());
// 更新错误信息 // 更新错误信息 TODO 貌似不应该更新异常
aiChatMessageMapper.updateById(new AiChatMessageDO() aiChatMessageMapper.updateById(new AiChatMessageDO()
.setId(systemMessage.getId()) .setId(systemMessage.getId())
.setContent(throwable.getMessage()) .setContent(throwable.getMessage())
.setTokens(tokens.get()) .setTokens(tokens.get())
); );
}
}); });
} }
@Override
public AiChatMessageRespVO add(AiChatMessageAddReqVO req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询对话
AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
// 获取对话模型
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(userMessage);
}
@Override @Override
public List<AiChatMessageRespVO> getMessageListByConversationId(Long conversationId) { public List<AiChatMessageRespVO> getMessageListByConversationId(Long conversationId) {
// 校验对话是否存在 // 校验对话是否存在
@ -205,7 +189,7 @@ public class AiChatServiceImpl implements AiChatService {
List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId); List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId);
// 获取模型信息 // 获取模型信息
Set<Long> modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet()); Set<Long> modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet());
List<AiChatModelDO> modalList = aiChatModalService.getModalByIds(modalIds); List<AiChatModelDO> modalList = chatModalService.getModalByIds(modalIds);
Map<Long, AiChatModelDO> modalIdMap = modalList.stream().collect(Collectors.toMap(AiChatModelDO::getId, o -> o)); Map<Long, AiChatModelDO> modalIdMap = modalList.stream().collect(Collectors.toMap(AiChatModelDO::getId, o -> o));
// 转换 AiChatMessageRespVO // 转换 AiChatMessageRespVO
List<AiChatMessageRespVO> aiChatMessageRespList = AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList); List<AiChatMessageRespVO> aiChatMessageRespList = AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList);

View File

@ -94,7 +94,10 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
String a = ";"; String a = ";";
} }
}); });
return response.map(res -> new ChatResponse(List.of(new Generation(res.getResult())))); return response.map(res -> {
// TODO @fan这里缺少了 usage 的封装
return new ChatResponse(List.of(new Generation(res.getResult())));
});
} }
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) { private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {