mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2024-12-02 04:01:52 +08:00
【新增】AI:流式发送消息的微调,统一成单接口
This commit is contained in:
parent
20657ccaf3
commit
b31e919d52
@ -10,13 +10,13 @@ Authorization: {{token}}
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
### chat call
|
### 发送消息(流式)
|
||||||
POST {{baseUrl}}/admin-api/ai/chat/message/send-stream
|
POST {{baseUrl}}/ai/chat/message/send-stream
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
Authorization: {{token}}
|
Authorization: {{token}}
|
||||||
|
|
||||||
{
|
{
|
||||||
"conversationId": "1781604279872581649",
|
"conversationId": "1781604279872581651",
|
||||||
"content": "苹果是什么颜色?"
|
"content": "苹果是什么颜色?"
|
||||||
}
|
}
|
||||||
|
|
@ -17,6 +17,7 @@ import reactor.core.publisher.Flux;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||||
|
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
|
||||||
|
|
||||||
@Tag(name = "管理后台 - 聊天消息")
|
@Tag(name = "管理后台 - 聊天消息")
|
||||||
@RestController
|
@RestController
|
||||||
@ -36,14 +37,8 @@ public class AiChatMessageController {
|
|||||||
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
|
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
|
||||||
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||||
@PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
|
@PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
|
||||||
public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendStreamReqVO sendReqVO) {
|
public Flux<AiChatMessageSendRespVO> sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
|
||||||
return chatService.chatStream(sendReqVO);
|
return chatService.sendChatMessageStream(sendReqVO, getLoginUserId());
|
||||||
}
|
|
||||||
|
|
||||||
@Operation(summary = "添加/提问", description = "先创建好 message 前端才好渲染")
|
|
||||||
@PostMapping(value = "/add")
|
|
||||||
public CommonResult<AiChatMessageRespVO> add(@Validated @RequestBody AiChatMessageAddReqVO req) {
|
|
||||||
return success(chatService.add(req));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Operation(summary = "获得指定会话的消息列表")
|
@Operation(summary = "获得指定会话的消息列表")
|
||||||
|
@ -44,4 +44,5 @@ public class AiChatMessageRespVO {
|
|||||||
|
|
||||||
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
|
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
|
||||||
private LocalDateTime createTime;
|
private LocalDateTime createTime;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
|
||||||
|
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
|
||||||
|
@Schema(description = "管理后台 - AI 聊天消息发送 Response VO")
|
||||||
|
@Data
|
||||||
|
public class AiChatMessageSendRespVO {
|
||||||
|
|
||||||
|
@Schema(description = "发送消息", requiredMode = Schema.RequiredMode.REQUIRED)
|
||||||
|
private Message send;
|
||||||
|
|
||||||
|
@Schema(description = "接收消息", requiredMode = Schema.RequiredMode.REQUIRED)
|
||||||
|
private Message receive;
|
||||||
|
|
||||||
|
@Schema(description = "消息")
|
||||||
|
@Data
|
||||||
|
public static class Message {
|
||||||
|
|
||||||
|
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
|
||||||
|
private Long id;
|
||||||
|
|
||||||
|
@Schema(description = "消息类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "role")
|
||||||
|
private String type; // 参见 MessageType 枚举类
|
||||||
|
|
||||||
|
@Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊")
|
||||||
|
private String content;
|
||||||
|
|
||||||
|
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
|
||||||
|
private LocalDateTime createTime;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,16 +0,0 @@
|
|||||||
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
|
|
||||||
|
|
||||||
import io.swagger.v3.oas.annotations.media.Schema;
|
|
||||||
import jakarta.validation.constraints.NotEmpty;
|
|
||||||
import jakarta.validation.constraints.NotNull;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
|
|
||||||
@Data
|
|
||||||
public class AiChatMessageSendStreamReqVO {
|
|
||||||
|
|
||||||
@Schema(description = "提问的 messageId", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
|
|
||||||
@NotNull(message = "提问的 messageId 不能为空")
|
|
||||||
private Long id;
|
|
||||||
|
|
||||||
}
|
|
@ -27,12 +27,4 @@ public interface AiChatMessageConvert {
|
|||||||
*/
|
*/
|
||||||
List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList);
|
List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList);
|
||||||
|
|
||||||
/**
|
|
||||||
* 转换 - aiChatMessageDO
|
|
||||||
*
|
|
||||||
* @param aiChatMessageDO
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
AiChatMessageRespVO convertAiChatMessageRespVO(AiChatMessageDO aiChatMessageDO);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -22,22 +22,6 @@ public interface AiChatService {
|
|||||||
*/
|
*/
|
||||||
AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO);
|
AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO);
|
||||||
|
|
||||||
/**
|
|
||||||
* chat stream
|
|
||||||
*
|
|
||||||
* @param sendReqVO
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO sendReqVO);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 添加 - message
|
|
||||||
*
|
|
||||||
* @param sendReqVO
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
AiChatMessageRespVO add(AiChatMessageAddReqVO sendReqVO);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取 - 获取对话 message list
|
* 获取 - 获取对话 message list
|
||||||
*
|
*
|
||||||
@ -54,4 +38,13 @@ public interface AiChatService {
|
|||||||
*/
|
*/
|
||||||
Boolean deleteMessage(Long id);
|
Boolean deleteMessage(Long id);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 发送消息
|
||||||
|
*
|
||||||
|
* @param sendReqVO
|
||||||
|
* @param userId
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
Flux<AiChatMessageSendRespVO> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,27 +1,24 @@
|
|||||||
package cn.iocoder.yudao.module.ai.service.impl;
|
package cn.iocoder.yudao.module.ai.service.impl;
|
||||||
|
|
||||||
import cn.hutool.core.exceptions.ExceptionUtil;
|
import cn.hutool.core.exceptions.ExceptionUtil;
|
||||||
|
import cn.hutool.core.util.ObjUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
||||||
import org.springframework.ai.chat.ChatClient;
|
import org.springframework.ai.chat.ChatClient;
|
||||||
import org.springframework.ai.chat.ChatResponse;
|
import org.springframework.ai.chat.ChatResponse;
|
||||||
import org.springframework.ai.chat.StreamingChatClient;
|
import org.springframework.ai.chat.StreamingChatClient;
|
||||||
import org.springframework.ai.chat.messages.MessageType;
|
import org.springframework.ai.chat.messages.MessageType;
|
||||||
|
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
|
||||||
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
|
|
||||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||||
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
|
||||||
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
|
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageAddReqVO;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendStreamReqVO;
|
|
||||||
import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
|
import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
|
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
|
||||||
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
|
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
||||||
@ -33,13 +30,16 @@ import org.springframework.stereotype.Service;
|
|||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.function.Consumer;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||||
|
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 聊天 service
|
* 聊天 service
|
||||||
*
|
*
|
||||||
@ -52,11 +52,11 @@ import java.util.stream.Collectors;
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class AiChatServiceImpl implements AiChatService {
|
public class AiChatServiceImpl implements AiChatService {
|
||||||
|
|
||||||
private final AiChatClientFactory aiChatClientFactory;
|
private final AiChatClientFactory chatClientFactory;
|
||||||
|
|
||||||
private final AiChatMessageMapper aiChatMessageMapper;
|
private final AiChatMessageMapper aiChatMessageMapper;
|
||||||
private final AiChatConversationService chatConversationService;
|
private final AiChatConversationService chatConversationService;
|
||||||
private final AiChatModelService aiChatModalService;
|
private final AiChatModelService chatModalService;
|
||||||
private final AiChatRoleService chatRoleService;
|
private final AiChatRoleService chatRoleService;
|
||||||
|
|
||||||
@Transactional(rollbackFor = Exception.class)
|
@Transactional(rollbackFor = Exception.class)
|
||||||
@ -65,7 +65,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
// 查询对话
|
// 查询对话
|
||||||
AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
|
AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
|
||||||
// 获取对话模型
|
// 获取对话模型
|
||||||
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId());
|
||||||
// 获取角色信息
|
// 获取角色信息
|
||||||
AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
|
AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
|
||||||
// 获取 client 类型
|
// 获取 client 类型
|
||||||
@ -84,7 +84,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
// req.setTopP(req.getTopP());
|
// req.setTopP(req.getTopP());
|
||||||
// req.setTemperature(req.getTemperature());
|
// req.setTemperature(req.getTemperature());
|
||||||
// 发送 call 调用
|
// 发送 call 调用
|
||||||
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
|
ChatClient chatClient = chatClientFactory.getChatClient(platformEnum);
|
||||||
ChatResponse call = chatClient.call(prompt);
|
ChatResponse call = chatClient.call(prompt);
|
||||||
content = call.getResult().getOutput().getContent();
|
content = call.getResult().getOutput().getContent();
|
||||||
tokens = call.getResults().size();
|
tokens = call.getResults().size();
|
||||||
@ -113,88 +113,72 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
.setModelId(modelId)
|
.setModelId(modelId)
|
||||||
.setContent(content)
|
.setContent(content)
|
||||||
.setTokens(tokens)
|
.setTokens(tokens)
|
||||||
|
|
||||||
.setTemperature(temperature)
|
.setTemperature(temperature)
|
||||||
.setMaxTokens(maxTokens)
|
.setMaxTokens(maxTokens)
|
||||||
.setMaxContexts(maxContexts);
|
.setMaxContexts(maxContexts);
|
||||||
|
insertChatMessageDO.setCreateTime(LocalDateTime.now());
|
||||||
// 增加 chat message 记录
|
// 增加 chat message 记录
|
||||||
aiChatMessageMapper.insert(insertChatMessageDO);
|
aiChatMessageMapper.insert(insertChatMessageDO);
|
||||||
return insertChatMessageDO;
|
return insertChatMessageDO;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO req) {
|
|
||||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
|
||||||
// 查询提问的 message
|
|
||||||
AiChatMessageDO aiChatMessageDO = aiChatMessageMapper.selectById(req.getId());
|
|
||||||
if (aiChatMessageDO == null) {
|
|
||||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST);
|
|
||||||
}
|
|
||||||
// 查询对话
|
|
||||||
AiChatConversationDO conversation = chatConversationService.validateExists(aiChatMessageDO.getConversationId());
|
|
||||||
// 获取对话模型
|
|
||||||
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
|
||||||
// 获取角色信息
|
|
||||||
AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
|
|
||||||
// 创建 chat 需要的 Prompt
|
|
||||||
Prompt prompt = new Prompt(aiChatMessageDO.getContent());
|
|
||||||
// 提前创建一个 system message
|
|
||||||
AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
||||||
chatModel.getModel(), chatModel.getId(), "",
|
|
||||||
0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
||||||
// req.setTopK(req.getTopK());
|
|
||||||
// req.setTopP(req.getTopP());
|
|
||||||
// req.setTemperature(req.getTemperature());
|
|
||||||
// 获取 client 类型
|
|
||||||
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
|
|
||||||
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
|
|
||||||
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
|
|
||||||
// 转换 flex AiChatMessageRespVO
|
|
||||||
StringBuffer contentBuffer = new StringBuffer();
|
|
||||||
AtomicInteger tokens = new AtomicInteger(0);
|
|
||||||
return streamResponse.map(res -> {
|
|
||||||
AiChatMessageRespVO aiChatMessageRespVO =
|
|
||||||
AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(systemMessage);
|
|
||||||
aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
|
|
||||||
contentBuffer.append(res.getResult().getOutput().getContent());
|
|
||||||
tokens.incrementAndGet();
|
|
||||||
return aiChatMessageRespVO;
|
|
||||||
}
|
|
||||||
).doOnComplete(new Runnable() {
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
log.info("发送完成!");
|
|
||||||
// 保存 chat message
|
|
||||||
aiChatMessageMapper.updateById(new AiChatMessageDO()
|
|
||||||
.setId(systemMessage.getId())
|
|
||||||
.setContent(contentBuffer.toString())
|
|
||||||
.setTokens(tokens.get())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}).doOnError(new Consumer<Throwable>() {
|
|
||||||
@Override
|
|
||||||
public void accept(Throwable throwable) {
|
|
||||||
log.error("发送错误 {}!", throwable.getMessage());
|
|
||||||
// 更新错误信息
|
|
||||||
aiChatMessageMapper.updateById(new AiChatMessageDO()
|
|
||||||
.setId(systemMessage.getId())
|
|
||||||
.setContent(throwable.getMessage())
|
|
||||||
.setTokens(tokens.get())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AiChatMessageRespVO add(AiChatMessageAddReqVO req) {
|
public Flux<AiChatMessageSendRespVO> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) {
|
||||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
// 1.1 校验对话存在
|
||||||
// 查询对话
|
AiChatConversationDO conversation = chatConversationService.validateExists(sendReqVO.getConversationId());
|
||||||
AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
|
if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
|
||||||
// 获取对话模型
|
throw exception(CHAT_CONVERSATION_NOT_EXISTS);
|
||||||
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
}
|
||||||
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
// 1.2 校验模型
|
||||||
chatModel.getModel(), chatModel.getId(), req.getContent(),
|
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
|
||||||
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
|
StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
|
||||||
|
|
||||||
|
// 2. 插入 user 发送消息 TODO tokens 计算
|
||||||
|
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(),
|
||||||
|
conversation.getModel(), conversation.getId(), sendReqVO.getContent(),
|
||||||
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||||
return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(userMessage);
|
|
||||||
|
// 3.1 插入 system 接收消息
|
||||||
|
AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(),
|
||||||
|
conversation.getModel(), conversation.getId(), conversation.getSystemMessage(),
|
||||||
|
0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||||
|
// 3.2 创建 chat 需要的 Prompt
|
||||||
|
// TODO 消息上下文
|
||||||
|
Prompt prompt = new Prompt(sendReqVO.getContent());
|
||||||
|
// ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build()
|
||||||
|
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
||||||
|
// 3.3 转换 flex AiChatMessageRespVO
|
||||||
|
StringBuffer contentBuffer = new StringBuffer();
|
||||||
|
AtomicInteger tokens = new AtomicInteger(0); // TODO token 计算不对;
|
||||||
|
return streamResponse.map(res -> {
|
||||||
|
contentBuffer.append(res.getResult().getOutput().getContent());
|
||||||
|
tokens.incrementAndGet();
|
||||||
|
|
||||||
|
AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId())
|
||||||
|
.setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime())
|
||||||
|
.setContent(sendReqVO.getContent());
|
||||||
|
AiChatMessageSendRespVO.Message receive = new AiChatMessageSendRespVO.Message().setId(systemMessage.getId())
|
||||||
|
.setType(MessageType.SYSTEM.getValue()).setCreateTime(systemMessage.getCreateTime())
|
||||||
|
.setContent(res.getResult().getOutput().getContent());
|
||||||
|
return new AiChatMessageSendRespVO().setSend(send).setReceive(receive);
|
||||||
|
}).doOnComplete(() -> {
|
||||||
|
log.info("发送完成!");
|
||||||
|
// 保存 chat message
|
||||||
|
aiChatMessageMapper.updateById(new AiChatMessageDO()
|
||||||
|
.setId(systemMessage.getId())
|
||||||
|
.setContent(contentBuffer.toString())
|
||||||
|
.setTokens(tokens.get())
|
||||||
|
);
|
||||||
|
}).doOnError(throwable -> {
|
||||||
|
log.error("发送错误 {}!", throwable.getMessage());
|
||||||
|
// 更新错误信息 TODO 貌似不应该更新异常
|
||||||
|
aiChatMessageMapper.updateById(new AiChatMessageDO()
|
||||||
|
.setId(systemMessage.getId())
|
||||||
|
.setContent(throwable.getMessage())
|
||||||
|
.setTokens(tokens.get())
|
||||||
|
);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -205,7 +189,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|||||||
List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId);
|
List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId);
|
||||||
// 获取模型信息
|
// 获取模型信息
|
||||||
Set<Long> modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet());
|
Set<Long> modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet());
|
||||||
List<AiChatModelDO> modalList = aiChatModalService.getModalByIds(modalIds);
|
List<AiChatModelDO> modalList = chatModalService.getModalByIds(modalIds);
|
||||||
Map<Long, AiChatModelDO> modalIdMap = modalList.stream().collect(Collectors.toMap(AiChatModelDO::getId, o -> o));
|
Map<Long, AiChatModelDO> modalIdMap = modalList.stream().collect(Collectors.toMap(AiChatModelDO::getId, o -> o));
|
||||||
// 转换 AiChatMessageRespVO
|
// 转换 AiChatMessageRespVO
|
||||||
List<AiChatMessageRespVO> aiChatMessageRespList = AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList);
|
List<AiChatMessageRespVO> aiChatMessageRespList = AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList);
|
||||||
|
@ -94,7 +94,10 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
|
|||||||
String a = ";";
|
String a = ";";
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return response.map(res -> new ChatResponse(List.of(new Generation(res.getResult()))));
|
return response.map(res -> {
|
||||||
|
// TODO @fan:这里缺少了 usage 的封装
|
||||||
|
return new ChatResponse(List.of(new Generation(res.getResult())));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
|
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
|
||||||
|
Loading…
Reference in New Issue
Block a user