【添加】midjourney 增加 imagine 接口

This commit is contained in:
cherishsince 2024-04-29 14:46:57 +08:00
parent 24479dacfd
commit 330fd52b3e
3 changed files with 79 additions and 10 deletions

View File

@ -1,14 +1,17 @@
package cn.iocoder.yudao.module.ai.controller; package cn.iocoder.yudao.module.ai.controller;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.service.AiImageService; import cn.iocoder.yudao.module.ai.service.AiImageService;
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq; import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@ -30,10 +33,16 @@ public class AiImageController {
private final AiImageService aiImageService; private final AiImageService aiImageService;
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!") @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
@GetMapping("/dallDrawing") @PostMapping("/dallDrawing")
public SseEmitter dallDrawing(@Validated @ModelAttribute AiImageDallDrawingReq req) { public SseEmitter dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) {
Utf8SseEmitter sseEmitter = new Utf8SseEmitter(); Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
aiImageService.dallDrawing(req, sseEmitter); aiImageService.dallDrawing(req, sseEmitter);
return sseEmitter; return sseEmitter;
} }
@Operation(summary = "midjourney", description = "midjourney图片绘画流程1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
@PostMapping("/midjourney")
public CommonResult<AiImageMidjourneyRes> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
return CommonResult.success(aiImageService.midjourney(req));
}
} }

View File

@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq; import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
/** /**
* ai 作图 * ai 作图
@ -19,4 +21,12 @@ public interface AiImageService {
* @param sseEmitter * @param sseEmitter
*/ */
void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter); void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter);
/**
* midjourney 图片生成
*
* @param req
* @return
*/
AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req);
} }

View File

@ -8,17 +8,26 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient;
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageModelEnum; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageModelEnum;
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageStyleEnum; import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageStyleEnum;
import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify;
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO;
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum; import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
import cn.iocoder.yudao.module.ai.mapper.AiImageMapper; import cn.iocoder.yudao.module.ai.mapper.AiImageMapper;
import cn.iocoder.yudao.module.ai.service.AiImageService; import cn.iocoder.yudao.module.ai.service.AiImageService;
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq; import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
import jakarta.annotation.PostConstruct;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.io.IOException; import java.io.IOException;
@ -36,6 +45,24 @@ public class AiImageServiceImpl implements AiImageService {
private final AiImageMapper aiImageMapper; private final AiImageMapper aiImageMapper;
private final OpenAiImageClient openAiImageClient; private final OpenAiImageClient openAiImageClient;
private final MidjourneyWebSocketStarter midjourneyWebSocketStarter;
private final MidjourneyInteractionsApi midjourneyInteractionsApi;
@PostConstruct
public void startMidjourney() {
log.info("midjourney web socket starter...");
midjourneyWebSocketStarter.start(new WssNotify() {
@Override
public void notify(int code, String message) {
log.info("code: {}, message: {}", code, message);
if (message.contains("Authentication failed")) {
// TODO 芋艿这里看怎么处理token无效的时候会认证失败
// 认证失败
log.error("midjourney socket 认证失败检查token是否失效!");
}
}
});
}
@Override @Override
public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) { public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) {
@ -55,15 +82,33 @@ public class AiImageServiceImpl implements AiImageService {
// 发送信息 // 发送信息
sendSseEmitter(sseEmitter, imageGeneration); sendSseEmitter(sseEmitter, imageGeneration);
// 保存数据库 // 保存数据库
doSave(req, imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null); doSave(req.getPrompt(), req.getSize(), req.getModal(),
imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
} catch (AiException aiException) { } catch (AiException aiException) {
// 保存数据库 // 保存数据库
doSave(req, null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage()); doSave(req.getPrompt(), req.getSize(), req.getModal(),
null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
// 发送错误信息 // 发送错误信息
sendSseEmitter(sseEmitter, aiException.getMessage()); sendSseEmitter(sseEmitter, aiException.getMessage());
} }
} }
@Override
@Transactional(rollbackFor = Exception.class)
public AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req) {
// 保存数据库
doSave(req.getPrompt(), null, "midjoureny",
null, AiChatDrawingStatusEnum.SUBMIT, null);
// 提交 midjourney 任务
Boolean imagine = midjourneyInteractionsApi.imagine(req.getPrompt());
if (!imagine) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
}
//
return null;
}
private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) { private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
try { try {
sseEmitter.send(object, MediaType.APPLICATION_JSON); sseEmitter.send(object, MediaType.APPLICATION_JSON);
@ -75,14 +120,19 @@ public class AiImageServiceImpl implements AiImageService {
} }
} }
private void doSave(AiImageDallDrawingReq req, String imageUrl, AiChatDrawingStatusEnum drawingStatusEnum, String drawingError) { private void doSave(String prompt,
String size,
String model,
String imageUrl,
AiChatDrawingStatusEnum drawingStatusEnum,
String drawingError) {
// 保存数据库 // 保存数据库
Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
AiImageDO aiImageDO = new AiImageDO(); AiImageDO aiImageDO = new AiImageDO();
aiImageDO.setId(null); aiImageDO.setId(null);
aiImageDO.setPrompt(req.getPrompt()); aiImageDO.setPrompt(prompt);
aiImageDO.setSize(req.getSize()); aiImageDO.setSize(size);
aiImageDO.setModal(req.getModal()); aiImageDO.setModal(model);
aiImageDO.setUserId(loginUserId); aiImageDO.setUserId(loginUserId);
aiImageDO.setDrawingImageUrl(imageUrl); aiImageDO.setDrawingImageUrl(imageUrl);
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus()); aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());