【调整】调整AI对话模块

This commit is contained in:
cherishsince 2024-05-07 10:04:49 +08:00
parent f9854273cc
commit f86e24bb86
11 changed files with 136 additions and 116 deletions

View File

@ -0,0 +1,16 @@
package cn.iocoder.yudao.module.ai;
/**
* ai 常用的常量
*
* @author fansili
* @time 2024/5/7 09:29
* @since 1.0
*/
public class AiCommonConstants {
/**
* 对话 - 默认 title
*/
public static final String CONVERSATION_DEFAULT_TITLE = "新增对话";
}

View File

@ -15,6 +15,10 @@ public interface ErrorCodeConstants {
ErrorCode AI_MODULE_NOT_SUPPORTED = new ErrorCode(1_022_000_000, "AI 模型暂不支持!");
ErrorCode AI_CHAT_ROLE_NOT_EXISTENT = new ErrorCode(1_022_000_001, "AI Role 不存在!");;
// conversation
ErrorCode AI_CONVERSATION_NOT_EXISTS = new ErrorCode(1_022_000_002, "AI 对话不存在!");;
ErrorCode AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL = new ErrorCode(1_022_000_002, "chat 继续对话,对话 id 不能为空!");;
ErrorCode AI_CHAT_CONTINUE_NOT_EXIST = new ErrorCode(1_022_000_020, "chat 对话不存在!");
ErrorCode AI_CHAT_CONVERSATION_NOT_YOURS = new ErrorCode(1_022_000_021, "这条 chat 对话不是你的!");

View File

@ -2,12 +2,15 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateReqVO;
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.web.bind.annotation.*;
@ -16,33 +19,36 @@ import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
@Slf4j
@Tag(name = "管理后台 - 聊天会话")
@RestController
@RequestMapping("/ai/chat/conversation")
@Slf4j
@AllArgsConstructor
public class AiChatConversationController {
// TODO @fan实现一下
private final AiChatConversationService aiChatConversationService;
// TODO done @fan实现一下
@PostMapping("/create")
@Operation(summary = "创建聊天会话")
@PreAuthorize("@ss.hasPermission('ai:chat-conversation:create')")
public CommonResult<Long> createConversation(@RequestBody @Valid AiChatConversationCreateReqVO createReqVO) {
return success(1L);
return success(aiChatConversationService.createConversation(createReqVO));
}
// TODO @fan实现一下
// TODO done @fan实现一下
@PutMapping("/update")
@Operation(summary = "更新聊天会话")
@PreAuthorize("@ss.hasPermission('ai:chat-conversation:create')")
public CommonResult<Boolean> updateConversation(@RequestBody @Valid AiChatConversationUpdateReqVO updateReqVO) {
return success(true);
return success(aiChatConversationService.updateConversation(updateReqVO));
}
// TODO @fan实现一下
// TODO done @fan实现一下
@GetMapping("/list")
@Operation(summary = "获得聊天会话列表")
public CommonResult<List<AiChatConversationRespVO>> getConversationList() {
return success(null);
public CommonResult<List<AiChatConversationRespVO>> getConversationList(@ModelAttribute AiChatConversationListReqVO listReqVO) {
return success(aiChatConversationService.listConversation(listReqVO));
}
// TODO @fan实现一下
@ -50,7 +56,7 @@ public class AiChatConversationController {
@Operation(summary = "获得聊天会话")
@Parameter(name = "id", required = true, description = "会话编号", example = "1024")
public CommonResult<AiChatConversationRespVO> getConversation(@RequestParam("id") Long id) {
return success(null);
return success(aiChatConversationService.getConversationOfValidate(id));
}
// TODO @fan实现一下
@ -58,7 +64,7 @@ public class AiChatConversationController {
@Operation(summary = "删除聊天会话")
@Parameter(name = "id", required = true, description = "会话编号", example = "1024")
public CommonResult<Boolean> deleteConversation(@RequestParam("id") Long id) {
return success(null);
return success(aiChatConversationService.deleteConversation(id));
}
}

View File

@ -0,0 +1,13 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Schema(description = "管理后台 - AI 聊天会话 Response VO")
@Data
public class AiChatConversationListReqVO {
@Schema(description = "会话标题", requiredMode = Schema.RequiredMode.REQUIRED, example = "我是一个标题")
private String title;
}

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.convert;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import org.mapstruct.Mapper;
@ -34,4 +35,11 @@ public interface AiChatConversationConvert {
* @return
*/
AiChatConversationRespVO covnertChatConversationRes(AiChatConversationDO aiChatConversationDO);
/**
* 转换 - AiChatConversationDO
*
* @param updateReqVO
*/
AiChatConversationDO convertAiChatConversationDO(AiChatConversationUpdateReqVO updateReqVO);
}

View File

@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import org.apache.ibatis.annotations.Mapper;
import org.springframework.stereotype.Repository;

View File

@ -1,7 +1,9 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateReqVO;
import java.util.List;
@ -14,29 +16,20 @@ import java.util.List;
public interface AiChatConversationService {
/**
* 对话 - 创建普通对话
* 对话 - 创建对话
*
* @param req
* @return
*/
AiChatConversationRespVO createConversation(AiChatConversationCreateUserReq req);
Long createConversation(AiChatConversationCreateReqVO req);
/**
* 对话 - 创建role对话
* 对话 - 更新对话
*
* @param req
* @param updateReqVO
* @return
*/
AiChatConversationRespVO createRoleConversation(AiChatConversationCreateReqVO req);
/**
* 获取 - 对话
*
* @param id
* @return
*/
AiChatConversationRespVO getConversation(Long id);
Boolean updateConversation(AiChatConversationUpdateReqVO updateReqVO);
/**
* 获取 - 对话列表
@ -44,22 +37,21 @@ public interface AiChatConversationService {
* @param req
* @return
*/
List<AiChatConversationRespVO> listConversation(AiChatConversationListReq req);
List<AiChatConversationRespVO> listConversation(AiChatConversationListReqVO req);
/**
* 更新 - 更新模型
* 获取 - 对话
*
* @param id
* @param modalId
* @return
*/
void updateModal(Long id, Long modalId);
AiChatConversationRespVO getConversationOfValidate(Long id);
/**
* 删除 - 根据id
*
* @param id
*/
void delete(Long id);
Boolean deleteConversation(Long id);
}

View File

@ -51,5 +51,5 @@ public interface AiChatModalService {
* @param modalId
* @return
*/
AiChatModalRes getChatModal(Long modalId);
AiChatModalRes getChatModalOfValidate(Long modalId);
}

View File

@ -2,17 +2,20 @@ package cn.iocoder.yudao.module.ai.service.impl;
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.AiCommonConstants;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.role.AiChatRoleRes;
import cn.iocoder.yudao.module.ai.convert.AiChatConversationConvert;
import cn.iocoder.yudao.module.ai.enums.AiChatConversationTypeEnum;
import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum;
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
@ -34,119 +37,95 @@ import java.util.List;
@AllArgsConstructor
public class AiChatConversationServiceImpl implements AiChatConversationService {
private final AiChatRoleMapper aiChatRoleMapper;
private final AiChatModalMapper aiChatModalMapper;
private final AiChatConversationMapper aiChatConversationMapper;
private final AiChatModalService aiChatModalService;
private final AiChatRoleService aiChatRoleService;
private final AiChatConversationMapper aiChatConversationMapper;
@Override
public AiChatConversationRespVO createConversation(AiChatConversationCreateUserReq req) {
public Long createConversation(AiChatConversationCreateReqVO req) {
// 获取用户id
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询最新的对话
AiChatConversationDO latestConversation = aiChatConversationMapper.selectLatestConversation(loginUserId);
// 如果有对话没有被使用过那就返回这个
if (latestConversation != null && latestConversation.getChatCount() <= 0) {
return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation);
}
// 获取第一个模型
// 默认使用 sort 排序第一个模型
AiChatModalDO aiChatModalDO = aiChatModalMapper.selectFirstModal();
// 创建新的 Conversation
AiChatConversationDO insertConversation = saveConversation(req.getTitle(), loginUserId,
null, null, AiChatConversationTypeEnum.USER_CHAT,
aiChatModalDO.getId(), aiChatModalDO.getModal());
// 转换 res
return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
}
@Override
public AiChatConversationRespVO createRoleConversation(AiChatConversationCreateReqVO req) {
// 获取用户id
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询最新的对话
// AiChatConversationDO latestConversation = aiChatConversationMapper.selectLatestConversation(loginUserId);
// // 如果有对话没有被使用过那就返回这个
// if (latestConversation != null && latestConversation.getChatCount() <= 0) {
// return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation);
// }
// 查询角色
AiChatRoleRes chatRoleRes = aiChatRoleService.getChatRole(req.getRoleId());
// 获取第一个模型
AiChatModalDO aiChatModalDO = aiChatModalMapper.selectFirstModal();
AiChatRoleRes chatRoleRes = null;
if (req.getRoleId() != null) {
chatRoleRes = aiChatRoleService.getChatRole(req.getRoleId());
}
Long chatRoleId = chatRoleRes != null ? chatRoleRes.getId() : null;
// 创建新的 Conversation
AiChatConversationDO insertConversation = saveConversation(req.getTitle(), loginUserId,
req.getRoleId(), chatRoleRes.getName(), AiChatConversationTypeEnum.ROLE_CHAT,
aiChatModalDO.getId(), aiChatModalDO.getModal());
// 转换 res
return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
AiChatConversationDO insertConversation = saveConversation(AiCommonConstants.CONVERSATION_DEFAULT_TITLE,
loginUserId, chatRoleId, aiChatModalDO.getId(), aiChatModalDO.getModel()
);
// 返回对话id
return insertConversation.getId();
}
private @NotNull AiChatConversationDO saveConversation(String title,
Long userId,
Long roleId,
String roleName,
AiChatConversationTypeEnum typeEnum,
Long modalId,
String modal) {
String model) {
AiChatConversationDO insertConversation = new AiChatConversationDO();
insertConversation.setId(null);
insertConversation.setUserId(userId);
insertConversation.setRoleId(roleId);
insertConversation.setRoleName(roleName);
insertConversation.setTitle(title);
insertConversation.setChatCount(0);
insertConversation.setType(typeEnum.getType());
insertConversation.setModalId(modalId);
insertConversation.setModal(modal);
insertConversation.setPinned(false);
insertConversation.setRoleId(roleId);
insertConversation.setModelId(modalId);
insertConversation.setModel(model);
insertConversation.setTemperature(null);
insertConversation.setMaxTokens(null);
insertConversation.setMaxContexts(null);
aiChatConversationMapper.insert(insertConversation);
return insertConversation;
}
@Override
public AiChatConversationRespVO getConversation(Long id) {
public Boolean updateConversation(AiChatConversationUpdateReqVO updateReqVO) {
// 校验对话是否存在
validateExists(updateReqVO.getId());
// 获取模型信息并验证
AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(updateReqVO.getModelId());
// 校验modal是否可用
if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
}
// 更新对话信息
AiChatConversationDO updateAiChatConversationDO
= AiChatConversationConvert.INSTANCE.convertAiChatConversationDO(updateReqVO);
return aiChatConversationMapper.updateById(updateAiChatConversationDO) > 0;
}
@Override
public List<AiChatConversationRespVO> listConversation(AiChatConversationListReqVO listReqVO) {
// 获取用户id
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询前100对话
List<AiChatConversationDO> top100Conversation
= aiChatConversationMapper.selectTop100Conversation(loginUserId, listReqVO.getTitle());
return AiChatConversationConvert.INSTANCE.covnertChatConversationResList(top100Conversation);
}
@Override
public AiChatConversationRespVO getConversationOfValidate(Long id) {
AiChatConversationDO aiChatConversationDO = validateExists(id);
return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(aiChatConversationDO);
}
@Override
public Boolean deleteConversation(Long id) {
return aiChatConversationMapper.deleteById(id) > 0;
}
private @NotNull AiChatConversationDO validateExists(Long id) {
AiChatConversationDO aiChatConversationDO = aiChatConversationMapper.selectById(id);
if (aiChatConversationDO == null) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_CONTINUE_NOT_EXIST);
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CONVERSATION_NOT_EXISTS);
}
return aiChatConversationDO;
}
@Override
public List<AiChatConversationRespVO> listConversation(AiChatConversationListReq req) {
// 获取用户id
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询前100对话
List<AiChatConversationDO> top100Conversation
= aiChatConversationMapper.selectTop100Conversation(loginUserId, req.getSearch());
return AiChatConversationConvert.INSTANCE.covnertChatConversationResList(top100Conversation);
}
@Override
public void updateModal(Long id, Long modalId) {
// 校验对话是否存在
validateExists(id);
// 获取模型
AiChatModalRes chatModal = aiChatModalService.getChatModal(modalId);
// 判断模型是否禁用
if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
}
// 更新对话
aiChatConversationMapper.updateById(new AiChatConversationDO()
.setId(id)
.setModalId(chatModal.getId())
.setModal(chatModal.getModal())
);
}
@Override
public void delete(Long id) {
aiChatConversationMapper.deleteById(id);
}
}

View File

@ -12,6 +12,7 @@ import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.convert.AiChatModalConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalChatConfigVO;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalDallConfigVO;
@ -109,7 +110,7 @@ public class AiChatModalServiceImpl implements AiChatModalService {
}
@Override
public AiChatModalRes getChatModal(Long modalId) {
public AiChatModalRes getChatModalOfValidate(Long modalId) {
// 检查 modal 是否存在
AiChatModalDO aiChatModalDO = validateChatModalExists(modalId);
return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);

View File

@ -58,7 +58,7 @@ public class AiChatServiceImpl implements AiChatService {
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
// 获取对话信息
AiChatConversationRespVO conversationRes = chatConversationService.getConversation(req.getConversationId());
AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
// 保存 chat message
saveChatMessage(req, conversationRes, loginUserId);
String content = null;
@ -133,7 +133,7 @@ public class AiChatServiceImpl implements AiChatService {
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
// 获取对话信息
AiChatConversationRespVO conversationRes = chatConversationService.getConversation(req.getConversationId());
AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt());
req.setTopK(req.getTopK());