调整 chat 和 steamChat 获取对话逻辑

This commit is contained in:
cherishsince 2024-04-18 17:10:08 +08:00
parent 97d4e56c81
commit 25523fd53e

View File

@ -18,7 +18,9 @@ import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum;
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
import cn.iocoder.yudao.module.ai.service.ChatConversationService;
import cn.iocoder.yudao.module.ai.service.ChatService; import cn.iocoder.yudao.module.ai.service.ChatService;
import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
import cn.iocoder.yudao.module.ai.vo.ChatReq; import cn.iocoder.yudao.module.ai.vo.ChatReq;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -46,6 +48,7 @@ public class ChatServiceImpl implements ChatService {
private final AiChatRoleMapper aiChatRoleMapper; private final AiChatRoleMapper aiChatRoleMapper;
private final AiChatMessageMapper aiChatMessageMapper; private final AiChatMessageMapper aiChatMessageMapper;
private final AiChatConversationMapper aiChatConversationMapper; private final AiChatConversationMapper aiChatConversationMapper;
private final ChatConversationService chatConversationService;
/** /**
@ -59,13 +62,10 @@ public class ChatServiceImpl implements ChatService {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型 // 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
// 获取 对话类型(新建还是继续) // 获取对话信息
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
// 保存 chat message // 保存 chat message
saveChatMessage(req, aiChatConversationDO.getId(), loginUserId); saveChatMessage(req, conversationRes.getId(), loginUserId);
String content = null; String content = null;
try { try {
// 创建 chat 需要的 Prompt // 创建 chat 需要的 Prompt
@ -80,7 +80,7 @@ public class ChatServiceImpl implements ChatService {
content = ExceptionUtil.getMessage(e); content = ExceptionUtil.getMessage(e);
} finally { } finally {
// 保存 chat message // 保存 chat message
saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content); saveSystemChatMessage(req, conversationRes.getId(), loginUserId, content);
} }
return content; return content;
} }
@ -176,16 +176,15 @@ public class ChatServiceImpl implements ChatService {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型 // 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
// 获取 对话类型(新建还是继续) // 获取对话信息
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
// 创建 chat 需要的 Prompt // 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt()); Prompt prompt = new Prompt(req.getPrompt());
req.setTopK(req.getTopK()); req.setTopK(req.getTopK());
req.setTopP(req.getTopP()); req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature()); req.setTemperature(req.getTemperature());
// 保存 chat message // 保存 chat message
saveChatMessage(req, aiChatConversationDO.getId(), loginUserId); saveChatMessage(req, conversationRes.getId(), loginUserId);
Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName()); Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
@ -212,7 +211,7 @@ public class ChatServiceImpl implements ChatService {
log.info("发送完成!"); log.info("发送完成!");
sseEmitter.complete(); sseEmitter.complete();
// 保存 chat message // 保存 chat message
saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString()); saveSystemChatMessage(req, conversationRes.getId(), loginUserId, contentBuffer.toString());
} }
); );
} }