【新增】AI:conversation 发送消息时,增加上下文

This commit is contained in:
YunaiV 2024-05-18 00:16:56 +08:00
parent 275d1fb627
commit 276ef98ff1
4 changed files with 98 additions and 98 deletions

View File

@ -19,10 +19,10 @@ public class AiChatMessageRespVO {
private String type; // 参见 MessageType 枚举类
@Schema(description = "用户编号", example = "4096")
private Long userId; // 仅当 user 发送时非空
private Long userId;
@Schema(description = "角色编号", example = "888")
private Long roleId; // 仅当 assistant 回复时非空
private Long roleId;
@Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "gpt-3.5-turbo")
private String model; // 参见 AiOpenAiModelEnum 枚举类

View File

@ -47,16 +47,12 @@ public class AiChatMessageDO extends BaseDO {
/**
* 用户编号
*
* 仅当 user 发送时非空
*
* 关联 AdminUserDO userId 字段
*/
private Long userId;
/**
* 角色编号
*
* 仅当 assistant 回复时非空
*
* 关联 {@link AiChatRoleDO#getId()} 字段
*/
private Long roleId;

View File

@ -1,23 +1,22 @@
package cn.iocoder.yudao.module.ai.service.impl;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
@ -30,10 +29,7 @@ import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.*;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -53,64 +49,49 @@ public class AiChatServiceImpl implements AiChatService {
private final AiChatClientFactory chatClientFactory;
private final AiChatMessageMapper aiChatMessageMapper;
private final AiChatMessageMapper chatMessageMapper;
private final AiChatConversationService chatConversationService;
private final AiChatModelService chatModalService;
private final AiChatRoleService chatRoleService;
@Transactional(rollbackFor = Exception.class)
public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询对话
AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
// 获取对话模型
AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId());
// 获取角色信息
AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), req.getContent());
String content = null;
int tokens = 0;
try {
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getContent());
// TODO @芋艿 @范 看要不要支持这些
// req.setTopK(req.getTopK());
// req.setTopP(req.getTopP());
// req.setTemperature(req.getTemperature());
// 发送 call 调用
ChatClient chatClient = chatClientFactory.getChatClient(platformEnum);
ChatResponse call = chatClient.call(prompt);
content = call.getResult().getOutput().getContent();
tokens = call.getResults().size();
// 更新 conversation
} catch (Exception e) {
content = ExceptionUtil.getMessage(e);
} finally {
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), content);
}
return new AiChatMessageRespVO().setContent(content);
}
private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId,
String model, Long modelId, String content) {
AiChatMessageDO insertChatMessageDO = new AiChatMessageDO()
.setConversationId(conversationId)
.setType(messageType.getValue())
.setUserId(loginUserId)
.setRoleId(roleId)
.setModel(model)
.setModelId(modelId)
.setContent(content);
insertChatMessageDO.setCreateTime(LocalDateTime.now());
// 增加 chat message 记录
aiChatMessageMapper.insert(insertChatMessageDO);
return insertChatMessageDO;
return null; // TODO 芋艿一起改
// Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// // 查询对话
// AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
// // 获取对话模型
// AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId());
// // 获取角色信息
// AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
// // 获取 client 类型
// AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
// // 保存 chat message
// createChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
// chatModel.getModel(), chatModel.getId(), req.getContent());
// String content = null;
// int tokens = 0;
// try {
// // 创建 chat 需要的 Prompt
// Prompt prompt = new Prompt(req.getContent());
// // TODO @芋艿 @范 看要不要支持这些
//// req.setTopK(req.getTopK());
//// req.setTopP(req.getTopP());
//// req.setTemperature(req.getTemperature());
// // 发送 call 调用
// ChatClient chatClient = chatClientFactory.getChatClient(platformEnum);
// ChatResponse call = chatClient.call(prompt);
// content = call.getResult().getOutput().getContent();
// // 更新 conversation
// } catch (Exception e) {
// content = ExceptionUtil.getMessage(e);
// } finally {
// // 保存 chat message
// createChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
// chatModel.getModel(), chatModel.getId(), content);
// }
// return new AiChatMessageRespVO().setContent(content);
}
@Override
@ -120,55 +101,78 @@ public class AiChatServiceImpl implements AiChatService {
if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿异常情况的对接
}
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectByConversationId(conversation.getId());
// 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
// 2. 插入 user 发送消息
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(),
conversation.getModel(), conversation.getId(), sendReqVO.getContent());
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), model,
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent());
// 3.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "");
// 3.1 插入 system 接收消息
AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(),
conversation.getModel(), conversation.getId(), conversation.getSystemMessage());
// 3.2 创建 chat 需要的 Prompt
// TODO 消息上下文
Prompt prompt = new Prompt(sendReqVO.getContent());
// ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build()
Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO);
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
// 3.3 转换 flex AiChatMessageRespVO
// 3.3 流式返回
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(res -> {
contentBuffer.append(res.getResult().getOutput().getContent());
AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId())
.setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime())
.setContent(sendReqVO.getContent());
AiChatMessageSendRespVO.Message receive = new AiChatMessageSendRespVO.Message().setId(systemMessage.getId())
.setType(MessageType.SYSTEM.getValue()).setCreateTime(systemMessage.getCreateTime())
.setContent(res.getResult().getOutput().getContent());
return new AiChatMessageSendRespVO().setSend(send).setReceive(receive);
return streamResponse.map(response -> {
String newContent = response.getResult().getOutput().getContent();
contentBuffer.append(newContent);
// 响应结果
return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
.setReceive(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
}).doOnComplete(() -> {
log.info("发送完成!");
// 保存 chat message
aiChatMessageMapper.updateById(new AiChatMessageDO()
.setId(systemMessage.getId())
.setContent(contentBuffer.toString()));
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString()));
}).doOnError(throwable -> {
log.error("发送错误 {}!", throwable.getMessage());
// 更新错误信息 TODO 貌似不应该更新异常
aiChatMessageMapper.updateById(new AiChatMessageDO()
.setId(systemMessage.getId())
.setContent(throwable.getMessage()));
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage()));
});
}
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, AiChatMessageSendReqVO sendReqVO) {
// TODO 芋艿1保留 n 个上下文2每一轮 token 数量
// if (conversation.getMaxContexts() != null && messages.size() > conversation.getMaxContexts()) {
//
// }
// 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>();
// 1.1 system context 角色设定
chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
// 1.2 history message 历史消息
messages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
// 1.3 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent()));
// 2. 构建 ChatOptions 对象
ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build();
return new Prompt(chatMessages, chatOptions);
}
private AiChatMessageDO createChatMessage(Long conversationId, AiChatModelDO model,
Long userId, Long roleId,
MessageType messageType, String content) {
AiChatMessageDO message = new AiChatMessageDO()
.setConversationId(conversationId).setModel(model.getModel()).setModelId(model.getId())
.setUserId(userId).setRoleId(roleId)
.setType(messageType.getValue()).setContent(content);
message.setCreateTime(LocalDateTime.now());
chatMessageMapper.insert(message);
return message;
}
@Override
public List<AiChatMessageRespVO> getMessageListByConversationId(Long conversationId) {
// 校验对话是否存在
chatConversationService.validateExists(conversationId);
// 获取对话所有 message
List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId);
List<AiChatMessageDO> aiChatMessageDOList = chatMessageMapper.selectByConversationId(conversationId);
// 获取模型信息
Set<Long> modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet());
List<AiChatModelDO> modalList = chatModalService.getModalByIds(modalIds);
@ -187,7 +191,7 @@ public class AiChatServiceImpl implements AiChatService {
@Override
public Boolean deleteMessage(Long id) {
return aiChatMessageMapper.deleteById(id) > 0;
return chatMessageMapper.deleteById(id) > 0;
}
}

View File

@ -15,13 +15,13 @@ import lombok.Getter;
public enum AiPlatformEnum {
OPENAI("OpenAI", "OpenAI"),
OLLAMA("dall", "dall"),
OLLAMA("Ollama", "Ollama"),
YI_YAN("yiyan", "一言"),
QIAN_WEN("qianwen", "千问"),
XING_HUO("xinghuo", "星火"),
OPEN_AI_DALL("dall", "dall"),
MIDJOURNEY("Ollama", "Ollama"),
MIDJOURNEY("midjourney", "midjourney"),
;