stream 保存聊天记录

This commit is contained in:
cherishsince 2024-04-18 16:09:45 +08:00
parent 905ce773e9
commit f84d25d3b7
3 changed files with 92 additions and 47 deletions

View File

@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.controller;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.config.AiClient;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.service.ChatService;
import cn.iocoder.yudao.module.ai.vo.ChatReq;
@ -38,7 +37,6 @@ import java.util.function.Consumer;
public class ChatController {
@Autowired
private AiClient aiClient;
private final ChatService chatService;
@Operation(summary = "聊天-chat", description = "这个一般等待时间比较久,需要全部完成才会返回!")
@ -52,30 +50,7 @@ public class ChatController {
@GetMapping(value = "/chatStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter chatStream(@Validated @ModelAttribute ChatReq req) {
Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
Flux<ChatResponse> streamResponse = chatService.chatStream(req);
streamResponse.subscribe(
new Consumer<ChatResponse>() {
@Override
public void accept(ChatResponse chatResponse) {
String content = chatResponse.getResults().get(0).getOutput().getContent();
try {
sseEmitter.send(content, MediaType.APPLICATION_JSON);
} catch (IOException e) {
log.error("发送异常{}", ExceptionUtil.getMessage(e));
// 如果不是因为关闭而抛出异常则重新连接
sseEmitter.completeWithError(e);
}
}
},
error -> {
//
log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
},
() -> {
log.info("发送完成!");
sseEmitter.complete();
}
);
chatService.chatStream(req, sseEmitter);
return sseEmitter;
}
}

View File

@ -1,9 +1,7 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.vo.ChatReq;
import reactor.core.publisher.Flux;
/**
* 聊天 chat
@ -26,7 +24,8 @@ public interface ChatService {
* chat stream
*
* @param req
* @param sseEmitter
* @return
*/
Flux<ChatResponse> chatStream(ChatReq req);
void chatStream(ChatReq req, Utf8SseEmitter sseEmitter);
}

View File

@ -8,6 +8,7 @@ import cn.iocoder.yudao.framework.ai.config.AiClient;
import cn.iocoder.yudao.framework.common.exception.ServerException;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO;
@ -21,10 +22,14 @@ import cn.iocoder.yudao.module.ai.service.ChatService;
import cn.iocoder.yudao.module.ai.vo.ChatReq;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
import java.io.IOException;
import java.util.function.Consumer;
/**
* 聊天 service
*
@ -51,25 +56,17 @@ public class ChatServiceImpl implements ChatService {
*/
@Transactional(rollbackFor = Exception.class)
public String chat(ChatReq req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
// 获取 对话类型(新建还是继续)
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
AiChatConversationDO aiChatConversationDO;
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
// 创建一个新的对话
aiChatConversationDO = createNewChatConversation(req, loginUserId);
} else {
// 继续对话
if (req.getConversationId() == null) {
throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
}
aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
}
// 保存 chat message
saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
String content;
String content = null;
try {
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt());
@ -81,13 +78,19 @@ public class ChatServiceImpl implements ChatService {
content = call.getResult().getOutput().getContent();
} catch (Exception e) {
content = ExceptionUtil.getMessage(e);
} finally {
// 保存 chat message
saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content);
}
return content;
}
private void saveChatMessage(ChatReq req, Long chatConversationId, Long loginUserId) {
// 增加 chat message 记录
aiChatMessageMapper.insert(
new AiChatMessageDO()
.setId(null)
.setChatConversationId(aiChatConversationDO.getId())
.setChatConversationId(chatConversationId)
.setUserId(loginUserId)
.setMessage(req.getPrompt())
.setMessageType(MessageType.USER.getValue())
@ -98,7 +101,39 @@ public class ChatServiceImpl implements ChatService {
// chat count +1
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
return content;
}
public void saveSystemChatMessage(ChatReq req, Long chatConversationId, Long loginUserId, String systemPrompts) {
// 增加 chat message 记录
aiChatMessageMapper.insert(
new AiChatMessageDO()
.setId(null)
.setChatConversationId(chatConversationId)
.setUserId(loginUserId)
.setMessage(systemPrompts)
.setMessageType(MessageType.SYSTEM.getValue())
.setTopK(req.getTopK())
.setTopP(req.getTopP())
.setTemperature(req.getTemperature())
);
// chat count +1
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
}
private AiChatConversationDO getChatConversationNoExistToCreate(ChatReq req, ChatConversationTypeEnum chatConversationTypeEnum, Long loginUserId) {
AiChatConversationDO aiChatConversationDO;
if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
// 创建一个新的对话
aiChatConversationDO = createNewChatConversation(req, loginUserId);
} else {
// 继续对话
if (req.getConversationId() == null) {
throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
}
aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
}
return aiChatConversationDO;
}
private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
@ -133,16 +168,52 @@ public class ChatServiceImpl implements ChatService {
* chat stream
*
* @param req
* @param sseEmitter
* @return
*/
@Override
public Flux<ChatResponse> chatStream(ChatReq req) {
public void chatStream(ChatReq req, Utf8SseEmitter sseEmitter) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
// 获取 对话类型(新建还是继续)
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt());
req.setTopK(req.getTopK());
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
return aiClient.stream(prompt, clientNameEnum.getName());
// 保存 chat message
saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
StringBuffer contentBuffer = new StringBuffer();
streamResponse.subscribe(
new Consumer<ChatResponse>() {
@Override
public void accept(ChatResponse chatResponse) {
String content = chatResponse.getResults().get(0).getOutput().getContent();
try {
contentBuffer.append(content);
sseEmitter.send(content, MediaType.APPLICATION_JSON);
} catch (IOException e) {
log.error("发送异常{}", ExceptionUtil.getMessage(e));
// 如果不是因为关闭而抛出异常则重新连接
sseEmitter.completeWithError(e);
}
}
},
error -> {
//
log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
},
() -> {
log.info("发送完成!");
sseEmitter.complete();
// 保存 chat message
saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString());
}
);
}
}