【优化】AI:简化 AiChatMessageDO 消息表,去除 tokens、temperature、maxTokens、maxContexts 字段,因为 spring-ai 没有返回 tokens 相关的字段

This commit is contained in:
YunaiV 2024-05-17 22:25:50 +08:00
parent 9de9e938bf
commit 275d1fb627
5 changed files with 9 additions and 98 deletions

View File

@ -1,20 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
@Data
public class AiChatMessageAddReqVO {
@Schema(description = "聊天对话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
@NotNull(message = "聊天对话编号不能为空")
private Long conversationId;
@Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "帮我写个 Java 算法")
@NotEmpty(message = "聊天内容不能为空")
private String content;
}

View File

@ -1,17 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
@Schema(description = "管理后台 - AI 聊天消息 Add Response VO")
@Data
public class AiChatMessageAddRespVO {
@Schema(description = "用户信息")
private AiChatMessageRespVO userMessage;
@Schema(description = "系统信息")
private AiChatMessageRespVO systemMessage;
}

View File

@ -21,9 +21,6 @@ public class AiChatMessageRespVO {
@Schema(description = "用户编号", example = "4096") @Schema(description = "用户编号", example = "4096")
private Long userId; // 仅当 user 发送时非空 private Long userId; // 仅当 user 发送时非空
@Schema(description = "用户头像", example = "http://xxx")
private Long avatarUrl; // 仅当 user 发送时非空
@Schema(description = "角色编号", example = "888") @Schema(description = "角色编号", example = "888")
private Long roleId; // 仅当 assistant 回复时非空 private Long roleId; // 仅当 assistant 回复时非空
@ -33,15 +30,9 @@ public class AiChatMessageRespVO {
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "123") @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "123")
private Long modelId; private Long modelId;
@Schema(description = "模型图片", requiredMode = Schema.RequiredMode.REQUIRED, example = "123")
private String modelImage;
@Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊") @Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊")
private String content; private String content;
@Schema(description = "消耗 Token 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "80")
private Integer tokens;
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51") @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
private LocalDateTime createTime; private LocalDateTime createTime;

View File

@ -78,32 +78,7 @@ public class AiChatMessageDO extends BaseDO {
* 聊天内容 * 聊天内容
*/ */
private String content; private String content;
/**
* 消耗 Token 数量
*/
private Integer tokens;
// TODO 芋艿是否作为上下文语料use_context待定 // TODO 芋艿是否作为上下文语料use_context待定
// ========== 会话配置 ==========
/**
* 温度参数
*
* 冗余 {@link AiChatConversationDO#getTemperature()}
*/
private Double temperature;
/**
* 单条回复的最大 Token 数量
*
* 冗余 {@link AiChatConversationDO#getMaxTokens()}
*/
private Integer maxTokens;
/**
* 上下文的最大 Message 数量
*
* 冗余 {@link AiChatConversationDO#getMaxContexts()}
*/
private Integer maxContexts;
} }

View File

@ -9,7 +9,6 @@ import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
@ -72,8 +71,7 @@ public class AiChatServiceImpl implements AiChatService {
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform()); AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
// 保存 chat message // 保存 chat message
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), req.getContent(), chatModel.getModel(), chatModel.getId(), req.getContent());
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
String content = null; String content = null;
int tokens = 0; int tokens = 0;
try { try {
@ -94,28 +92,21 @@ public class AiChatServiceImpl implements AiChatService {
} finally { } finally {
// 保存 chat message // 保存 chat message
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), content, chatModel.getModel(), chatModel.getId(), content);
tokens, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
} }
return new AiChatMessageRespVO().setContent(content); return new AiChatMessageRespVO().setContent(content);
} }
private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId, private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId,
String model, Long modelId, String content, Integer tokens, Double temperature, String model, Long modelId, String content) {
Integer maxTokens, Integer maxContexts) {
AiChatMessageDO insertChatMessageDO = new AiChatMessageDO() AiChatMessageDO insertChatMessageDO = new AiChatMessageDO()
.setId(null)
.setConversationId(conversationId) .setConversationId(conversationId)
.setType(messageType.getValue()) .setType(messageType.getValue())
.setUserId(loginUserId) .setUserId(loginUserId)
.setRoleId(roleId) .setRoleId(roleId)
.setModel(model) .setModel(model)
.setModelId(modelId) .setModelId(modelId)
.setContent(content) .setContent(content);
.setTokens(tokens)
.setTemperature(temperature)
.setMaxTokens(maxTokens)
.setMaxContexts(maxContexts);
insertChatMessageDO.setCreateTime(LocalDateTime.now()); insertChatMessageDO.setCreateTime(LocalDateTime.now());
// 增加 chat message 记录 // 增加 chat message 记录
aiChatMessageMapper.insert(insertChatMessageDO); aiChatMessageMapper.insert(insertChatMessageDO);
@ -134,15 +125,13 @@ public class AiChatServiceImpl implements AiChatService {
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform); StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
// 2. 插入 user 发送消息 TODO tokens 计算 // 2. 插入 user 发送消息
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(), AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(),
conversation.getModel(), conversation.getId(), sendReqVO.getContent(), conversation.getModel(), conversation.getId(), sendReqVO.getContent());
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
// 3.1 插入 system 接收消息 // 3.1 插入 system 接收消息
AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(), AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(),
conversation.getModel(), conversation.getId(), conversation.getSystemMessage(), conversation.getModel(), conversation.getId(), conversation.getSystemMessage());
0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
// 3.2 创建 chat 需要的 Prompt // 3.2 创建 chat 需要的 Prompt
// TODO 消息上下文 // TODO 消息上下文
Prompt prompt = new Prompt(sendReqVO.getContent()); Prompt prompt = new Prompt(sendReqVO.getContent());
@ -150,11 +139,8 @@ public class AiChatServiceImpl implements AiChatService {
Flux<ChatResponse> streamResponse = chatClient.stream(prompt); Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
// 3.3 转换 flex AiChatMessageRespVO // 3.3 转换 flex AiChatMessageRespVO
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
AtomicInteger tokens = new AtomicInteger(0); // TODO token 计算不对
return streamResponse.map(res -> { return streamResponse.map(res -> {
contentBuffer.append(res.getResult().getOutput().getContent()); contentBuffer.append(res.getResult().getOutput().getContent());
tokens.incrementAndGet();
AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId()) AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId())
.setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime()) .setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime())
.setContent(sendReqVO.getContent()); .setContent(sendReqVO.getContent());
@ -167,17 +153,13 @@ public class AiChatServiceImpl implements AiChatService {
// 保存 chat message // 保存 chat message
aiChatMessageMapper.updateById(new AiChatMessageDO() aiChatMessageMapper.updateById(new AiChatMessageDO()
.setId(systemMessage.getId()) .setId(systemMessage.getId())
.setContent(contentBuffer.toString()) .setContent(contentBuffer.toString()));
.setTokens(tokens.get())
);
}).doOnError(throwable -> { }).doOnError(throwable -> {
log.error("发送错误 {}!", throwable.getMessage()); log.error("发送错误 {}!", throwable.getMessage());
// 更新错误信息 TODO 貌似不应该更新异常 // 更新错误信息 TODO 貌似不应该更新异常
aiChatMessageMapper.updateById(new AiChatMessageDO() aiChatMessageMapper.updateById(new AiChatMessageDO()
.setId(systemMessage.getId()) .setId(systemMessage.getId())
.setContent(throwable.getMessage()) .setContent(throwable.getMessage()));
.setTokens(tokens.get())
);
}); });
} }