mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2025-01-18 19:20:05 +08:00
stream 保存聊天记录
This commit is contained in:
parent
905ce773e9
commit
f84d25d3b7
@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.controller;
|
||||
|
||||
import cn.hutool.core.exceptions.ExceptionUtil;
|
||||
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
|
||||
import cn.iocoder.yudao.framework.ai.config.AiClient;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.module.ai.service.ChatService;
|
||||
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
||||
@ -38,7 +37,6 @@ import java.util.function.Consumer;
|
||||
public class ChatController {
|
||||
|
||||
@Autowired
|
||||
private AiClient aiClient;
|
||||
private final ChatService chatService;
|
||||
|
||||
@Operation(summary = "聊天-chat", description = "这个一般等待时间比较久,需要全部完成才会返回!")
|
||||
@ -52,30 +50,7 @@ public class ChatController {
|
||||
@GetMapping(value = "/chatStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public SseEmitter chatStream(@Validated @ModelAttribute ChatReq req) {
|
||||
Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
|
||||
Flux<ChatResponse> streamResponse = chatService.chatStream(req);
|
||||
streamResponse.subscribe(
|
||||
new Consumer<ChatResponse>() {
|
||||
@Override
|
||||
public void accept(ChatResponse chatResponse) {
|
||||
String content = chatResponse.getResults().get(0).getOutput().getContent();
|
||||
try {
|
||||
sseEmitter.send(content, MediaType.APPLICATION_JSON);
|
||||
} catch (IOException e) {
|
||||
log.error("发送异常{}", ExceptionUtil.getMessage(e));
|
||||
// 如果不是因为关闭而抛出异常,则重新连接
|
||||
sseEmitter.completeWithError(e);
|
||||
}
|
||||
}
|
||||
},
|
||||
error -> {
|
||||
//
|
||||
log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
|
||||
},
|
||||
() -> {
|
||||
log.info("发送完成!");
|
||||
sseEmitter.complete();
|
||||
}
|
||||
);
|
||||
chatService.chatStream(req, sseEmitter);
|
||||
return sseEmitter;
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,7 @@
|
||||
package cn.iocoder.yudao.module.ai.service;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
|
||||
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
|
||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
||||
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* 聊天 chat
|
||||
@ -26,7 +24,8 @@ public interface ChatService {
|
||||
* chat stream
|
||||
*
|
||||
* @param req
|
||||
* @param sseEmitter
|
||||
* @return
|
||||
*/
|
||||
Flux<ChatResponse> chatStream(ChatReq req);
|
||||
void chatStream(ChatReq req, Utf8SseEmitter sseEmitter);
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import cn.iocoder.yudao.framework.ai.config.AiClient;
|
||||
import cn.iocoder.yudao.framework.common.exception.ServerException;
|
||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
||||
import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO;
|
||||
import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO;
|
||||
import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO;
|
||||
@ -21,10 +22,14 @@ import cn.iocoder.yudao.module.ai.service.ChatService;
|
||||
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* 聊天 service
|
||||
*
|
||||
@ -51,25 +56,17 @@ public class ChatServiceImpl implements ChatService {
|
||||
*/
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public String chat(ChatReq req) {
|
||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||
// 获取 client 类型
|
||||
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
|
||||
// 获取 对话类型(新建还是继续)
|
||||
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
|
||||
AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
|
||||
|
||||
AiChatConversationDO aiChatConversationDO;
|
||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||
if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
|
||||
// 创建一个新的对话
|
||||
aiChatConversationDO = createNewChatConversation(req, loginUserId);
|
||||
} else {
|
||||
// 继续对话
|
||||
if (req.getConversationId() == null) {
|
||||
throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
|
||||
}
|
||||
aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
|
||||
}
|
||||
// 保存 chat message
|
||||
saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
|
||||
|
||||
String content;
|
||||
String content = null;
|
||||
try {
|
||||
// 创建 chat 需要的 Prompt
|
||||
Prompt prompt = new Prompt(req.getPrompt());
|
||||
@ -81,13 +78,19 @@ public class ChatServiceImpl implements ChatService {
|
||||
content = call.getResult().getOutput().getContent();
|
||||
} catch (Exception e) {
|
||||
content = ExceptionUtil.getMessage(e);
|
||||
} finally {
|
||||
// 保存 chat message
|
||||
saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content);
|
||||
}
|
||||
return content;
|
||||
}
|
||||
|
||||
private void saveChatMessage(ChatReq req, Long chatConversationId, Long loginUserId) {
|
||||
// 增加 chat message 记录
|
||||
aiChatMessageMapper.insert(
|
||||
new AiChatMessageDO()
|
||||
.setId(null)
|
||||
.setChatConversationId(aiChatConversationDO.getId())
|
||||
.setChatConversationId(chatConversationId)
|
||||
.setUserId(loginUserId)
|
||||
.setMessage(req.getPrompt())
|
||||
.setMessageType(MessageType.USER.getValue())
|
||||
@ -98,7 +101,39 @@ public class ChatServiceImpl implements ChatService {
|
||||
|
||||
// chat count 先+1
|
||||
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
||||
return content;
|
||||
}
|
||||
|
||||
public void saveSystemChatMessage(ChatReq req, Long chatConversationId, Long loginUserId, String systemPrompts) {
|
||||
// 增加 chat message 记录
|
||||
aiChatMessageMapper.insert(
|
||||
new AiChatMessageDO()
|
||||
.setId(null)
|
||||
.setChatConversationId(chatConversationId)
|
||||
.setUserId(loginUserId)
|
||||
.setMessage(systemPrompts)
|
||||
.setMessageType(MessageType.SYSTEM.getValue())
|
||||
.setTopK(req.getTopK())
|
||||
.setTopP(req.getTopP())
|
||||
.setTemperature(req.getTemperature())
|
||||
);
|
||||
|
||||
// chat count 先+1
|
||||
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
||||
}
|
||||
|
||||
private AiChatConversationDO getChatConversationNoExistToCreate(ChatReq req, ChatConversationTypeEnum chatConversationTypeEnum, Long loginUserId) {
|
||||
AiChatConversationDO aiChatConversationDO;
|
||||
if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
|
||||
// 创建一个新的对话
|
||||
aiChatConversationDO = createNewChatConversation(req, loginUserId);
|
||||
} else {
|
||||
// 继续对话
|
||||
if (req.getConversationId() == null) {
|
||||
throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
|
||||
}
|
||||
aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
|
||||
}
|
||||
return aiChatConversationDO;
|
||||
}
|
||||
|
||||
private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
|
||||
@ -133,16 +168,52 @@ public class ChatServiceImpl implements ChatService {
|
||||
* chat stream
|
||||
*
|
||||
* @param req
|
||||
* @param sseEmitter
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Flux<ChatResponse> chatStream(ChatReq req) {
|
||||
public void chatStream(ChatReq req, Utf8SseEmitter sseEmitter) {
|
||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||
// 获取 client 类型
|
||||
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
|
||||
// 获取 对话类型(新建还是继续)
|
||||
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
|
||||
AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
|
||||
// 创建 chat 需要的 Prompt
|
||||
Prompt prompt = new Prompt(req.getPrompt());
|
||||
req.setTopK(req.getTopK());
|
||||
req.setTopP(req.getTopP());
|
||||
req.setTemperature(req.getTemperature());
|
||||
return aiClient.stream(prompt, clientNameEnum.getName());
|
||||
// 保存 chat message
|
||||
saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
|
||||
Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
|
||||
|
||||
StringBuffer contentBuffer = new StringBuffer();
|
||||
streamResponse.subscribe(
|
||||
new Consumer<ChatResponse>() {
|
||||
@Override
|
||||
public void accept(ChatResponse chatResponse) {
|
||||
String content = chatResponse.getResults().get(0).getOutput().getContent();
|
||||
try {
|
||||
contentBuffer.append(content);
|
||||
sseEmitter.send(content, MediaType.APPLICATION_JSON);
|
||||
} catch (IOException e) {
|
||||
log.error("发送异常{}", ExceptionUtil.getMessage(e));
|
||||
// 如果不是因为关闭而抛出异常,则重新连接
|
||||
sseEmitter.completeWithError(e);
|
||||
}
|
||||
}
|
||||
},
|
||||
error -> {
|
||||
//
|
||||
log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
|
||||
},
|
||||
() -> {
|
||||
log.info("发送完成!");
|
||||
sseEmitter.complete();
|
||||
// 保存 chat message
|
||||
saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString());
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user