【调整】调整AI聊天模块

This commit is contained in:
cherishsince 2024-05-07 10:45:37 +08:00
parent f86e24bb86
commit 424210066f
10 changed files with 139 additions and 81 deletions

View File

@ -30,6 +30,7 @@ public interface ErrorCodeConstants {
// role // role
ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "chatRole 不存在!"); ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "chatRole 不存在!");
ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!");
// modal // modal

View File

@ -32,7 +32,7 @@ public class AiChatMessageController {
@PostMapping("/send") @PostMapping("/send")
public CommonResult<AiChatMessageRespVO> sendMessage(@Validated @ModelAttribute AiChatMessageSendReqVO sendReqVO) { public CommonResult<AiChatMessageRespVO> sendMessage(@Validated @ModelAttribute AiChatMessageSendReqVO sendReqVO) {
// TODO @fan使用 static import这样就 success 就行了 // TODO @fan使用 static import这样就 success 就行了
return success(null); return success(chatService.chat(sendReqVO));
} }
// TODO @芋艿调用这个方法异常Unable to handle the Spring Security Exception because the response is already committed.可以再试试 // TODO @芋艿调用这个方法异常Unable to handle the Spring Security Exception because the response is already committed.可以再试试

View File

@ -21,6 +21,9 @@ public class AiChatConversationRespVO {
@Schema(description = "是否置顶", requiredMode = Schema.RequiredMode.REQUIRED, example = "true") @Schema(description = "是否置顶", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
private Boolean pinned; private Boolean pinned;
@Schema(description = "角色编号", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1")
private Long roleId;
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Long modelId; private Long modelId;

View File

@ -52,4 +52,11 @@ public interface AiChatModalService {
* @return * @return
*/ */
AiChatModalRes getChatModalOfValidate(Long modalId); AiChatModalRes getChatModalOfValidate(Long modalId);
/**
* 校验 - 校验是否可用
*
* @param chatModal
*/
void validateAvailable(AiChatModalRes chatModal);
} }

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.role.*; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.role.*;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
/** /**
* chat 角色 * chat 角色
@ -58,4 +59,19 @@ public interface AiChatRoleService {
* @return * @return
*/ */
AiChatRoleRes getChatRole(Long roleId); AiChatRoleRes getChatRole(Long roleId);
/**
* 校验 - 角色是否存在
*
* @param id
* @return
*/
AiChatRoleDO validateExists(Long id);
/**
* 校验 - 角色是否公开
*
* @param aiChatRoleDO
*/
void validateIsPublic(AiChatRoleDO aiChatRoleDO);
} }

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service; package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
/** /**
@ -15,10 +16,10 @@ public interface AiChatService {
/** /**
* chat * chat
* *
* @param req * @param sendReqVO
* @return * @return
*/ */
String chat(AiChatMessageSendReqVO req); AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO);
/** /**
* chat stream * chat stream

View File

@ -15,7 +15,6 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum;
import cn.iocoder.yudao.module.ai.service.AiChatConversationService; import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.AiChatModalService; import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import cn.iocoder.yudao.module.ai.service.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
@ -91,9 +90,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
// 获取模型信息并验证 // 获取模型信息并验证
AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(updateReqVO.getModelId()); AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(updateReqVO.getModelId());
// 校验modal是否可用 // 校验modal是否可用
if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) { aiChatModalService.validateAvailable(chatModal);
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
}
// 更新对话信息 // 更新对话信息
AiChatConversationDO updateAiChatConversationDO AiChatConversationDO updateAiChatConversationDO
= AiChatConversationConvert.INSTANCE.convertAiChatConversationDO(updateReqVO); = AiChatConversationConvert.INSTANCE.convertAiChatConversationDO(updateReqVO);

View File

@ -116,6 +116,14 @@ public class AiChatModalServiceImpl implements AiChatModalService {
return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO); return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);
} }
@Override
public void validateAvailable(AiChatModalRes chatModal) {
// 对话模型是否可用
if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
}
}
private AiChatModalDO validateChatModalExists(Long id) { private AiChatModalDO validateChatModalExists(Long id) {
AiChatModalDO aiChatModalDO = aiChatModalMapper.selectById(id); AiChatModalDO aiChatModalDO = aiChatModalMapper.selectById(id);
if (aiChatModalDO == null) { if (aiChatModalDO == null) {

View File

@ -70,7 +70,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
AiChatRoleClassifyEnum.valueOfClassify(req.getClassify()); AiChatRoleClassifyEnum.valueOfClassify(req.getClassify());
AiChatRoleEnableEnum.valueOfType(req.getEnable()); AiChatRoleEnableEnum.valueOfType(req.getEnable());
// 检查角色是否存在 // 检查角色是否存在
validateChatRoleExists(id); validateExists(id);
// 转换do // 转换do
AiChatRoleDO updateChatRole = AiChatRoleConvert.INSTANCE.convertAiChatRoleDO(req); AiChatRoleDO updateChatRole = AiChatRoleConvert.INSTANCE.convertAiChatRoleDO(req);
updateChatRole.setId(id); updateChatRole.setId(id);
@ -83,7 +83,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
// 转换enum并校验enum // 转换enum并校验enum
AiChatRoleEnableEnum.valueOfType(req.getEnable()); AiChatRoleEnableEnum.valueOfType(req.getEnable());
// 检查角色是否存在 // 检查角色是否存在
validateChatRoleExists(id); validateExists(id);
// 更新 // 更新
aiChatRoleMapper.updateById(new AiChatRoleDO() aiChatRoleMapper.updateById(new AiChatRoleDO()
.setId(id) .setId(id)
@ -94,7 +94,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
@Override @Override
public void delete(Long chatRoleId) { public void delete(Long chatRoleId) {
// 检查角色是否存在 // 检查角色是否存在
validateChatRoleExists(chatRoleId); validateExists(chatRoleId);
// 删除 // 删除
aiChatRoleMapper.deleteById(chatRoleId); aiChatRoleMapper.deleteById(chatRoleId);
} }
@ -102,15 +102,25 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
@Override @Override
public AiChatRoleRes getChatRole(Long roleId) { public AiChatRoleRes getChatRole(Long roleId) {
// 检查角色是否存在 // 检查角色是否存在
AiChatRoleDO aiChatRoleDO = validateChatRoleExists(roleId); AiChatRoleDO aiChatRoleDO = validateExists(roleId);
return AiChatRoleConvert.INSTANCE.convertAiChatRoleRes(aiChatRoleDO); return AiChatRoleConvert.INSTANCE.convertAiChatRoleRes(aiChatRoleDO);
} }
private AiChatRoleDO validateChatRoleExists(Long id) { public AiChatRoleDO validateExists(Long id) {
AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(id); AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(id);
if (aiChatRoleDO == null) { if (aiChatRoleDO == null) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXIST); throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXIST);
} }
return aiChatRoleDO; return aiChatRoleDO;
} }
public void validateIsPublic(AiChatRoleDO aiChatRoleDO) {
if (aiChatRoleDO == null) {
return;
}
if (!aiChatRoleDO.getPublicStatus()) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_PUBLIC);
}
}
} }

View File

@ -10,14 +10,19 @@ import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper; import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper;
import cn.iocoder.yudao.module.ai.service.AiChatConversationService; import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.AiChatService; import cn.iocoder.yudao.module.ai.service.AiChatService;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
@ -45,29 +50,39 @@ public class AiChatServiceImpl implements AiChatService {
private final AiChatMessageMapper aiChatMessageMapper; private final AiChatMessageMapper aiChatMessageMapper;
private final AiChatConversationMapper aiChatConversationMapper; private final AiChatConversationMapper aiChatConversationMapper;
private final AiChatConversationService chatConversationService; private final AiChatConversationService chatConversationService;
private final AiChatModalService aiChatModalService;
private final AiChatRoleService aiChatRoleService;
/**
* chat
*
* @param req
* @return
*/
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public String chat(AiChatMessageSendReqVO req) { public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询对话
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
// 获取对话模型
AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
// 对话模型是否可用
aiChatModalService.validateAvailable(chatModal);
// 获取角色信息
AiChatRoleDO aiChatRoleDO = null;
if (conversation.getRoleId() != null) {
aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
}
// 校验角色是否公开
aiChatRoleService.validateIsPublic(aiChatRoleDO);
// 获取 client 类型 // 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal()); AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
// 获取对话信息
AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
// 保存 chat message // 保存 chat message
saveChatMessage(req, conversationRes, loginUserId); insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
String content = null; String content = null;
try { try {
// 创建 chat 需要的 Prompt // 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt()); Prompt prompt = new Prompt(req.getContent());
req.setTopK(req.getTopK()); // TODO @芋艿 @范 看要不要支持这些
req.setTopP(req.getTopP()); // req.setTopK(req.getTopK());
req.setTemperature(req.getTemperature()); // req.setTopP(req.getTopP());
// req.setTemperature(req.getTemperature());
// 发送 call 调用 // 发送 call 调用
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum); ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
ChatResponse call = chatClient.call(prompt); ChatResponse call = chatClient.call(prompt);
@ -78,69 +93,66 @@ public class AiChatServiceImpl implements AiChatService {
content = ExceptionUtil.getMessage(e); content = ExceptionUtil.getMessage(e);
} finally { } finally {
// 保存 chat message // 保存 chat message
saveSystemChatMessage(req, conversationRes, loginUserId, content); insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
} }
return content; return new AiChatMessageRespVO().setContent(content);
} }
private void saveChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId) { private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId,
Long chatConversationId = conversationRes.getId(); String model, Long modelId, String content, Integer tokens, Double temperature,
Integer maxTokens, Integer maxContexts) {
AiChatMessageDO insertChatMessageDO = new AiChatMessageDO()
.setId(null)
.setConversationId(conversationId)
.setType(messageType.getValue())
.setUserId(loginUserId)
.setRoleId(roleId)
.setModel(model)
.setModelId(modelId)
.setContent(content)
.setTokens(tokens)
.setTemperature(temperature)
.setMaxTokens(maxTokens)
.setMaxContexts(maxContexts);
// 增加 chat message 记录 // 增加 chat message 记录
aiChatMessageMapper.insert( aiChatMessageMapper.insert(insertChatMessageDO);
new AiChatMessageDO()
.setId(null)
.setConversationId(chatConversationId)
.setUserId(loginUserId)
.setMessage(req.getPrompt())
.setMessageType(MessageType.USER.getValue())
.setTopK(req.getTopK())
.setTopP(req.getTopP())
.setTemperature(req.getTemperature())
);
// chat count +1 // chat count +1
aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); aiChatConversationMapper.updateIncrChatCount(conversationId);
return insertChatMessageDO;
} }
public void saveSystemChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId, String systemPrompts) {
Long chatConversationId = conversationRes.getId();
// 增加 chat message 记录
aiChatMessageMapper.insert(
new AiChatMessageDO()
.setId(null)
.setConversationId(chatConversationId)
.setUserId(loginUserId)
.setMessage(systemPrompts)
.setMessageType(MessageType.SYSTEM.getValue())
.setTopK(req.getTopK())
.setTopP(req.getTopP())
.setTemperature(req.getTemperature())
);
// chat count +1
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
}
/**
* chat stream
*
* @param req
* @param sseEmitter
* @return
*/
@Override @Override
public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) { public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型 // 查询对话
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal()); AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
// 获取对话信息 // 获取对话模型
AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId()); AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
// 对话模型是否可用
aiChatModalService.validateAvailable(chatModal);
// 获取角色信息
AiChatRoleDO aiChatRoleDO = null;
if (conversation.getRoleId() != null) {
aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
}
// 校验角色是否公开
aiChatRoleService.validateIsPublic(aiChatRoleDO);
// 创建 chat 需要的 Prompt // 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt()); Prompt prompt = new Prompt(req.getContent());
req.setTopK(req.getTopK()); // req.setTopK(req.getTopK());
req.setTopP(req.getTopP()); // req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature()); // req.setTemperature(req.getTemperature());
// 保存 chat message // 保存 chat message
saveChatMessage(req, conversationRes, loginUserId); // 保存 chat message
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt); Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
@ -168,7 +180,10 @@ public class AiChatServiceImpl implements AiChatService {
log.info("发送完成!"); log.info("发送完成!");
sseEmitter.complete(); sseEmitter.complete();
// 保存 chat message // 保存 chat message
saveSystemChatMessage(req, conversationRes, loginUserId, contentBuffer.toString()); insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
} }
); );
} }