增加创建role对话

This commit is contained in:
cherishsince 2024-04-23 16:29:26 +08:00
parent dc44fe1cf9
commit 1ab1538afe
9 changed files with 125 additions and 69 deletions

View File

@ -15,8 +15,9 @@ import lombok.Getter;
@Getter @Getter
public enum ChatConversationTypeEnum { public enum ChatConversationTypeEnum {
NEW("new", "新建对话"), // roleChatuserChat
CONTINUE("continue", "继续对话"), ROLE_CHAT("roleChat", "角色对话"),
USER_CHAT("userChat", "用户对话"),
; ;

View File

@ -2,7 +2,8 @@ package cn.iocoder.yudao.module.ai.controller;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.service.ChatConversationService; import cn.iocoder.yudao.module.ai.service.ChatConversationService;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateReq; import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateRoleReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateUserReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq; import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationRes; import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
@ -30,10 +31,16 @@ public class ChatConversationController {
private final ChatConversationService chatConversationService; private final ChatConversationService chatConversationService;
@Operation(summary = "创建 - 对话") @Operation(summary = "创建 - 对话普通对话")
@PostMapping("/create") @PostMapping("/createConversation")
public CommonResult<ChatConversationRes> create(@RequestBody @Validated ChatConversationCreateReq req) { public CommonResult<ChatConversationRes> createConversation(@RequestBody @Validated ChatConversationCreateUserReq req) {
return CommonResult.success(chatConversationService.create(req)); return CommonResult.success(chatConversationService.createConversation(req));
}
@Operation(summary = "创建 - 对话角色对话")
@PostMapping("/createRoleConversation")
public CommonResult<ChatConversationRes> createRoleConversation(@RequestBody @Validated ChatConversationCreateRoleReq req) {
return CommonResult.success(chatConversationService.createRoleConversation(req));
} }
@Operation(summary = "获取 - 获取对话") @Operation(summary = "获取 - 获取对话")

View File

@ -24,7 +24,11 @@ import java.util.List;
@Mapper @Mapper
public interface AiChatConversationMapper extends BaseMapperX<AiChatConversationDO> { public interface AiChatConversationMapper extends BaseMapperX<AiChatConversationDO> {
/**
* 更新 - chat count
*
* @param id
*/
@Update("update ai_chat_conversation set chat_count = chat_count + 1 where id = #{id}") @Update("update ai_chat_conversation set chat_count = chat_count + 1 where id = #{id}")
void updateIncrChatCount(@Param("id") Long id); void updateIncrChatCount(@Param("id") Long id);

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service; package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateReq; import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateRoleReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateUserReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq; import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationRes; import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
@ -15,12 +16,21 @@ import java.util.List;
public interface ChatConversationService { public interface ChatConversationService {
/** /**
* 对话 - 创建 * 对话 - 创建普通对话
* *
* @param req * @param req
* @return * @return
*/ */
ChatConversationRes create(ChatConversationCreateReq req); ChatConversationRes createConversation(ChatConversationCreateUserReq req);
/**
* 对话 - 创建role对话
*
* @param req
* @return
*/
ChatConversationRes createRoleConversation(ChatConversationCreateRoleReq req);
/** /**
* 获取 - 对话 * 获取 - 对话
@ -44,4 +54,5 @@ public interface ChatConversationService {
* @param id * @param id
*/ */
void delete(Long id); void delete(Long id);
} }

View File

@ -5,13 +5,18 @@ import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.convert.ChatConversationConvert; import cn.iocoder.yudao.module.ai.convert.ChatConversationConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.enums.ChatConversationTypeEnum;
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
import cn.iocoder.yudao.module.ai.service.ChatConversationService; import cn.iocoder.yudao.module.ai.service.ChatConversationService;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateReq; import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateRoleReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateUserReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq; import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationRes; import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
@ -27,10 +32,11 @@ import java.util.List;
@AllArgsConstructor @AllArgsConstructor
public class ChatConversationServiceImpl implements ChatConversationService { public class ChatConversationServiceImpl implements ChatConversationService {
private final AiChatRoleMapper aiChatRoleMapper;
private final AiChatConversationMapper aiChatConversationMapper; private final AiChatConversationMapper aiChatConversationMapper;
@Override @Override
public ChatConversationRes create(ChatConversationCreateReq req) { public ChatConversationRes createConversation(ChatConversationCreateUserReq req) {
// 获取用户id // 获取用户id
Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询最新的对话 // 查询最新的对话
@ -40,19 +46,47 @@ public class ChatConversationServiceImpl implements ChatConversationService {
return ChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation); return ChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation);
} }
// 创建新的 Conversation // 创建新的 Conversation
AiChatConversationDO insertConversation = new AiChatConversationDO(); AiChatConversationDO insertConversation = saveConversation(req.getTitle(), loginUserId,
insertConversation.setId(null); null, null, ChatConversationTypeEnum.USER_CHAT);
insertConversation.setUserId(loginUserId);
insertConversation.setChatRoleId(null);
insertConversation.setChatRoleName(null);
insertConversation.setTitle(null);
insertConversation.setChatCount(0);
insertConversation.setType(req.getChatType());
aiChatConversationMapper.insert(insertConversation);
// 转换 res // 转换 res
return ChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation); return ChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
} }
@Override
public ChatConversationRes createRoleConversation(ChatConversationCreateRoleReq req) {
// 获取用户id
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询最新的对话
AiChatConversationDO latestConversation = aiChatConversationMapper.selectLatestConversation(loginUserId);
// 如果有对话没有被使用过那就返回这个
if (latestConversation != null && latestConversation.getChatCount() <= 0) {
return ChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation);
}
AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(req.getChatRoleId());
// 创建新的 Conversation
AiChatConversationDO insertConversation = saveConversation(req.getTitle(), loginUserId,
req.getChatRoleId(), aiChatRoleDO.getRoleName(), ChatConversationTypeEnum.ROLE_CHAT);
// 转换 res
return ChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
}
private @NotNull AiChatConversationDO saveConversation(String title,
Long userId,
Long chatRoleId,
String chatRoleName,
ChatConversationTypeEnum typeEnum) {
AiChatConversationDO insertConversation = new AiChatConversationDO();
insertConversation.setId(null);
insertConversation.setUserId(userId);
insertConversation.setChatRoleId(chatRoleId);
insertConversation.setChatRoleName(chatRoleName);
insertConversation.setTitle(title);
insertConversation.setChatCount(0);
insertConversation.setType(typeEnum.getType());
aiChatConversationMapper.insert(insertConversation);
return insertConversation;
}
@Override @Override
public ChatConversationRes getConversation(Long id) { public ChatConversationRes getConversation(Long id) {
AiChatConversationDO aiChatConversationDO = aiChatConversationMapper.selectById(id); AiChatConversationDO aiChatConversationDO = aiChatConversationMapper.selectById(id);

View File

@ -5,15 +5,10 @@ import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType; import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.config.AiClient; import cn.iocoder.yudao.framework.ai.config.AiClient;
import cn.iocoder.yudao.framework.common.exception.ServerException;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum; import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum;
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper; import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
@ -49,7 +44,6 @@ public class ChatServiceImpl implements ChatService {
private final AiChatConversationMapper aiChatConversationMapper; private final AiChatConversationMapper aiChatConversationMapper;
private final ChatConversationService chatConversationService; private final ChatConversationService chatConversationService;
/** /**
* chat * chat
* *
@ -64,7 +58,7 @@ public class ChatServiceImpl implements ChatService {
// 获取对话信息 // 获取对话信息
ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId()); ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
// 保存 chat message // 保存 chat message
saveChatMessage(req, conversationRes.getId(), loginUserId); saveChatMessage(req, conversationRes, loginUserId);
String content = null; String content = null;
try { try {
// 创建 chat 需要的 Prompt // 创建 chat 需要的 Prompt
@ -75,16 +69,19 @@ public class ChatServiceImpl implements ChatService {
// 发送 call 调用 // 发送 call 调用
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName()); ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
content = call.getResult().getOutput().getContent(); content = call.getResult().getOutput().getContent();
// 更新 conversation
} catch (Exception e) { } catch (Exception e) {
content = ExceptionUtil.getMessage(e); content = ExceptionUtil.getMessage(e);
} finally { } finally {
// 保存 chat message // 保存 chat message
saveSystemChatMessage(req, conversationRes.getId(), loginUserId, content); saveSystemChatMessage(req, conversationRes, loginUserId, content);
} }
return content; return content;
} }
private void saveChatMessage(ChatReq req, Long chatConversationId, Long loginUserId) { private void saveChatMessage(ChatReq req, ChatConversationRes conversationRes, Long loginUserId) {
Long chatConversationId = conversationRes.getId();
// 增加 chat message 记录 // 增加 chat message 记录
aiChatMessageMapper.insert( aiChatMessageMapper.insert(
new AiChatMessageDO() new AiChatMessageDO()
@ -97,12 +94,12 @@ public class ChatServiceImpl implements ChatService {
.setTopP(req.getTopP()) .setTopP(req.getTopP())
.setTemperature(req.getTemperature()) .setTemperature(req.getTemperature())
); );
// chat count +1 // chat count +1
aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
} }
public void saveSystemChatMessage(ChatReq req, Long chatConversationId, Long loginUserId, String systemPrompts) { public void saveSystemChatMessage(ChatReq req, ChatConversationRes conversationRes, Long loginUserId, String systemPrompts) {
Long chatConversationId = conversationRes.getId();
// 增加 chat message 记录 // 增加 chat message 记录
aiChatMessageMapper.insert( aiChatMessageMapper.insert(
new AiChatMessageDO() new AiChatMessageDO()
@ -120,34 +117,6 @@ public class ChatServiceImpl implements ChatService {
aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
} }
private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
// 获取 chat 角色
String chatRoleName = null;
ChatTypeEnum chatTypeEnum = null;
Long chatRoleId = req.getChatRoleId();
if (req.getChatRoleId() != null) {
AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(chatRoleId);
if (aiChatRoleDO == null) {
throw new ServerException(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXISTENT);
}
chatTypeEnum = ChatTypeEnum.ROLE_CHAT;
chatRoleName = aiChatRoleDO.getRoleName();
} else {
chatTypeEnum = ChatTypeEnum.USER_CHAT;
}
//
AiChatConversationDO insertChatConversation = new AiChatConversationDO()
.setId(null)
.setUserId(loginUserId)
.setChatRoleId(req.getChatRoleId())
.setChatRoleName(chatRoleName)
.setType(chatTypeEnum.getType())
.setChatCount(1)
.setTitle(req.getPrompt().substring(0, 20) + "...");
aiChatConversationMapper.insert(insertChatConversation);
return insertChatConversation;
}
/** /**
* chat stream * chat stream
* *
@ -168,7 +137,7 @@ public class ChatServiceImpl implements ChatService {
req.setTopP(req.getTopP()); req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature()); req.setTemperature(req.getTemperature());
// 保存 chat message // 保存 chat message
saveChatMessage(req, conversationRes.getId(), loginUserId); saveChatMessage(req, conversationRes, loginUserId);
Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName()); Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
@ -195,7 +164,7 @@ public class ChatServiceImpl implements ChatService {
log.info("发送完成!"); log.info("发送完成!");
sseEmitter.complete(); sseEmitter.complete();
// 保存 chat message // 保存 chat message
saveSystemChatMessage(req, conversationRes.getId(), loginUserId, contentBuffer.toString()); saveSystemChatMessage(req, conversationRes, loginUserId, contentBuffer.toString());
} }
); );
} }

View File

@ -0,0 +1,26 @@
package cn.iocoder.yudao.module.ai.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* 聊天对话
*
* @author fansili
* @time 2024/4/18 16:24
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class ChatConversationCreateRoleReq {
@Schema(description = "chat角色Id")
@NotNull(message = "聊天角色id不能为空!")
private Long chatRoleId;
@Schema(description = "标题(有程序自动生成)")
@NotNull(message = "标题不能为空!")
private String title;
}

View File

@ -14,10 +14,9 @@ import lombok.experimental.Accessors;
*/ */
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public class ChatConversationCreateReq { public class ChatConversationCreateUserReq {
@Schema(description = "对话类型(roleChat、userChat)")
@NotNull(message = "聊天类型不能为空!")
private String chatType;
@Schema(description = "对话标题")
@NotNull(message = "标题不能为空!")
private String title;
} }

View File

@ -16,7 +16,12 @@ GET {{baseUrl}}/ai/chat/conversation/1781604279872581644
Authorization: {{token}} Authorization: {{token}}
### 对话 - id获取 ### 对话 - list
GET {{baseUrl}}/ai/chat/conversation/list
Authorization: {{token}}
### 对话 - 删除
DELETE {{baseUrl}}/ai/chat/conversation/1781604279872581644 DELETE {{baseUrl}}/ai/chat/conversation/1781604279872581644
Authorization: {{token}} Authorization: {{token}}