mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2025-01-31 17:40:05 +08:00
【添加】midjourney 增加 imagine 接口
This commit is contained in:
parent
24479dacfd
commit
330fd52b3e
@ -1,14 +1,17 @@
|
|||||||
package cn.iocoder.yudao.module.ai.controller;
|
package cn.iocoder.yudao.module.ai.controller;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
||||||
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
|
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
|
||||||
|
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
|
||||||
|
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.validation.annotation.Validated;
|
import org.springframework.validation.annotation.Validated;
|
||||||
import org.springframework.web.bind.annotation.GetMapping;
|
import org.springframework.web.bind.annotation.PostMapping;
|
||||||
import org.springframework.web.bind.annotation.ModelAttribute;
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
@ -30,10 +33,16 @@ public class AiImageController {
|
|||||||
private final AiImageService aiImageService;
|
private final AiImageService aiImageService;
|
||||||
|
|
||||||
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
|
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
|
||||||
@GetMapping("/dallDrawing")
|
@PostMapping("/dallDrawing")
|
||||||
public SseEmitter dallDrawing(@Validated @ModelAttribute AiImageDallDrawingReq req) {
|
public SseEmitter dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) {
|
||||||
Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
|
Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
|
||||||
aiImageService.dallDrawing(req, sseEmitter);
|
aiImageService.dallDrawing(req, sseEmitter);
|
||||||
return sseEmitter;
|
return sseEmitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
|
||||||
|
@PostMapping("/midjourney")
|
||||||
|
public CommonResult<AiImageMidjourneyRes> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
|
||||||
|
return CommonResult.success(aiImageService.midjourney(req));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.service;
|
|||||||
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
||||||
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
|
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
|
||||||
|
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
|
||||||
|
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ai 作图
|
* ai 作图
|
||||||
@ -19,4 +21,12 @@ public interface AiImageService {
|
|||||||
* @param sseEmitter
|
* @param sseEmitter
|
||||||
*/
|
*/
|
||||||
void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter);
|
void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* midjourney 图片生成
|
||||||
|
*
|
||||||
|
* @param req
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req);
|
||||||
}
|
}
|
||||||
|
@ -8,17 +8,26 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient;
|
|||||||
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageModelEnum;
|
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageModelEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
|
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
|
||||||
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageStyleEnum;
|
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageStyleEnum;
|
||||||
|
import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
|
||||||
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
|
||||||
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify;
|
||||||
|
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
|
||||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||||
|
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
||||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO;
|
||||||
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
|
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
|
||||||
import cn.iocoder.yudao.module.ai.mapper.AiImageMapper;
|
import cn.iocoder.yudao.module.ai.mapper.AiImageMapper;
|
||||||
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
||||||
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
|
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
|
||||||
|
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
|
||||||
|
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
|
||||||
|
import jakarta.annotation.PostConstruct;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
@ -36,6 +45,24 @@ public class AiImageServiceImpl implements AiImageService {
|
|||||||
|
|
||||||
private final AiImageMapper aiImageMapper;
|
private final AiImageMapper aiImageMapper;
|
||||||
private final OpenAiImageClient openAiImageClient;
|
private final OpenAiImageClient openAiImageClient;
|
||||||
|
private final MidjourneyWebSocketStarter midjourneyWebSocketStarter;
|
||||||
|
private final MidjourneyInteractionsApi midjourneyInteractionsApi;
|
||||||
|
|
||||||
|
@PostConstruct
|
||||||
|
public void startMidjourney() {
|
||||||
|
log.info("midjourney web socket starter...");
|
||||||
|
midjourneyWebSocketStarter.start(new WssNotify() {
|
||||||
|
@Override
|
||||||
|
public void notify(int code, String message) {
|
||||||
|
log.info("code: {}, message: {}", code, message);
|
||||||
|
if (message.contains("Authentication failed")) {
|
||||||
|
// TODO 芋艿,这里看怎么处理,token无效的时候会认证失败!
|
||||||
|
// 认证失败
|
||||||
|
log.error("midjourney socket 认证失败,检查token是否失效!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) {
|
public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) {
|
||||||
@ -55,15 +82,33 @@ public class AiImageServiceImpl implements AiImageService {
|
|||||||
// 发送信息
|
// 发送信息
|
||||||
sendSseEmitter(sseEmitter, imageGeneration);
|
sendSseEmitter(sseEmitter, imageGeneration);
|
||||||
// 保存数据库
|
// 保存数据库
|
||||||
doSave(req, imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
|
doSave(req.getPrompt(), req.getSize(), req.getModal(),
|
||||||
|
imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
|
||||||
} catch (AiException aiException) {
|
} catch (AiException aiException) {
|
||||||
// 保存数据库
|
// 保存数据库
|
||||||
doSave(req, null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
|
doSave(req.getPrompt(), req.getSize(), req.getModal(),
|
||||||
|
null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
|
||||||
// 发送错误信息
|
// 发送错误信息
|
||||||
sendSseEmitter(sseEmitter, aiException.getMessage());
|
sendSseEmitter(sseEmitter, aiException.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@Transactional(rollbackFor = Exception.class)
|
||||||
|
public AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req) {
|
||||||
|
// 保存数据库
|
||||||
|
doSave(req.getPrompt(), null, "midjoureny",
|
||||||
|
null, AiChatDrawingStatusEnum.SUBMIT, null);
|
||||||
|
// 提交 midjourney 任务
|
||||||
|
Boolean imagine = midjourneyInteractionsApi.imagine(req.getPrompt());
|
||||||
|
if (!imagine) {
|
||||||
|
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
|
||||||
|
}
|
||||||
|
//
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
|
private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
|
||||||
try {
|
try {
|
||||||
sseEmitter.send(object, MediaType.APPLICATION_JSON);
|
sseEmitter.send(object, MediaType.APPLICATION_JSON);
|
||||||
@ -75,14 +120,19 @@ public class AiImageServiceImpl implements AiImageService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void doSave(AiImageDallDrawingReq req, String imageUrl, AiChatDrawingStatusEnum drawingStatusEnum, String drawingError) {
|
private void doSave(String prompt,
|
||||||
|
String size,
|
||||||
|
String model,
|
||||||
|
String imageUrl,
|
||||||
|
AiChatDrawingStatusEnum drawingStatusEnum,
|
||||||
|
String drawingError) {
|
||||||
// 保存数据库
|
// 保存数据库
|
||||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||||
AiImageDO aiImageDO = new AiImageDO();
|
AiImageDO aiImageDO = new AiImageDO();
|
||||||
aiImageDO.setId(null);
|
aiImageDO.setId(null);
|
||||||
aiImageDO.setPrompt(req.getPrompt());
|
aiImageDO.setPrompt(prompt);
|
||||||
aiImageDO.setSize(req.getSize());
|
aiImageDO.setSize(size);
|
||||||
aiImageDO.setModal(req.getModal());
|
aiImageDO.setModal(model);
|
||||||
aiImageDO.setUserId(loginUserId);
|
aiImageDO.setUserId(loginUserId);
|
||||||
aiImageDO.setDrawingImageUrl(imageUrl);
|
aiImageDO.setDrawingImageUrl(imageUrl);
|
||||||
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
|
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
|
||||||
|
Loading…
Reference in New Issue
Block a user