mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2025-01-19 03:30:06 +08:00
聊天对话,增加 创建对话、还是继续对话逻辑
This commit is contained in:
parent
a2bd9b710e
commit
7794992225
@ -0,0 +1,34 @@
|
|||||||
|
package cn.iocoder.yudao.module.ai.enums;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 聊天类型
|
||||||
|
*
|
||||||
|
* @author fansili
|
||||||
|
* @time 2024/4/14 17:58
|
||||||
|
* @since 1.0
|
||||||
|
*/
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Getter
|
||||||
|
public enum ChatTypeEnum {
|
||||||
|
|
||||||
|
ROLE_CHAT("roleChat", "角色模板聊天"),
|
||||||
|
USER_CHAT("userChat", "用户普通聊天"),
|
||||||
|
|
||||||
|
;
|
||||||
|
|
||||||
|
private String type;
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
public static ChatTypeEnum valueOfType(String type) {
|
||||||
|
for (ChatTypeEnum itemEnum : ChatTypeEnum.values()) {
|
||||||
|
if (itemEnum.getType().equals(type)) {
|
||||||
|
return itemEnum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("Invalid MessageType value: " + type);
|
||||||
|
}
|
||||||
|
}
|
@ -23,12 +23,12 @@ public class AiChatMessageDO {
|
|||||||
/**
|
/**
|
||||||
* 聊天ID,关联到特定的会话或对话
|
* 聊天ID,关联到特定的会话或对话
|
||||||
*/
|
*/
|
||||||
private Long chatId;
|
private Long chatConversationId;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 角色ID,用于标识发送消息的用户或系统的身份
|
* 角色ID,用于标识发送消息的用户或系统的身份
|
||||||
*/
|
*/
|
||||||
private String userId;
|
private Long userId;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 消息具体内容,存储用户的发言或者系统响应的文字信息
|
* 消息具体内容,存储用户的发言或者系统响应的文字信息
|
||||||
@ -38,7 +38,7 @@ public class AiChatMessageDO {
|
|||||||
/**
|
/**
|
||||||
* 消息类型,枚举值可能包括'system'(系统消息)、'user'(用户消息)和'assistant'(助手消息)
|
* 消息类型,枚举值可能包括'system'(系统消息)、'user'(用户消息)和'assistant'(助手消息)
|
||||||
*/
|
*/
|
||||||
private Double messageType;
|
private String messageType;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 在生成消息时采用的Top-K采样大小,
|
* 在生成消息时采用的Top-K采样大小,
|
||||||
|
@ -1,14 +1,28 @@
|
|||||||
package cn.iocoder.yudao.module.ai.service.impl;
|
package cn.iocoder.yudao.module.ai.service.impl;
|
||||||
|
|
||||||
|
import cn.hutool.core.exceptions.ExceptionUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
|
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
|
||||||
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
||||||
import cn.iocoder.yudao.framework.ai.config.AiClient;
|
import cn.iocoder.yudao.framework.ai.config.AiClient;
|
||||||
|
import cn.iocoder.yudao.framework.common.exception.ServerException;
|
||||||
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||||
|
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
||||||
|
import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO;
|
||||||
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
|
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
|
||||||
|
import cn.iocoder.yudao.module.ai.enums.ChatConversationTypeEnum;
|
||||||
|
import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum;
|
||||||
|
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
|
||||||
|
import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
|
||||||
|
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
|
||||||
import cn.iocoder.yudao.module.ai.service.ChatService;
|
import cn.iocoder.yudao.module.ai.service.ChatService;
|
||||||
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -24,6 +38,10 @@ import reactor.core.publisher.Flux;
|
|||||||
public class ChatServiceImpl implements ChatService {
|
public class ChatServiceImpl implements ChatService {
|
||||||
|
|
||||||
private final AiClient aiClient;
|
private final AiClient aiClient;
|
||||||
|
private final AiChatRoleMapper aiChatRoleMapper;
|
||||||
|
private final AiChatMessageMapper aiChatMessageMapper;
|
||||||
|
private final AiChatConversationMapper aiChatConversationMapper;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* chat
|
* chat
|
||||||
@ -31,16 +49,84 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
* @param req
|
* @param req
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public String chat(ChatReq req) {
|
public String chat(ChatReq req) {
|
||||||
|
// 获取 client 类型
|
||||||
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
|
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
|
||||||
// 创建 chat 需要的 Prompt
|
// 获取 对话类型(新建还是继续)
|
||||||
Prompt prompt = new Prompt(req.getPrompt());
|
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
|
||||||
req.setTopK(req.getTopK());
|
|
||||||
req.setTopP(req.getTopP());
|
AiChatConversationDO aiChatConversationDO;
|
||||||
req.setTemperature(req.getTemperature());
|
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||||
// 发送 call 调用
|
if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
|
||||||
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
|
// 创建一个新的对话
|
||||||
return call.getResult().getOutput().getContent();
|
aiChatConversationDO = createNewChatConversation(req, loginUserId);
|
||||||
|
} else {
|
||||||
|
// 继续对话
|
||||||
|
if (req.getConversationId() == null) {
|
||||||
|
throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
|
||||||
|
}
|
||||||
|
aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
|
||||||
|
}
|
||||||
|
|
||||||
|
String content;
|
||||||
|
try {
|
||||||
|
// 创建 chat 需要的 Prompt
|
||||||
|
Prompt prompt = new Prompt(req.getPrompt());
|
||||||
|
req.setTopK(req.getTopK());
|
||||||
|
req.setTopP(req.getTopP());
|
||||||
|
req.setTemperature(req.getTemperature());
|
||||||
|
// 发送 call 调用
|
||||||
|
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
|
||||||
|
content = call.getResult().getOutput().getContent();
|
||||||
|
} catch (Exception e) {
|
||||||
|
content = ExceptionUtil.getMessage(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 增加 chat message 记录
|
||||||
|
aiChatMessageMapper.insert(
|
||||||
|
new AiChatMessageDO()
|
||||||
|
.setId(null)
|
||||||
|
.setChatConversationId(aiChatConversationDO.getId())
|
||||||
|
.setUserId(loginUserId)
|
||||||
|
.setMessage(req.getPrompt())
|
||||||
|
.setMessageType(MessageType.USER.getValue())
|
||||||
|
.setTopK(req.getTopK())
|
||||||
|
.setTopP(req.getTopP())
|
||||||
|
.setTemperature(req.getTemperature())
|
||||||
|
);
|
||||||
|
|
||||||
|
// chat count 先+1
|
||||||
|
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
|
||||||
|
// 获取 chat 角色
|
||||||
|
String chatRoleName = null;
|
||||||
|
ChatTypeEnum chatTypeEnum = null;
|
||||||
|
Long chatRoleId = req.getChatRoleId();
|
||||||
|
if (req.getChatRoleId() != null) {
|
||||||
|
AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(chatRoleId);
|
||||||
|
if (aiChatRoleDO == null) {
|
||||||
|
throw new ServerException(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXISTENT);
|
||||||
|
}
|
||||||
|
chatTypeEnum = ChatTypeEnum.ROLE_CHAT;
|
||||||
|
chatRoleName = aiChatRoleDO.getRoleName();
|
||||||
|
} else {
|
||||||
|
chatTypeEnum = ChatTypeEnum.USER_CHAT;
|
||||||
|
}
|
||||||
|
//
|
||||||
|
AiChatConversationDO insertChatConversation = new AiChatConversationDO()
|
||||||
|
.setId(null)
|
||||||
|
.setUserId(loginUserId)
|
||||||
|
.setChatRoleId(req.getChatRoleId())
|
||||||
|
.setChatRoleName(chatRoleName)
|
||||||
|
.setChatType(chatTypeEnum.getType())
|
||||||
|
.setChatCount(1)
|
||||||
|
.setChatTitle(req.getPrompt().substring(0, 20) + "...");
|
||||||
|
aiChatConversationMapper.insert(insertChatConversation);
|
||||||
|
return insertChatConversation;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -24,19 +24,29 @@ public class ChatReq {
|
|||||||
@Schema(description = "填入固定值,1 issues, 2 pr")
|
@Schema(description = "填入固定值,1 issues, 2 pr")
|
||||||
private String prompt;
|
private String prompt;
|
||||||
|
|
||||||
|
@Schema(description = "chat角色模板")
|
||||||
|
private Long chatRoleId;
|
||||||
|
|
||||||
@Schema(description = "用于控制随机性和多样性的温度参数")
|
@Schema(description = "用于控制随机性和多样性的温度参数")
|
||||||
private Float temperature;
|
private Double temperature;
|
||||||
|
|
||||||
@Schema(description = "生成时,核采样方法的概率阈值。例如,取值为0.8时,仅保留累计概率之和大于等于0.8的概率分布中的token,\n" +
|
@Schema(description = "生成时,核采样方法的概率阈值。例如,取值为0.8时,仅保留累计概率之和大于等于0.8的概率分布中的token,\n" +
|
||||||
" * 作为随机采样的候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。\n" +
|
" * 作为随机采样的候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。\n" +
|
||||||
" * 默认值为0.8。注意,取值不要大于等于1\n")
|
" * 默认值为0.8。注意,取值不要大于等于1\n")
|
||||||
private Float topP;
|
private Double topP;
|
||||||
|
|
||||||
@Schema(description = "在生成消息时采用的Top-K采样大小,表示模型生成回复时考虑的候选项集合的大小")
|
@Schema(description = "在生成消息时采用的Top-K采样大小,表示模型生成回复时考虑的候选项集合的大小")
|
||||||
private Integer topK;
|
private Double topK;
|
||||||
|
|
||||||
@Schema(description = "ai模型(查看 AiClientNameEnum)")
|
@Schema(description = "ai模型(查看 AiClientNameEnum)")
|
||||||
@NotNull(message = "模型不能为空!")
|
@NotNull(message = "模型不能为空!")
|
||||||
@Size(max = 30, message = "模型字符最大30个字符!")
|
@Size(max = 30, message = "模型字符最大30个字符!")
|
||||||
private String modal;
|
private String modal;
|
||||||
|
|
||||||
|
@Schema(description = "对话类型(new、continue)")
|
||||||
|
@NotNull(message = "对话类型,不能为空!")
|
||||||
|
private String conversationType;
|
||||||
|
|
||||||
|
@Schema(description = "对话Id")
|
||||||
|
private Long conversationId;
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ public abstract class AbstractMessage implements Message {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected AbstractMessage(MessageType messageType, String textContent, List<MediaData> mediaData,
|
protected AbstractMessage(MessageType messageType, String textContent, List<MediaData> mediaData,
|
||||||
Map<String, Object> messageProperties) {
|
Map<String, Object> messageProperties) {
|
||||||
|
|
||||||
Assert.notNull(messageType, "Message type must not be null");
|
Assert.notNull(messageType, "Message type must not be null");
|
||||||
Assert.notNull(textContent, "Content must not be null");
|
Assert.notNull(textContent, "Content must not be null");
|
||||||
|
Loading…
Reference in New Issue
Block a user