【优化】AI:简化 AiChatMessageDO 消息表,去除 tokens、temperature、maxTokens、maxContexts 字段,因为 spring-ai 没有返回 tokens 相关的字段

This commit is contained in:
YunaiV 2024-05-17 22:25:50 +08:00
parent 9de9e938bf
commit 275d1fb627
5 changed files with 9 additions and 98 deletions

View File

@ -1,20 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
@Data
public class AiChatMessageAddReqVO {
@Schema(description = "聊天对话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
@NotNull(message = "聊天对话编号不能为空")
private Long conversationId;
@Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "帮我写个 Java 算法")
@NotEmpty(message = "聊天内容不能为空")
private String content;
}

View File

@ -1,17 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
@Schema(description = "管理后台 - AI 聊天消息 Add Response VO")
@Data
public class AiChatMessageAddRespVO {
@Schema(description = "用户信息")
private AiChatMessageRespVO userMessage;
@Schema(description = "系统信息")
private AiChatMessageRespVO systemMessage;
}

View File

@ -21,9 +21,6 @@ public class AiChatMessageRespVO {
@Schema(description = "用户编号", example = "4096")
private Long userId; // 仅当 user 发送时非空
@Schema(description = "用户头像", example = "http://xxx")
private Long avatarUrl; // 仅当 user 发送时非空
@Schema(description = "角色编号", example = "888")
private Long roleId; // 仅当 assistant 回复时非空
@ -33,15 +30,9 @@ public class AiChatMessageRespVO {
@Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "123")
private Long modelId;
@Schema(description = "模型图片", requiredMode = Schema.RequiredMode.REQUIRED, example = "123")
private String modelImage;
@Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊")
private String content;
@Schema(description = "消耗 Token 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "80")
private Integer tokens;
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
private LocalDateTime createTime;

View File

@ -78,32 +78,7 @@ public class AiChatMessageDO extends BaseDO {
* 聊天内容
*/
private String content;
/**
* 消耗 Token 数量
*/
private Integer tokens;
// TODO 芋艿是否作为上下文语料use_context待定
// ========== 会话配置 ==========
/**
* 温度参数
*
* 冗余 {@link AiChatConversationDO#getTemperature()}
*/
private Double temperature;
/**
* 单条回复的最大 Token 数量
*
* 冗余 {@link AiChatConversationDO#getMaxTokens()}
*/
private Integer maxTokens;
/**
* 上下文的最大 Message 数量
*
* 冗余 {@link AiChatConversationDO#getMaxContexts()}
*/
private Integer maxContexts;
}

View File

@ -9,7 +9,6 @@ import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
@ -72,8 +71,7 @@ public class AiChatServiceImpl implements AiChatService {
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
chatModel.getModel(), chatModel.getId(), req.getContent());
String content = null;
int tokens = 0;
try {
@ -94,28 +92,21 @@ public class AiChatServiceImpl implements AiChatService {
} finally {
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModel.getModel(), chatModel.getId(), content,
tokens, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
chatModel.getModel(), chatModel.getId(), content);
}
return new AiChatMessageRespVO().setContent(content);
}
private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId,
String model, Long modelId, String content, Integer tokens, Double temperature,
Integer maxTokens, Integer maxContexts) {
String model, Long modelId, String content) {
AiChatMessageDO insertChatMessageDO = new AiChatMessageDO()
.setId(null)
.setConversationId(conversationId)
.setType(messageType.getValue())
.setUserId(loginUserId)
.setRoleId(roleId)
.setModel(model)
.setModelId(modelId)
.setContent(content)
.setTokens(tokens)
.setTemperature(temperature)
.setMaxTokens(maxTokens)
.setMaxContexts(maxContexts);
.setContent(content);
insertChatMessageDO.setCreateTime(LocalDateTime.now());
// 增加 chat message 记录
aiChatMessageMapper.insert(insertChatMessageDO);
@ -134,15 +125,13 @@ public class AiChatServiceImpl implements AiChatService {
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
// 2. 插入 user 发送消息 TODO tokens 计算
// 2. 插入 user 发送消息
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(),
conversation.getModel(), conversation.getId(), sendReqVO.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
conversation.getModel(), conversation.getId(), sendReqVO.getContent());
// 3.1 插入 system 接收消息
AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(),
conversation.getModel(), conversation.getId(), conversation.getSystemMessage(),
0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
conversation.getModel(), conversation.getId(), conversation.getSystemMessage());
// 3.2 创建 chat 需要的 Prompt
// TODO 消息上下文
Prompt prompt = new Prompt(sendReqVO.getContent());
@ -150,11 +139,8 @@ public class AiChatServiceImpl implements AiChatService {
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
// 3.3 转换 flex AiChatMessageRespVO
StringBuffer contentBuffer = new StringBuffer();
AtomicInteger tokens = new AtomicInteger(0); // TODO token 计算不对
return streamResponse.map(res -> {
contentBuffer.append(res.getResult().getOutput().getContent());
tokens.incrementAndGet();
AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId())
.setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime())
.setContent(sendReqVO.getContent());
@ -167,17 +153,13 @@ public class AiChatServiceImpl implements AiChatService {
// 保存 chat message
aiChatMessageMapper.updateById(new AiChatMessageDO()
.setId(systemMessage.getId())
.setContent(contentBuffer.toString())
.setTokens(tokens.get())
);
.setContent(contentBuffer.toString()));
}).doOnError(throwable -> {
log.error("发送错误 {}!", throwable.getMessage());
// 更新错误信息 TODO 貌似不应该更新异常
aiChatMessageMapper.updateById(new AiChatMessageDO()
.setId(systemMessage.getId())
.setContent(throwable.getMessage())
.setTokens(tokens.get())
);
.setContent(throwable.getMessage()));
});
}