diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ChatTypeEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ChatTypeEnum.java new file mode 100644 index 000000000..cdf7c3ba0 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ChatTypeEnum.java @@ -0,0 +1,34 @@ +package cn.iocoder.yudao.module.ai.enums; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * 聊天类型 + * + * @author fansili + * @time 2024/4/14 17:58 + * @since 1.0 + */ +@AllArgsConstructor +@Getter +public enum ChatTypeEnum { + + ROLE_CHAT("roleChat", "角色模板聊天"), + USER_CHAT("userChat", "用户普通聊天"), + + ; + + private String type; + + private String name; + + public static ChatTypeEnum valueOfType(String type) { + for (ChatTypeEnum itemEnum : ChatTypeEnum.values()) { + if (itemEnum.getType().equals(type)) { + return itemEnum; + } + } + throw new IllegalArgumentException("Invalid MessageType value: " + type); + } +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dataobject/AiChatMessageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dataobject/AiChatMessageDO.java index 416b3f0f9..9047180ec 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dataobject/AiChatMessageDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dataobject/AiChatMessageDO.java @@ -23,12 +23,12 @@ public class AiChatMessageDO { /** * 聊天ID,关联到特定的会话或对话 */ - private Long chatId; + private Long chatConversationId; /** * 角色ID,用于标识发送消息的用户或系统的身份 */ - private String userId; + private Long userId; /** * 消息具体内容,存储用户的发言或者系统响应的文字信息 @@ -38,7 +38,7 @@ public class AiChatMessageDO { /** * 消息类型,枚举值可能包括'system'(系统消息)、'user'(用户消息)和'assistant'(助手消息) */ - private Double messageType; + private String messageType; /** * 在生成消息时采用的Top-K采样大小, diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java index 8531594db..6fe226181 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java @@ -1,14 +1,28 @@ package cn.iocoder.yudao.module.ai.service.impl; +import cn.hutool.core.exceptions.ExceptionUtil; import cn.iocoder.yudao.framework.ai.chat.ChatResponse; +import cn.iocoder.yudao.framework.ai.chat.messages.MessageType; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.config.AiClient; +import cn.iocoder.yudao.framework.common.exception.ServerException; +import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; +import cn.iocoder.yudao.module.ai.ErrorCodeConstants; +import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO; +import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO; +import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO; import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum; +import cn.iocoder.yudao.module.ai.enums.ChatConversationTypeEnum; +import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum; +import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper; +import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper; +import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper; import cn.iocoder.yudao.module.ai.service.ChatService; import cn.iocoder.yudao.module.ai.vo.ChatReq; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; /** @@ -24,6 +38,10 @@ import reactor.core.publisher.Flux; public class ChatServiceImpl implements ChatService { private final AiClient aiClient; + private final AiChatRoleMapper aiChatRoleMapper; + private final AiChatMessageMapper aiChatMessageMapper; + private final AiChatConversationMapper aiChatConversationMapper; + /** * chat @@ -31,16 +49,84 @@ public class ChatServiceImpl implements ChatService { * @param req * @return */ + @Transactional(rollbackFor = Exception.class) public String chat(ChatReq req) { + // 获取 client 类型 AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); - // 创建 chat 需要的 Prompt - Prompt prompt = new Prompt(req.getPrompt()); - req.setTopK(req.getTopK()); - req.setTopP(req.getTopP()); - req.setTemperature(req.getTemperature()); - // 发送 call 调用 - ChatResponse call = aiClient.call(prompt, clientNameEnum.getName()); - return call.getResult().getOutput().getContent(); + // 获取 对话类型(新建还是继续) + ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); + + AiChatConversationDO aiChatConversationDO; + Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); + if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) { + // 创建一个新的对话 + aiChatConversationDO = createNewChatConversation(req, loginUserId); + } else { + // 继续对话 + if (req.getConversationId() == null) { + throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL); + } + aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId()); + } + + String content; + try { + // 创建 chat 需要的 Prompt + Prompt prompt = new Prompt(req.getPrompt()); + req.setTopK(req.getTopK()); + req.setTopP(req.getTopP()); + req.setTemperature(req.getTemperature()); + // 发送 call 调用 + ChatResponse call = aiClient.call(prompt, clientNameEnum.getName()); + content = call.getResult().getOutput().getContent(); + } catch (Exception e) { + content = ExceptionUtil.getMessage(e); + } + + // 增加 chat message 记录 + aiChatMessageMapper.insert( + new AiChatMessageDO() + .setId(null) + .setChatConversationId(aiChatConversationDO.getId()) + .setUserId(loginUserId) + .setMessage(req.getPrompt()) + .setMessageType(MessageType.USER.getValue()) + .setTopK(req.getTopK()) + .setTopP(req.getTopP()) + .setTemperature(req.getTemperature()) + ); + + // chat count 先+1 + aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); + return content; + } + + private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) { + // 获取 chat 角色 + String chatRoleName = null; + ChatTypeEnum chatTypeEnum = null; + Long chatRoleId = req.getChatRoleId(); + if (req.getChatRoleId() != null) { + AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(chatRoleId); + if (aiChatRoleDO == null) { + throw new ServerException(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXISTENT); + } + chatTypeEnum = ChatTypeEnum.ROLE_CHAT; + chatRoleName = aiChatRoleDO.getRoleName(); + } else { + chatTypeEnum = ChatTypeEnum.USER_CHAT; + } + // + AiChatConversationDO insertChatConversation = new AiChatConversationDO() + .setId(null) + .setUserId(loginUserId) + .setChatRoleId(req.getChatRoleId()) + .setChatRoleName(chatRoleName) + .setChatType(chatTypeEnum.getType()) + .setChatCount(1) + .setChatTitle(req.getPrompt().substring(0, 20) + "..."); + aiChatConversationMapper.insert(insertChatConversation); + return insertChatConversation; } /** diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/vo/ChatReq.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/vo/ChatReq.java index 754cb0572..321d429c5 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/vo/ChatReq.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/vo/ChatReq.java @@ -24,19 +24,29 @@ public class ChatReq { @Schema(description = "填入固定值,1 issues, 2 pr") private String prompt; + @Schema(description = "chat角色模板") + private Long chatRoleId; + @Schema(description = "用于控制随机性和多样性的温度参数") - private Float temperature; + private Double temperature; @Schema(description = "生成时,核采样方法的概率阈值。例如,取值为0.8时,仅保留累计概率之和大于等于0.8的概率分布中的token,\n" + " * 作为随机采样的候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。\n" + " * 默认值为0.8。注意,取值不要大于等于1\n") - private Float topP; + private Double topP; @Schema(description = "在生成消息时采用的Top-K采样大小,表示模型生成回复时考虑的候选项集合的大小") - private Integer topK; + private Double topK; @Schema(description = "ai模型(查看 AiClientNameEnum)") @NotNull(message = "模型不能为空!") @Size(max = 30, message = "模型字符最大30个字符!") private String modal; + + @Schema(description = "对话类型(new、continue)") + @NotNull(message = "对话类型,不能为空!") + private String conversationType; + + @Schema(description = "对话Id") + private Long conversationId; } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chat/messages/AbstractMessage.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chat/messages/AbstractMessage.java index d622cf9f8..8c8e60d5a 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chat/messages/AbstractMessage.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chat/messages/AbstractMessage.java @@ -59,7 +59,7 @@ public abstract class AbstractMessage implements Message { } protected AbstractMessage(MessageType messageType, String textContent, List mediaData, - Map messageProperties) { + Map messageProperties) { Assert.notNull(messageType, "Message type must not be null"); Assert.notNull(textContent, "Content must not be null");