【优化】dall 绘画,改为异步。

This commit is contained in:
cherishsince 2024-05-28 11:39:29 +08:00
parent 63a8cc244d
commit e97408b3ac

View File

@ -35,6 +35,9 @@ import org.springframework.transaction.annotation.Transactional;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/** /**
* ai 作图 * ai 作图
@ -53,6 +56,8 @@ public class AiImageServiceImpl implements AiImageService {
private final OpenAiImageClient openAiImageClient; private final OpenAiImageClient openAiImageClient;
private final MidjourneyWebSocketStarter midjourneyWebSocketStarter; private final MidjourneyWebSocketStarter midjourneyWebSocketStarter;
private final MidjourneyInteractionsApi midjourneyInteractionsApi; private final MidjourneyInteractionsApi midjourneyInteractionsApi;
private static ThreadPoolExecutor EXECUTOR = new ThreadPoolExecutor(
3, 5, 1, TimeUnit.HOURS, new LinkedBlockingQueue<>(32));
@PostConstruct @PostConstruct
public void startMidjourney() { public void startMidjourney() {
@ -89,34 +94,48 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
public AiImageDallRespVO dallDrawing(AiImageDallReqVO req) { public AiImageDallRespVO dallDrawing(AiImageDallReqVO req) {
// 获取 model // 保存数据库
OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel()); AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle()); null, null, AiImageStatusEnum.IN_PROGRESS, null,
try { null, null, null);
// 转换openai 参数 // 异步执行
OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions(); EXECUTOR.execute(() -> {
openAiImageOptions.setModel(openAiImageModelEnum.getModel()); try {
openAiImageOptions.setStyle(openAiImageStyleEnum.getStyle());
openAiImageOptions.setSize(req.getSize()); // 获取 model
ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions)); OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel());
// 发送 OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
ImageGeneration imageGeneration = imageResponse.getResult();
// 图片保存到服务器 // 转换openai 参数
String filePath = fileApi.createFile(HttpUtil.downloadBytes(imageGeneration.getOutput().getUrl())); OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
// 保存数据库 openAiImageOptions.setModel(openAiImageModelEnum.getModel());
AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(), openAiImageOptions.setStyle(openAiImageStyleEnum.getStyle());
filePath, imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null, openAiImageOptions.setSize(req.getSize());
null, null, null); ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
// 转换 AiImageDallDrawingRespVO // 发送
return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO); ImageGeneration imageGeneration = imageResponse.getResult();
} catch (AiException aiException) { // 图片保存到服务器
// 保存数据库 String filePath = fileApi.createFile(HttpUtil.downloadBytes(imageGeneration.getOutput().getUrl()));
AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(), // 更新数据库
null, null, AiImageStatusEnum.FAIL, aiException.getMessage(), aiImageMapper.updateById(
null, null, null); new AiImageDO()
// 发送错误信息 .setId(aiImageDO.getId())
return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO); .setStatus(AiImageStatusEnum.COMPLETE.getStatus())
} .setPicUrl(filePath)
.setOriginalPicUrl(imageGeneration.getOutput().getUrl())
);
} catch (AiException aiException) {
// 更新错误信息
aiImageMapper.updateById(
new AiImageDO()
.setId(aiImageDO.getId())
.setStatus(AiImageStatusEnum.FAIL.getStatus())
.setErrorMessage(aiException.getMessage())
);
}
});
// 转换 AiImageDallDrawingRespVO
return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
} }
@Override @Override