From 330fd52b3e2ffc2f6045b736b05ddd2f5762864a Mon Sep 17 00:00:00 2001 From: cherishsince Date: Mon, 29 Apr 2024 14:46:57 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E6=B7=BB=E5=8A=A0=E3=80=91midjourney?= =?UTF-8?q?=20=E5=A2=9E=E5=8A=A0=20imagine=20=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/controller/AiImageController.java | 17 +++-- .../module/ai/service/AiImageService.java | 10 +++ .../ai/service/impl/AiImageServiceImpl.java | 62 +++++++++++++++++-- 3 files changed, 79 insertions(+), 10 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiImageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiImageController.java index f4a20e5fa..aefdf14f4 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiImageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiImageController.java @@ -1,14 +1,17 @@ package cn.iocoder.yudao.module.ai.controller; +import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.module.ai.service.AiImageService; import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq; +import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq; +import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.validation.annotation.Validated; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -30,10 +33,16 @@ public class AiImageController { private final AiImageService aiImageService; @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!") - @GetMapping("/dallDrawing") - public SseEmitter dallDrawing(@Validated @ModelAttribute AiImageDallDrawingReq req) { + @PostMapping("/dallDrawing") + public SseEmitter dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) { Utf8SseEmitter sseEmitter = new Utf8SseEmitter(); aiImageService.dallDrawing(req, sseEmitter); return sseEmitter; } + + @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果") + @PostMapping("/midjourney") + public CommonResult midjourney(@Validated @RequestBody AiImageMidjourneyReq req) { + return CommonResult.success(aiImageService.midjourney(req)); + } } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java index 05e512d24..f4395e398 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java @@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.service; import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq; +import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq; +import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes; /** * ai 作图 @@ -19,4 +21,12 @@ public interface AiImageService { * @param sseEmitter */ void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter); + + /** + * midjourney 图片生成 + * + * @param req + * @return + */ + AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java index cdb95c78c..d87e2d9f3 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java @@ -8,17 +8,26 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageModelEnum; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageStyleEnum; +import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi; +import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter; +import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify; +import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; +import cn.iocoder.yudao.module.ai.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO; import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum; import cn.iocoder.yudao.module.ai.mapper.AiImageMapper; import cn.iocoder.yudao.module.ai.service.AiImageService; import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq; +import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq; +import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes; +import jakarta.annotation.PostConstruct; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.http.MediaType; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import java.io.IOException; @@ -36,6 +45,24 @@ public class AiImageServiceImpl implements AiImageService { private final AiImageMapper aiImageMapper; private final OpenAiImageClient openAiImageClient; + private final MidjourneyWebSocketStarter midjourneyWebSocketStarter; + private final MidjourneyInteractionsApi midjourneyInteractionsApi; + + @PostConstruct + public void startMidjourney() { + log.info("midjourney web socket starter..."); + midjourneyWebSocketStarter.start(new WssNotify() { + @Override + public void notify(int code, String message) { + log.info("code: {}, message: {}", code, message); + if (message.contains("Authentication failed")) { + // TODO 芋艿,这里看怎么处理,token无效的时候会认证失败! + // 认证失败 + log.error("midjourney socket 认证失败,检查token是否失效!"); + } + } + }); + } @Override public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) { @@ -55,15 +82,33 @@ public class AiImageServiceImpl implements AiImageService { // 发送信息 sendSseEmitter(sseEmitter, imageGeneration); // 保存数据库 - doSave(req, imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null); + doSave(req.getPrompt(), req.getSize(), req.getModal(), + imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null); } catch (AiException aiException) { // 保存数据库 - doSave(req, null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage()); + doSave(req.getPrompt(), req.getSize(), req.getModal(), + null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage()); // 发送错误信息 sendSseEmitter(sseEmitter, aiException.getMessage()); } } + @Override + @Transactional(rollbackFor = Exception.class) + public AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req) { + // 保存数据库 + doSave(req.getPrompt(), null, "midjoureny", + null, AiChatDrawingStatusEnum.SUBMIT, null); + // 提交 midjourney 任务 + Boolean imagine = midjourneyInteractionsApi.imagine(req.getPrompt()); + if (!imagine) { + throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL); + } + // + + return null; + } + private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) { try { sseEmitter.send(object, MediaType.APPLICATION_JSON); @@ -75,14 +120,19 @@ public class AiImageServiceImpl implements AiImageService { } } - private void doSave(AiImageDallDrawingReq req, String imageUrl, AiChatDrawingStatusEnum drawingStatusEnum, String drawingError) { + private void doSave(String prompt, + String size, + String model, + String imageUrl, + AiChatDrawingStatusEnum drawingStatusEnum, + String drawingError) { // 保存数据库 Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); AiImageDO aiImageDO = new AiImageDO(); aiImageDO.setId(null); - aiImageDO.setPrompt(req.getPrompt()); - aiImageDO.setSize(req.getSize()); - aiImageDO.setModal(req.getModal()); + aiImageDO.setPrompt(prompt); + aiImageDO.setSize(size); + aiImageDO.setModal(model); aiImageDO.setUserId(loginUserId); aiImageDO.setDrawingImageUrl(imageUrl); aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());