调整 chat 和 steamChat 获取对话逻辑

This commit is contained in:
cherishsince 2024-04-18 17:10:08 +08:00
parent 97d4e56c81
commit 25523fd53e

View File

@ -18,7 +18,9 @@ import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum;
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
import cn.iocoder.yudao.module.ai.service.ChatConversationService;
import cn.iocoder.yudao.module.ai.service.ChatService;
import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
import cn.iocoder.yudao.module.ai.vo.ChatReq;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@ -46,6 +48,7 @@ public class ChatServiceImpl implements ChatService {
private final AiChatRoleMapper aiChatRoleMapper;
private final AiChatMessageMapper aiChatMessageMapper;
private final AiChatConversationMapper aiChatConversationMapper;
private final ChatConversationService chatConversationService;
/**
@ -59,13 +62,10 @@ public class ChatServiceImpl implements ChatService {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
// 获取 对话类型(新建还是继续)
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
// 获取对话信息
ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
// 保存 chat message
saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
saveChatMessage(req, conversationRes.getId(), loginUserId);
String content = null;
try {
// 创建 chat 需要的 Prompt
@ -80,7 +80,7 @@ public class ChatServiceImpl implements ChatService {
content = ExceptionUtil.getMessage(e);
} finally {
// 保存 chat message
saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content);
saveSystemChatMessage(req, conversationRes.getId(), loginUserId, content);
}
return content;
}
@ -176,16 +176,15 @@ public class ChatServiceImpl implements ChatService {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
// 获取 对话类型(新建还是继续)
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
// 获取对话信息
ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt());
req.setTopK(req.getTopK());
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
// 保存 chat message
saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
saveChatMessage(req, conversationRes.getId(), loginUserId);
Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
StringBuffer contentBuffer = new StringBuffer();
@ -212,7 +211,7 @@ public class ChatServiceImpl implements ChatService {
log.info("发送完成!");
sseEmitter.complete();
// 保存 chat message
saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString());
saveSystemChatMessage(req, conversationRes.getId(), loginUserId, contentBuffer.toString());
}
);
}