【新增】AI:聊天接入知识库

This commit is contained in:
xiaoxin 2024-09-22 18:13:21 +08:00
parent 5cd870748d
commit 0700c3f15e
6 changed files with 66 additions and 10 deletions

View File

@ -34,7 +34,12 @@ public enum AiChatRoleEnum {
### 支付宝 ### 支付宝
### 微信 ### 微信
除此之外不要任何解释性语句 除此之外不要任何解释性语句
"""); """),
AI_KNOWLEDGE_ROLE("知识库助手", """
给你提供一些数据参考{info},请回答我的问题
请你跟进数据参考与工具返回结果回复用户的请求
""");
/** /**
* 角色名 * 角色名

View File

@ -10,4 +10,7 @@ public class AiChatConversationCreateMyReqVO {
@Schema(description = "聊天角色编号", example = "666") @Schema(description = "聊天角色编号", example = "666")
private Long roleId; private Long roleId;
@Schema(description = "知识库编号", example = "1204")
private Long knowledgeId;
} }

View File

@ -21,6 +21,9 @@ public class AiChatConversationUpdateMyReqVO {
@Schema(description = "模型编号", example = "1") @Schema(description = "模型编号", example = "1")
private Long modelId; private Long modelId;
@Schema(description = "知识库编号", example = "1")
private Long knowledgeId;
@Schema(description = "角色设定", example = "一个快乐的程序员") @Schema(description = "角色设定", example = "一个快乐的程序员")
private String systemMessage; private String systemMessage;

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.chat; package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
@ -64,6 +65,13 @@ public class AiChatConversationDO extends BaseDO {
*/ */
private Long roleId; private Long roleId;
/**
* 知识库编号
* <p>
* 关联 {@link AiKnowledgeDO#getId()}
*/
private Long knowledgeId;
/** /**
* 模型编号 * 模型编号
* *

View File

@ -13,6 +13,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
@ -22,6 +23,7 @@ import org.springframework.validation.annotation.Validated;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.List; import java.util.List;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@ -45,6 +47,8 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
private AiChatModelService chatModalService; private AiChatModelService chatModalService;
@Resource @Resource
private AiChatRoleService chatRoleService; private AiChatRoleService chatRoleService;
@Resource
private AiKnowledgeService knowledgeService;
@Override @Override
public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) { public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) {
@ -56,9 +60,14 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
Assert.notNull(model, "必须找到默认模型"); Assert.notNull(model, "必须找到默认模型");
validateChatModel(model); validateChatModel(model);
// 1.3 校验知识库
if (Objects.nonNull(createReqVO.getKnowledgeId())) {
knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
}
// 2. 创建 AiChatConversationDO 聊天对话 // 2. 创建 AiChatConversationDO 聊天对话
AiChatConversationDO conversation = new AiChatConversationDO().setUserId(userId).setPinned(false) AiChatConversationDO conversation = new AiChatConversationDO().setUserId(userId).setPinned(false)
.setModelId(model.getId()).setModel(model.getModel()) .setModelId(model.getId()).setModel(model.getModel()).setKnowledgeId(createReqVO.getKnowledgeId())
.setTemperature(model.getTemperature()).setMaxTokens(model.getMaxTokens()).setMaxContexts(model.getMaxContexts()); .setTemperature(model.getTemperature()).setMaxTokens(model.getMaxTokens()).setMaxContexts(model.getMaxContexts());
if (role != null) { if (role != null) {
conversation.setTitle(role.getName()).setRoleId(role.getId()).setSystemMessage(role.getSystemMessage()); conversation.setTitle(role.getName()).setRoleId(role.getId()).setSystemMessage(role.getSystemMessage());
@ -82,6 +91,11 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
model = chatModalService.validateChatModel(updateReqVO.getModelId()); model = chatModalService.validateChatModel(updateReqVO.getModelId());
} }
// 1.3 校验知识库是否存在
if (updateReqVO.getKnowledgeId() != null) {
knowledgeService.validateKnowledgeExists(updateReqVO.getKnowledgeId());
}
// 2. 更新对话信息 // 2. 更新对话信息
AiChatConversationDO updateObj = BeanUtils.toBean(updateReqVO, AiChatConversationDO.class); AiChatConversationDO updateObj = BeanUtils.toBean(updateReqVO, AiChatConversationDO.class);
if (Boolean.TRUE.equals(updateReqVO.getPinned())) { if (Boolean.TRUE.equals(updateReqVO.getPinned())) {

View File

@ -12,21 +12,29 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -59,6 +67,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
private AiChatModelService chatModalService; private AiChatModelService chatModalService;
@Resource @Resource
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Resource
private AiKnowledgeSegmentService knowledgeSegmentService;
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
@ -141,14 +151,27 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) { AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
// 1. 构建 Prompt Message 列表 // 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
// 1.1 system context 角色设定
// 1.1 知识库召回
if (Objects.nonNull(conversation.getKnowledgeId())) {
List<AiKnowledgeSegmentDO> segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(conversation.getKnowledgeId()).setContent(sendReqVO.getContent()));
if (CollUtil.isNotEmpty(segmentList)) {
PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
StringBuilder infoBuilder = StrUtil.builder();
segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
chatMessages.add(message);
}
}
// 1.2 system context 角色设定
if (StrUtil.isNotBlank(conversation.getSystemMessage())) { if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
chatMessages.add(new SystemMessage(conversation.getSystemMessage())); chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
} }
// 1.2 history message 历史消息 // 1.3 history message 历史消息
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO); List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
// 1.3 user message 新发送消息 // 1.4 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent())); chatMessages.add(new UserMessage(sendReqVO.getContent()));
// 2. 构建 ChatOptions 对象 // 2. 构建 ChatOptions 对象
@ -160,12 +183,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
/** /**
* 从历史消息中获得倒序的 n 组消息作为消息上下文 * 从历史消息中获得倒序的 n 组消息作为消息上下文
* * <p>
* n 指的是 user + assistant 形成一组 * n 指的是 user + assistant 形成一组
* *
* @param messages 消息列表 * @param messages 消息列表
* @param conversation 对话 * @param conversation 对话
* @param sendReqVO 发送请求 * @param sendReqVO 发送请求
* @return 消息上下文 * @return 消息上下文
*/ */
private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages, private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages,
@ -182,7 +205,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
AiChatMessageDO userMessage = CollUtil.get(messages, i - 1); AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId()) if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
|| StrUtil.isEmpty(assistantMessage.getContent())) { || StrUtil.isEmpty(assistantMessage.getContent())) {
continue; continue;
} }
// 由于后续要 reverse 反转所以先添加 assistantMessage // 由于后续要 reverse 反转所以先添加 assistantMessage