diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/Utf8SseEmitter.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/Utf8SseEmitter.java deleted file mode 100644 index d23f4e9b1..000000000 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/Utf8SseEmitter.java +++ /dev/null @@ -1,26 +0,0 @@ -package cn.iocoder.yudao.module.ai.controller; - -import org.springframework.http.HttpHeaders; -import org.springframework.http.MediaType; -import org.springframework.http.server.ServerHttpResponse; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; - -import java.nio.charset.StandardCharsets; - -/** - * 解决中文乱码 - * - * @author fansili - * @time 2024/4/14 15:13 - * @since 1.0 - */ -public class Utf8SseEmitter extends SseEmitter { - - @Override - protected void extendResponse(ServerHttpResponse outputMessage) { - super.extendResponse(outputMessage); - - HttpHeaders headers = outputMessage.getHeaders(); - headers.setContentType(new MediaType(MediaType.TEXT_EVENT_STREAM, StandardCharsets.UTF_8)); - } -} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java index a36f871a6..82392ed27 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java @@ -1,10 +1,9 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat; import cn.iocoder.yudao.framework.common.pojo.CommonResult; -import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; -import cn.iocoder.yudao.module.ai.service.AiChatService; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; +import cn.iocoder.yudao.module.ai.service.AiChatService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.tags.Tag; @@ -13,7 +12,7 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.http.MediaType; import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.*; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import reactor.core.publisher.Flux; import java.util.List; @@ -39,10 +38,8 @@ public class AiChatMessageController { // TODO @fan:要不要使用 Flux 来返回;可以使用 Flux @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public SseEmitter sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { - Utf8SseEmitter sseEmitter = new Utf8SseEmitter(); - chatService.chatStream(sendReqVO, sseEmitter); - return sseEmitter; + public Flux sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { + return chatService.chatStream(sendReqVO); } @Operation(summary = "获得指定会话的消息列表") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java index 43e05dc83..274391777 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java @@ -1,10 +1,9 @@ package cn.iocoder.yudao.module.ai.controller.admin.image; import cn.iocoder.yudao.framework.common.pojo.CommonResult; -import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; -import cn.iocoder.yudao.module.ai.service.AiImageService; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq; +import cn.iocoder.yudao.module.ai.service.AiImageService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import lombok.AllArgsConstructor; @@ -14,7 +13,6 @@ import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; // TODO @芋艿:整理接口定义 /** @@ -35,10 +33,11 @@ public class AiImageController { @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!") @PostMapping("/dallDrawing") - public SseEmitter dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) { - Utf8SseEmitter sseEmitter = new Utf8SseEmitter(); - aiImageService.dallDrawing(req, sseEmitter); - return sseEmitter; + public void dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) { +// Utf8SseEmitter sseEmitter = new Utf8SseEmitter(); +// aiImageService.dallDrawing(req, sseEmitter); +// return sseEmitter; + } @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果") diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java index 27ece6b14..a5e97ce5f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java @@ -1,8 +1,8 @@ package cn.iocoder.yudao.module.ai.service; -import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; +import reactor.core.publisher.Flux; import java.util.List; @@ -26,11 +26,10 @@ public interface AiChatService { /** * chat stream * - * @param req - * @param sseEmitter + * @param sendReqVO * @return */ - void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter); + Flux chatStream(AiChatMessageSendReqVO sendReqVO); /** * 获取 - 获取对话 message list diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java index 00b5ded44..cf95483d2 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java @@ -1,6 +1,5 @@ package cn.iocoder.yudao.module.ai.service; -import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq; @@ -17,9 +16,8 @@ public interface AiImageService { * ai绘画 - dall2/dall3 绘画 * * @param req - * @param sseEmitter */ - void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter); + void dallDrawing(AiImageDallDrawingReq req); /** * midjourney 图片生成 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java index ba1b4b679..90649f417 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java @@ -9,7 +9,6 @@ import cn.iocoder.yudao.framework.ai.chat.messages.MessageType; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; -import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; @@ -25,13 +24,12 @@ import cn.iocoder.yudao.module.ai.service.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.AiChatService; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.http.MediaType; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; -import java.io.IOException; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; /** @@ -76,6 +74,7 @@ public class AiChatServiceImpl implements AiChatService { chatModal.getModel(), chatModal.getId(), req.getContent(), null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); String content = null; + int tokens = 0; try { // 创建 chat 需要的 Prompt Prompt prompt = new Prompt(req.getContent()); @@ -87,6 +86,7 @@ public class AiChatServiceImpl implements AiChatService { ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum); ChatResponse call = chatClient.call(prompt); content = call.getResult().getOutput().getContent(); + tokens = call.getResults().size(); // 更新 conversation } catch (Exception e) { content = ExceptionUtil.getMessage(e); @@ -94,7 +94,7 @@ public class AiChatServiceImpl implements AiChatService { // 保存 chat message insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), chatModal.getModel(), chatModal.getId(), content, - null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); + tokens, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); } return new AiChatMessageRespVO().setContent(content); } @@ -123,8 +123,7 @@ public class AiChatServiceImpl implements AiChatService { return insertChatMessageDO; } - @Override - public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) { + public Flux chatStream(AiChatMessageSendReqVO req) { Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); // 查询对话 AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId()); @@ -144,47 +143,43 @@ public class AiChatServiceImpl implements AiChatService { // req.setTopK(req.getTopK()); // req.setTopP(req.getTopP()); // req.setTemperature(req.getTemperature()); - // 保存 chat message // 保存 chat message insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), chatModal.getModel(), chatModal.getId(), req.getContent(), null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); - // 获取 client 类型 AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform()); StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); Flux streamResponse = streamingChatClient.stream(prompt); - + // 转换 flex AiChatMessageRespVO StringBuffer contentBuffer = new StringBuffer(); - streamResponse.subscribe( - new Consumer() { - @Override - public void accept(ChatResponse chatResponse) { - String content = chatResponse.getResults().get(0).getOutput().getContent(); - try { - contentBuffer.append(content); - sseEmitter.send(new AiChatMessageRespVO().setContent(content), MediaType.APPLICATION_JSON); - } catch (IOException e) { - log.error("发送异常{}", ExceptionUtil.getMessage(e)); - // 如果不是因为关闭而抛出异常,则重新连接 - sseEmitter.completeWithError(e); - } - } - }, - error -> { - // - log.error("subscribe错误 {}", ExceptionUtil.getMessage(error)); - }, - () -> { - log.info("发送完成!"); - sseEmitter.complete(); - // 保存 chat message - insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), - chatModal.getModel(), chatModal.getId(), contentBuffer.toString(), - null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); - + AtomicInteger tokens = new AtomicInteger(0); + return streamResponse.map(res -> { + AiChatMessageRespVO aiChatMessageRespVO = new AiChatMessageRespVO(); + aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent()); + contentBuffer.append(res.getResult().getOutput().getContent()); + tokens.incrementAndGet(); + return aiChatMessageRespVO; } - ); + ).doOnComplete(new Runnable() { + @Override + public void run() { + log.info("发送完成!"); + // 保存 chat message + insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), + chatModal.getModel(), chatModal.getId(), contentBuffer.toString(), + tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); + } + }).doOnError(new Consumer() { + @Override + public void accept(Throwable throwable) { + log.error("发送错误 {}!", throwable.getMessage()); + // 保存 chat message + insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), + chatModal.getModel(), chatModal.getId(), throwable.getMessage(), + tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); + } + }); } @Override @@ -194,7 +189,7 @@ public class AiChatServiceImpl implements AiChatService { // 获取对话所有 message List aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId); // 转换 AiChatMessageRespVO - return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList); + return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList); } @Override diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java index 16a414681..5bacf9da5 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java @@ -5,8 +5,8 @@ import cn.iocoder.yudao.framework.ai.image.ImageGeneration; import cn.iocoder.yudao.framework.ai.image.ImagePrompt; import cn.iocoder.yudao.framework.ai.image.ImageResponse; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient; -import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions; +import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum; import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageStyleEnum; import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter; @@ -14,22 +14,18 @@ import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify; import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.module.ai.ErrorCodeConstants; -import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; -import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; -import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum; -import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper; -import cn.iocoder.yudao.module.ai.service.AiImageService; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq; +import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; +import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper; +import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum; +import cn.iocoder.yudao.module.ai.service.AiImageService; import jakarta.annotation.PostConstruct; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.http.MediaType; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -import java.io.IOException; - /** * ai 作图 * @@ -64,7 +60,7 @@ public class AiImageServiceImpl implements AiImageService { } @Override - public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) { + public void dallDrawing(AiImageDallDrawingReq req) { // 获取 model OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal()); OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle()); @@ -79,7 +75,7 @@ public class AiImageServiceImpl implements AiImageService { // 发送 ImageGeneration imageGeneration = imageResponse.getResult(); // 发送信息 - sendSseEmitter(sseEmitter, imageGeneration); +// sendSseEmitter(sseEmitter, imageGeneration); // 保存数据库 doSave(req.getPrompt(), req.getSize(), req.getModal(), imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null); @@ -88,7 +84,7 @@ public class AiImageServiceImpl implements AiImageService { doSave(req.getPrompt(), req.getSize(), req.getModal(), null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage()); // 发送错误信息 - sendSseEmitter(sseEmitter, aiException.getMessage()); +// sendSseEmitter(sseEmitter, aiException.getMessage()); } } @@ -105,16 +101,16 @@ public class AiImageServiceImpl implements AiImageService { } } - private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) { - try { - sseEmitter.send(object, MediaType.APPLICATION_JSON); - } catch (IOException e) { - throw new RuntimeException(e); - } finally { - // 发送 complete - sseEmitter.complete(); - } - } +// private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) { +// try { +// sseEmitter.send(object, MediaType.APPLICATION_JSON); +// } catch (IOException e) { +// throw new RuntimeException(e); +// } finally { +// // 发送 complete +// sseEmitter.complete(); +// } +// } private AiImageDO doSave(String prompt, String size, diff --git a/yudao-server/src/main/resources/application-local.yaml b/yudao-server/src/main/resources/application-local.yaml index dd191ce9a..135b3df81 100644 --- a/yudao-server/src/main/resources/application-local.yaml +++ b/yudao-server/src/main/resources/application-local.yaml @@ -2,7 +2,6 @@ server: port: 48080 --- #################### 数据库相关配置 #################### - spring: # 数据源配置项 autoconfigure: @@ -79,7 +78,12 @@ spring: port: 6379 # 端口 database: 0 # 数据库索引 # password: dev # 密码,建议生产环境开启 - +server: + servlet: + encoding: + enabled: true + charset: UTF-8 + force: true --- #################### 定时任务相关配置 #################### # Quartz 配置项,对应 QuartzProperties 配置类