mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2024-11-23 07:41:53 +08:00
【调整】调整AI聊天模块
This commit is contained in:
parent
f86e24bb86
commit
424210066f
@ -30,6 +30,7 @@ public interface ErrorCodeConstants {
|
||||
// role
|
||||
|
||||
ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "chatRole 不存在!");
|
||||
ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!");
|
||||
|
||||
// modal
|
||||
|
||||
|
@ -32,7 +32,7 @@ public class AiChatMessageController {
|
||||
@PostMapping("/send")
|
||||
public CommonResult<AiChatMessageRespVO> sendMessage(@Validated @ModelAttribute AiChatMessageSendReqVO sendReqVO) {
|
||||
// TODO @fan:使用 static import;这样就 success 就行了;
|
||||
return success(null);
|
||||
return success(chatService.chat(sendReqVO));
|
||||
}
|
||||
|
||||
// TODO @芋艿:调用这个方法异常,Unable to handle the Spring Security Exception because the response is already committed.;可以再试试
|
||||
|
@ -21,6 +21,9 @@ public class AiChatConversationRespVO {
|
||||
@Schema(description = "是否置顶", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
|
||||
private Boolean pinned;
|
||||
|
||||
@Schema(description = "角色编号", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1")
|
||||
private Long roleId;
|
||||
|
||||
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
private Long modelId;
|
||||
|
||||
|
@ -52,4 +52,11 @@ public interface AiChatModalService {
|
||||
* @return
|
||||
*/
|
||||
AiChatModalRes getChatModalOfValidate(Long modalId);
|
||||
|
||||
/**
|
||||
* 校验 - 校验是否可用
|
||||
*
|
||||
* @param chatModal
|
||||
*/
|
||||
void validateAvailable(AiChatModalRes chatModal);
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.role.*;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
|
||||
/**
|
||||
* chat 角色
|
||||
@ -58,4 +59,19 @@ public interface AiChatRoleService {
|
||||
* @return
|
||||
*/
|
||||
AiChatRoleRes getChatRole(Long roleId);
|
||||
|
||||
/**
|
||||
* 校验 - 角色是否存在
|
||||
*
|
||||
* @param id
|
||||
* @return
|
||||
*/
|
||||
AiChatRoleDO validateExists(Long id);
|
||||
|
||||
/**
|
||||
* 校验 - 角色是否公开
|
||||
*
|
||||
* @param aiChatRoleDO
|
||||
*/
|
||||
void validateIsPublic(AiChatRoleDO aiChatRoleDO);
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package cn.iocoder.yudao.module.ai.service;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||
|
||||
/**
|
||||
@ -15,10 +16,10 @@ public interface AiChatService {
|
||||
/**
|
||||
* chat
|
||||
*
|
||||
* @param req
|
||||
* @param sendReqVO
|
||||
* @return
|
||||
*/
|
||||
String chat(AiChatMessageSendReqVO req);
|
||||
AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO);
|
||||
|
||||
/**
|
||||
* chat stream
|
||||
|
@ -15,7 +15,6 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper;
|
||||
import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
|
||||
@ -91,9 +90,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
|
||||
// 获取模型信息并验证
|
||||
AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(updateReqVO.getModelId());
|
||||
// 校验modal是否可用
|
||||
if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
|
||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
|
||||
}
|
||||
aiChatModalService.validateAvailable(chatModal);
|
||||
// 更新对话信息
|
||||
AiChatConversationDO updateAiChatConversationDO
|
||||
= AiChatConversationConvert.INSTANCE.convertAiChatConversationDO(updateReqVO);
|
||||
|
@ -116,6 +116,14 @@ public class AiChatModalServiceImpl implements AiChatModalService {
|
||||
return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validateAvailable(AiChatModalRes chatModal) {
|
||||
// 对话模型是否可用
|
||||
if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
|
||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
|
||||
}
|
||||
}
|
||||
|
||||
private AiChatModalDO validateChatModalExists(Long id) {
|
||||
AiChatModalDO aiChatModalDO = aiChatModalMapper.selectById(id);
|
||||
if (aiChatModalDO == null) {
|
||||
|
@ -70,7 +70,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
|
||||
AiChatRoleClassifyEnum.valueOfClassify(req.getClassify());
|
||||
AiChatRoleEnableEnum.valueOfType(req.getEnable());
|
||||
// 检查角色是否存在
|
||||
validateChatRoleExists(id);
|
||||
validateExists(id);
|
||||
// 转换do
|
||||
AiChatRoleDO updateChatRole = AiChatRoleConvert.INSTANCE.convertAiChatRoleDO(req);
|
||||
updateChatRole.setId(id);
|
||||
@ -83,7 +83,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
|
||||
// 转换enum,并校验enum
|
||||
AiChatRoleEnableEnum.valueOfType(req.getEnable());
|
||||
// 检查角色是否存在
|
||||
validateChatRoleExists(id);
|
||||
validateExists(id);
|
||||
// 更新
|
||||
aiChatRoleMapper.updateById(new AiChatRoleDO()
|
||||
.setId(id)
|
||||
@ -94,7 +94,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
|
||||
@Override
|
||||
public void delete(Long chatRoleId) {
|
||||
// 检查角色是否存在
|
||||
validateChatRoleExists(chatRoleId);
|
||||
validateExists(chatRoleId);
|
||||
// 删除
|
||||
aiChatRoleMapper.deleteById(chatRoleId);
|
||||
}
|
||||
@ -102,15 +102,25 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
|
||||
@Override
|
||||
public AiChatRoleRes getChatRole(Long roleId) {
|
||||
// 检查角色是否存在
|
||||
AiChatRoleDO aiChatRoleDO = validateChatRoleExists(roleId);
|
||||
AiChatRoleDO aiChatRoleDO = validateExists(roleId);
|
||||
return AiChatRoleConvert.INSTANCE.convertAiChatRoleRes(aiChatRoleDO);
|
||||
}
|
||||
|
||||
private AiChatRoleDO validateChatRoleExists(Long id) {
|
||||
public AiChatRoleDO validateExists(Long id) {
|
||||
AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(id);
|
||||
if (aiChatRoleDO == null) {
|
||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXIST);
|
||||
}
|
||||
return aiChatRoleDO;
|
||||
}
|
||||
|
||||
public void validateIsPublic(AiChatRoleDO aiChatRoleDO) {
|
||||
if (aiChatRoleDO == null) {
|
||||
return;
|
||||
}
|
||||
if (!aiChatRoleDO.getPublicStatus()) {
|
||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_PUBLIC);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10,14 +10,19 @@ import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
|
||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
|
||||
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.MediaType;
|
||||
@ -45,29 +50,39 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
private final AiChatMessageMapper aiChatMessageMapper;
|
||||
private final AiChatConversationMapper aiChatConversationMapper;
|
||||
private final AiChatConversationService chatConversationService;
|
||||
private final AiChatModalService aiChatModalService;
|
||||
private final AiChatRoleService aiChatRoleService;
|
||||
|
||||
/**
|
||||
* chat
|
||||
*
|
||||
* @param req
|
||||
* @return
|
||||
*/
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public String chat(AiChatMessageSendReqVO req) {
|
||||
public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
|
||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||
// 查询对话
|
||||
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
||||
// 获取对话模型
|
||||
AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
|
||||
// 对话模型是否可用
|
||||
aiChatModalService.validateAvailable(chatModal);
|
||||
// 获取角色信息
|
||||
AiChatRoleDO aiChatRoleDO = null;
|
||||
if (conversation.getRoleId() != null) {
|
||||
aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
|
||||
}
|
||||
// 校验角色是否公开
|
||||
aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
||||
// 获取 client 类型
|
||||
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
|
||||
// 获取对话信息
|
||||
AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
|
||||
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
|
||||
// 保存 chat message
|
||||
saveChatMessage(req, conversationRes, loginUserId);
|
||||
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
||||
chatModal.getModal(), chatModal.getId(), req.getContent(),
|
||||
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||
String content = null;
|
||||
try {
|
||||
// 创建 chat 需要的 Prompt
|
||||
Prompt prompt = new Prompt(req.getPrompt());
|
||||
req.setTopK(req.getTopK());
|
||||
req.setTopP(req.getTopP());
|
||||
req.setTemperature(req.getTemperature());
|
||||
Prompt prompt = new Prompt(req.getContent());
|
||||
// TODO @芋艿 @范 看要不要支持这些
|
||||
// req.setTopK(req.getTopK());
|
||||
// req.setTopP(req.getTopP());
|
||||
// req.setTemperature(req.getTemperature());
|
||||
// 发送 call 调用
|
||||
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
|
||||
ChatResponse call = chatClient.call(prompt);
|
||||
@ -78,69 +93,66 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
content = ExceptionUtil.getMessage(e);
|
||||
} finally {
|
||||
// 保存 chat message
|
||||
saveSystemChatMessage(req, conversationRes, loginUserId, content);
|
||||
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
||||
chatModal.getModal(), chatModal.getId(), req.getContent(),
|
||||
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||
}
|
||||
return content;
|
||||
return new AiChatMessageRespVO().setContent(content);
|
||||
}
|
||||
|
||||
private void saveChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId) {
|
||||
Long chatConversationId = conversationRes.getId();
|
||||
private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId,
|
||||
String model, Long modelId, String content, Integer tokens, Double temperature,
|
||||
Integer maxTokens, Integer maxContexts) {
|
||||
AiChatMessageDO insertChatMessageDO = new AiChatMessageDO()
|
||||
.setId(null)
|
||||
.setConversationId(conversationId)
|
||||
.setType(messageType.getValue())
|
||||
.setUserId(loginUserId)
|
||||
.setRoleId(roleId)
|
||||
.setModel(model)
|
||||
.setModelId(modelId)
|
||||
.setContent(content)
|
||||
.setTokens(tokens)
|
||||
|
||||
.setTemperature(temperature)
|
||||
.setMaxTokens(maxTokens)
|
||||
.setMaxContexts(maxContexts);
|
||||
// 增加 chat message 记录
|
||||
aiChatMessageMapper.insert(
|
||||
new AiChatMessageDO()
|
||||
.setId(null)
|
||||
.setConversationId(chatConversationId)
|
||||
.setUserId(loginUserId)
|
||||
.setMessage(req.getPrompt())
|
||||
.setMessageType(MessageType.USER.getValue())
|
||||
.setTopK(req.getTopK())
|
||||
.setTopP(req.getTopP())
|
||||
.setTemperature(req.getTemperature())
|
||||
);
|
||||
aiChatMessageMapper.insert(insertChatMessageDO);
|
||||
// chat count 先+1
|
||||
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
||||
aiChatConversationMapper.updateIncrChatCount(conversationId);
|
||||
return insertChatMessageDO;
|
||||
}
|
||||
|
||||
public void saveSystemChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId, String systemPrompts) {
|
||||
Long chatConversationId = conversationRes.getId();
|
||||
// 增加 chat message 记录
|
||||
aiChatMessageMapper.insert(
|
||||
new AiChatMessageDO()
|
||||
.setId(null)
|
||||
.setConversationId(chatConversationId)
|
||||
.setUserId(loginUserId)
|
||||
.setMessage(systemPrompts)
|
||||
.setMessageType(MessageType.SYSTEM.getValue())
|
||||
.setTopK(req.getTopK())
|
||||
.setTopP(req.getTopP())
|
||||
.setTemperature(req.getTemperature())
|
||||
);
|
||||
|
||||
// chat count 先+1
|
||||
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
||||
}
|
||||
|
||||
/**
|
||||
* chat stream
|
||||
*
|
||||
* @param req
|
||||
* @param sseEmitter
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) {
|
||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||
// 获取 client 类型
|
||||
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
|
||||
// 获取对话信息
|
||||
AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
|
||||
// 查询对话
|
||||
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
||||
// 获取对话模型
|
||||
AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
|
||||
// 对话模型是否可用
|
||||
aiChatModalService.validateAvailable(chatModal);
|
||||
// 获取角色信息
|
||||
AiChatRoleDO aiChatRoleDO = null;
|
||||
if (conversation.getRoleId() != null) {
|
||||
aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
|
||||
}
|
||||
// 校验角色是否公开
|
||||
aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
||||
// 创建 chat 需要的 Prompt
|
||||
Prompt prompt = new Prompt(req.getPrompt());
|
||||
req.setTopK(req.getTopK());
|
||||
req.setTopP(req.getTopP());
|
||||
req.setTemperature(req.getTemperature());
|
||||
Prompt prompt = new Prompt(req.getContent());
|
||||
// req.setTopK(req.getTopK());
|
||||
// req.setTopP(req.getTopP());
|
||||
// req.setTemperature(req.getTemperature());
|
||||
// 保存 chat message
|
||||
saveChatMessage(req, conversationRes, loginUserId);
|
||||
// 保存 chat message
|
||||
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
||||
chatModal.getModal(), chatModal.getId(), req.getContent(),
|
||||
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||
|
||||
// 获取 client 类型
|
||||
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
|
||||
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
|
||||
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
|
||||
|
||||
@ -168,7 +180,10 @@ public class AiChatServiceImpl implements AiChatService {
|
||||
log.info("发送完成!");
|
||||
sseEmitter.complete();
|
||||
// 保存 chat message
|
||||
saveSystemChatMessage(req, conversationRes, loginUserId, contentBuffer.toString());
|
||||
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
||||
chatModal.getModal(), chatModal.getId(), req.getContent(),
|
||||
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||
|
||||
}
|
||||
);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user