【调整】调整AI聊天模块

This commit is contained in:
cherishsince 2024-05-07 10:45:37 +08:00
parent f86e24bb86
commit 424210066f
10 changed files with 139 additions and 81 deletions

View File

@ -30,6 +30,7 @@ public interface ErrorCodeConstants {
// role
ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "chatRole 不存在!");
ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!");
// modal

View File

@ -32,7 +32,7 @@ public class AiChatMessageController {
@PostMapping("/send")
public CommonResult<AiChatMessageRespVO> sendMessage(@Validated @ModelAttribute AiChatMessageSendReqVO sendReqVO) {
// TODO @fan使用 static import这样就 success 就行了
return success(null);
return success(chatService.chat(sendReqVO));
}
// TODO @芋艿调用这个方法异常Unable to handle the Spring Security Exception because the response is already committed.可以再试试

View File

@ -21,6 +21,9 @@ public class AiChatConversationRespVO {
@Schema(description = "是否置顶", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
private Boolean pinned;
@Schema(description = "角色编号", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1")
private Long roleId;
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Long modelId;

View File

@ -52,4 +52,11 @@ public interface AiChatModalService {
* @return
*/
AiChatModalRes getChatModalOfValidate(Long modalId);
/**
* 校验 - 校验是否可用
*
* @param chatModal
*/
void validateAvailable(AiChatModalRes chatModal);
}

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.role.*;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
/**
* chat 角色
@ -58,4 +59,19 @@ public interface AiChatRoleService {
* @return
*/
AiChatRoleRes getChatRole(Long roleId);
/**
* 校验 - 角色是否存在
*
* @param id
* @return
*/
AiChatRoleDO validateExists(Long id);
/**
* 校验 - 角色是否公开
*
* @param aiChatRoleDO
*/
void validateIsPublic(AiChatRoleDO aiChatRoleDO);
}

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
/**
@ -15,10 +16,10 @@ public interface AiChatService {
/**
* chat
*
* @param req
* @param sendReqVO
* @return
*/
String chat(AiChatMessageSendReqVO req);
AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO);
/**
* chat stream

View File

@ -15,7 +15,6 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum;
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
@ -91,9 +90,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
// 获取模型信息并验证
AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(updateReqVO.getModelId());
// 校验modal是否可用
if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
}
aiChatModalService.validateAvailable(chatModal);
// 更新对话信息
AiChatConversationDO updateAiChatConversationDO
= AiChatConversationConvert.INSTANCE.convertAiChatConversationDO(updateReqVO);

View File

@ -116,6 +116,14 @@ public class AiChatModalServiceImpl implements AiChatModalService {
return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);
}
@Override
public void validateAvailable(AiChatModalRes chatModal) {
// 对话模型是否可用
if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
}
}
private AiChatModalDO validateChatModalExists(Long id) {
AiChatModalDO aiChatModalDO = aiChatModalMapper.selectById(id);
if (aiChatModalDO == null) {

View File

@ -70,7 +70,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
AiChatRoleClassifyEnum.valueOfClassify(req.getClassify());
AiChatRoleEnableEnum.valueOfType(req.getEnable());
// 检查角色是否存在
validateChatRoleExists(id);
validateExists(id);
// 转换do
AiChatRoleDO updateChatRole = AiChatRoleConvert.INSTANCE.convertAiChatRoleDO(req);
updateChatRole.setId(id);
@ -83,7 +83,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
// 转换enum并校验enum
AiChatRoleEnableEnum.valueOfType(req.getEnable());
// 检查角色是否存在
validateChatRoleExists(id);
validateExists(id);
// 更新
aiChatRoleMapper.updateById(new AiChatRoleDO()
.setId(id)
@ -94,7 +94,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
@Override
public void delete(Long chatRoleId) {
// 检查角色是否存在
validateChatRoleExists(chatRoleId);
validateExists(chatRoleId);
// 删除
aiChatRoleMapper.deleteById(chatRoleId);
}
@ -102,15 +102,25 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
@Override
public AiChatRoleRes getChatRole(Long roleId) {
// 检查角色是否存在
AiChatRoleDO aiChatRoleDO = validateChatRoleExists(roleId);
AiChatRoleDO aiChatRoleDO = validateExists(roleId);
return AiChatRoleConvert.INSTANCE.convertAiChatRoleRes(aiChatRoleDO);
}
private AiChatRoleDO validateChatRoleExists(Long id) {
public AiChatRoleDO validateExists(Long id) {
AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(id);
if (aiChatRoleDO == null) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXIST);
}
return aiChatRoleDO;
}
public void validateIsPublic(AiChatRoleDO aiChatRoleDO) {
if (aiChatRoleDO == null) {
return;
}
if (!aiChatRoleDO.getPublicStatus()) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_PUBLIC);
}
}
}

View File

@ -10,14 +10,19 @@ import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper;
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.AiChatService;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
@ -45,29 +50,39 @@ public class AiChatServiceImpl implements AiChatService {
private final AiChatMessageMapper aiChatMessageMapper;
private final AiChatConversationMapper aiChatConversationMapper;
private final AiChatConversationService chatConversationService;
private final AiChatModalService aiChatModalService;
private final AiChatRoleService aiChatRoleService;
/**
* chat
*
* @param req
* @return
*/
@Transactional(rollbackFor = Exception.class)
public String chat(AiChatMessageSendReqVO req) {
public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询对话
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
// 获取对话模型
AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
// 对话模型是否可用
aiChatModalService.validateAvailable(chatModal);
// 获取角色信息
AiChatRoleDO aiChatRoleDO = null;
if (conversation.getRoleId() != null) {
aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
}
// 校验角色是否公开
aiChatRoleService.validateIsPublic(aiChatRoleDO);
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
// 获取对话信息
AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
// 保存 chat message
saveChatMessage(req, conversationRes, loginUserId);
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
String content = null;
try {
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt());
req.setTopK(req.getTopK());
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
Prompt prompt = new Prompt(req.getContent());
// TODO @芋艿 @范 看要不要支持这些
// req.setTopK(req.getTopK());
// req.setTopP(req.getTopP());
// req.setTemperature(req.getTemperature());
// 发送 call 调用
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
ChatResponse call = chatClient.call(prompt);
@ -78,69 +93,66 @@ public class AiChatServiceImpl implements AiChatService {
content = ExceptionUtil.getMessage(e);
} finally {
// 保存 chat message
saveSystemChatMessage(req, conversationRes, loginUserId, content);
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
}
return content;
return new AiChatMessageRespVO().setContent(content);
}
private void saveChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId) {
Long chatConversationId = conversationRes.getId();
private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId,
String model, Long modelId, String content, Integer tokens, Double temperature,
Integer maxTokens, Integer maxContexts) {
AiChatMessageDO insertChatMessageDO = new AiChatMessageDO()
.setId(null)
.setConversationId(conversationId)
.setType(messageType.getValue())
.setUserId(loginUserId)
.setRoleId(roleId)
.setModel(model)
.setModelId(modelId)
.setContent(content)
.setTokens(tokens)
.setTemperature(temperature)
.setMaxTokens(maxTokens)
.setMaxContexts(maxContexts);
// 增加 chat message 记录
aiChatMessageMapper.insert(
new AiChatMessageDO()
.setId(null)
.setConversationId(chatConversationId)
.setUserId(loginUserId)
.setMessage(req.getPrompt())
.setMessageType(MessageType.USER.getValue())
.setTopK(req.getTopK())
.setTopP(req.getTopP())
.setTemperature(req.getTemperature())
);
aiChatMessageMapper.insert(insertChatMessageDO);
// chat count +1
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
aiChatConversationMapper.updateIncrChatCount(conversationId);
return insertChatMessageDO;
}
public void saveSystemChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId, String systemPrompts) {
Long chatConversationId = conversationRes.getId();
// 增加 chat message 记录
aiChatMessageMapper.insert(
new AiChatMessageDO()
.setId(null)
.setConversationId(chatConversationId)
.setUserId(loginUserId)
.setMessage(systemPrompts)
.setMessageType(MessageType.SYSTEM.getValue())
.setTopK(req.getTopK())
.setTopP(req.getTopP())
.setTemperature(req.getTemperature())
);
// chat count +1
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
}
/**
* chat stream
*
* @param req
* @param sseEmitter
* @return
*/
@Override
public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
// 获取对话信息
AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
// 查询对话
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
// 获取对话模型
AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
// 对话模型是否可用
aiChatModalService.validateAvailable(chatModal);
// 获取角色信息
AiChatRoleDO aiChatRoleDO = null;
if (conversation.getRoleId() != null) {
aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
}
// 校验角色是否公开
aiChatRoleService.validateIsPublic(aiChatRoleDO);
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt());
req.setTopK(req.getTopK());
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
Prompt prompt = new Prompt(req.getContent());
// req.setTopK(req.getTopK());
// req.setTopP(req.getTopP());
// req.setTemperature(req.getTemperature());
// 保存 chat message
saveChatMessage(req, conversationRes, loginUserId);
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
@ -168,7 +180,10 @@ public class AiChatServiceImpl implements AiChatService {
log.info("发送完成!");
sseEmitter.complete();
// 保存 chat message
saveSystemChatMessage(req, conversationRes, loginUserId, contentBuffer.toString());
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
}
);
}