mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2025-01-19 03:30:06 +08:00
stream 保存聊天记录
This commit is contained in:
parent
905ce773e9
commit
f84d25d3b7
@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.controller;
|
|||||||
|
|
||||||
import cn.hutool.core.exceptions.ExceptionUtil;
|
import cn.hutool.core.exceptions.ExceptionUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
|
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
|
||||||
import cn.iocoder.yudao.framework.ai.config.AiClient;
|
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
import cn.iocoder.yudao.module.ai.service.ChatService;
|
import cn.iocoder.yudao.module.ai.service.ChatService;
|
||||||
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
||||||
@ -38,7 +37,6 @@ import java.util.function.Consumer;
|
|||||||
public class ChatController {
|
public class ChatController {
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private AiClient aiClient;
|
|
||||||
private final ChatService chatService;
|
private final ChatService chatService;
|
||||||
|
|
||||||
@Operation(summary = "聊天-chat", description = "这个一般等待时间比较久,需要全部完成才会返回!")
|
@Operation(summary = "聊天-chat", description = "这个一般等待时间比较久,需要全部完成才会返回!")
|
||||||
@ -52,30 +50,7 @@ public class ChatController {
|
|||||||
@GetMapping(value = "/chatStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
@GetMapping(value = "/chatStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||||
public SseEmitter chatStream(@Validated @ModelAttribute ChatReq req) {
|
public SseEmitter chatStream(@Validated @ModelAttribute ChatReq req) {
|
||||||
Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
|
Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
|
||||||
Flux<ChatResponse> streamResponse = chatService.chatStream(req);
|
chatService.chatStream(req, sseEmitter);
|
||||||
streamResponse.subscribe(
|
|
||||||
new Consumer<ChatResponse>() {
|
|
||||||
@Override
|
|
||||||
public void accept(ChatResponse chatResponse) {
|
|
||||||
String content = chatResponse.getResults().get(0).getOutput().getContent();
|
|
||||||
try {
|
|
||||||
sseEmitter.send(content, MediaType.APPLICATION_JSON);
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("发送异常{}", ExceptionUtil.getMessage(e));
|
|
||||||
// 如果不是因为关闭而抛出异常,则重新连接
|
|
||||||
sseEmitter.completeWithError(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
error -> {
|
|
||||||
//
|
|
||||||
log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
|
|
||||||
},
|
|
||||||
() -> {
|
|
||||||
log.info("发送完成!");
|
|
||||||
sseEmitter.complete();
|
|
||||||
}
|
|
||||||
);
|
|
||||||
return sseEmitter;
|
return sseEmitter;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
package cn.iocoder.yudao.module.ai.service;
|
package cn.iocoder.yudao.module.ai.service;
|
||||||
|
|
||||||
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
|
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
||||||
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
|
|
||||||
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
||||||
import reactor.core.publisher.Flux;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 聊天 chat
|
* 聊天 chat
|
||||||
@ -26,7 +24,8 @@ public interface ChatService {
|
|||||||
* chat stream
|
* chat stream
|
||||||
*
|
*
|
||||||
* @param req
|
* @param req
|
||||||
|
* @param sseEmitter
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Flux<ChatResponse> chatStream(ChatReq req);
|
void chatStream(ChatReq req, Utf8SseEmitter sseEmitter);
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import cn.iocoder.yudao.framework.ai.config.AiClient;
|
|||||||
import cn.iocoder.yudao.framework.common.exception.ServerException;
|
import cn.iocoder.yudao.framework.common.exception.ServerException;
|
||||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||||
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
||||||
|
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
||||||
import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO;
|
import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO;
|
||||||
import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO;
|
import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO;
|
||||||
import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO;
|
import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO;
|
||||||
@ -21,10 +22,14 @@ import cn.iocoder.yudao.module.ai.service.ChatService;
|
|||||||
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 聊天 service
|
* 聊天 service
|
||||||
*
|
*
|
||||||
@ -51,25 +56,17 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
*/
|
*/
|
||||||
@Transactional(rollbackFor = Exception.class)
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public String chat(ChatReq req) {
|
public String chat(ChatReq req) {
|
||||||
|
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||||
// 获取 client 类型
|
// 获取 client 类型
|
||||||
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
|
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
|
||||||
// 获取 对话类型(新建还是继续)
|
// 获取 对话类型(新建还是继续)
|
||||||
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
|
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
|
||||||
|
AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
|
||||||
|
|
||||||
AiChatConversationDO aiChatConversationDO;
|
// 保存 chat message
|
||||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
|
||||||
if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
|
|
||||||
// 创建一个新的对话
|
|
||||||
aiChatConversationDO = createNewChatConversation(req, loginUserId);
|
|
||||||
} else {
|
|
||||||
// 继续对话
|
|
||||||
if (req.getConversationId() == null) {
|
|
||||||
throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
|
|
||||||
}
|
|
||||||
aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
|
|
||||||
}
|
|
||||||
|
|
||||||
String content;
|
String content = null;
|
||||||
try {
|
try {
|
||||||
// 创建 chat 需要的 Prompt
|
// 创建 chat 需要的 Prompt
|
||||||
Prompt prompt = new Prompt(req.getPrompt());
|
Prompt prompt = new Prompt(req.getPrompt());
|
||||||
@ -81,13 +78,19 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
content = call.getResult().getOutput().getContent();
|
content = call.getResult().getOutput().getContent();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
content = ExceptionUtil.getMessage(e);
|
content = ExceptionUtil.getMessage(e);
|
||||||
|
} finally {
|
||||||
|
// 保存 chat message
|
||||||
|
saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content);
|
||||||
}
|
}
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void saveChatMessage(ChatReq req, Long chatConversationId, Long loginUserId) {
|
||||||
// 增加 chat message 记录
|
// 增加 chat message 记录
|
||||||
aiChatMessageMapper.insert(
|
aiChatMessageMapper.insert(
|
||||||
new AiChatMessageDO()
|
new AiChatMessageDO()
|
||||||
.setId(null)
|
.setId(null)
|
||||||
.setChatConversationId(aiChatConversationDO.getId())
|
.setChatConversationId(chatConversationId)
|
||||||
.setUserId(loginUserId)
|
.setUserId(loginUserId)
|
||||||
.setMessage(req.getPrompt())
|
.setMessage(req.getPrompt())
|
||||||
.setMessageType(MessageType.USER.getValue())
|
.setMessageType(MessageType.USER.getValue())
|
||||||
@ -98,7 +101,39 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
|
|
||||||
// chat count 先+1
|
// chat count 先+1
|
||||||
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
||||||
return content;
|
}
|
||||||
|
|
||||||
|
public void saveSystemChatMessage(ChatReq req, Long chatConversationId, Long loginUserId, String systemPrompts) {
|
||||||
|
// 增加 chat message 记录
|
||||||
|
aiChatMessageMapper.insert(
|
||||||
|
new AiChatMessageDO()
|
||||||
|
.setId(null)
|
||||||
|
.setChatConversationId(chatConversationId)
|
||||||
|
.setUserId(loginUserId)
|
||||||
|
.setMessage(systemPrompts)
|
||||||
|
.setMessageType(MessageType.SYSTEM.getValue())
|
||||||
|
.setTopK(req.getTopK())
|
||||||
|
.setTopP(req.getTopP())
|
||||||
|
.setTemperature(req.getTemperature())
|
||||||
|
);
|
||||||
|
|
||||||
|
// chat count 先+1
|
||||||
|
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
||||||
|
}
|
||||||
|
|
||||||
|
private AiChatConversationDO getChatConversationNoExistToCreate(ChatReq req, ChatConversationTypeEnum chatConversationTypeEnum, Long loginUserId) {
|
||||||
|
AiChatConversationDO aiChatConversationDO;
|
||||||
|
if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
|
||||||
|
// 创建一个新的对话
|
||||||
|
aiChatConversationDO = createNewChatConversation(req, loginUserId);
|
||||||
|
} else {
|
||||||
|
// 继续对话
|
||||||
|
if (req.getConversationId() == null) {
|
||||||
|
throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
|
||||||
|
}
|
||||||
|
aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
|
||||||
|
}
|
||||||
|
return aiChatConversationDO;
|
||||||
}
|
}
|
||||||
|
|
||||||
private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
|
private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
|
||||||
@ -133,16 +168,52 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
* chat stream
|
* chat stream
|
||||||
*
|
*
|
||||||
* @param req
|
* @param req
|
||||||
|
* @param sseEmitter
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Flux<ChatResponse> chatStream(ChatReq req) {
|
public void chatStream(ChatReq req, Utf8SseEmitter sseEmitter) {
|
||||||
|
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||||
|
// 获取 client 类型
|
||||||
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
|
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
|
||||||
|
// 获取 对话类型(新建还是继续)
|
||||||
|
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
|
||||||
|
AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
|
||||||
// 创建 chat 需要的 Prompt
|
// 创建 chat 需要的 Prompt
|
||||||
Prompt prompt = new Prompt(req.getPrompt());
|
Prompt prompt = new Prompt(req.getPrompt());
|
||||||
req.setTopK(req.getTopK());
|
req.setTopK(req.getTopK());
|
||||||
req.setTopP(req.getTopP());
|
req.setTopP(req.getTopP());
|
||||||
req.setTemperature(req.getTemperature());
|
req.setTemperature(req.getTemperature());
|
||||||
return aiClient.stream(prompt, clientNameEnum.getName());
|
// 保存 chat message
|
||||||
|
saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
|
||||||
|
Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
|
||||||
|
|
||||||
|
StringBuffer contentBuffer = new StringBuffer();
|
||||||
|
streamResponse.subscribe(
|
||||||
|
new Consumer<ChatResponse>() {
|
||||||
|
@Override
|
||||||
|
public void accept(ChatResponse chatResponse) {
|
||||||
|
String content = chatResponse.getResults().get(0).getOutput().getContent();
|
||||||
|
try {
|
||||||
|
contentBuffer.append(content);
|
||||||
|
sseEmitter.send(content, MediaType.APPLICATION_JSON);
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("发送异常{}", ExceptionUtil.getMessage(e));
|
||||||
|
// 如果不是因为关闭而抛出异常,则重新连接
|
||||||
|
sseEmitter.completeWithError(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
error -> {
|
||||||
|
//
|
||||||
|
log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
|
||||||
|
},
|
||||||
|
() -> {
|
||||||
|
log.info("发送完成!");
|
||||||
|
sseEmitter.complete();
|
||||||
|
// 保存 chat message
|
||||||
|
saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString());
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user