【代码优化】AI:ChatGlm 替换成 ZhiPuAiImage 实现

This commit is contained in:
YunaiV 2024-07-13 00:06:48 +08:00
parent 4311fe4517
commit 73502d565f
9 changed files with 60 additions and 273 deletions

View File

@ -9,7 +9,6 @@ import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil; import cn.hutool.extra.spring.SpringUtil;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageOptions;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
@ -34,6 +33,7 @@ import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.qianfan.QianFanImageOptions; import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
@ -105,7 +105,9 @@ public class AiImageServiceImpl implements AiImageService {
ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request)); ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
// 2. 上传到文件服务 // 2. 上传到文件服务
byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json()); String b64Json = response.getResult().getOutput().getB64Json();
byte[] fileContent = StrUtil.isNotEmpty(b64Json) ? Base64.decode(b64Json)
: HttpUtil.downloadBytes(response.getResult().getOutput().getUrl());
String filePath = fileApi.createFile(fileContent); String filePath = fileApi.createFile(fileContent);
// 3. 更新数据库 // 3. 更新数据库
@ -149,8 +151,8 @@ public class AiImageServiceImpl implements AiImageService {
.withModel(draw.getModel()).withN(1) .withModel(draw.getModel()).withN(1)
.withHeight(draw.getHeight()).withWidth(draw.getWidth()) .withHeight(draw.getHeight()).withWidth(draw.getWidth())
.build(); .build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.CHATGLM.getPlatform())) { } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
return ChatGlmImageOptions.builder() return ZhiPuAiImageOptions.builder()
.withModel(draw.getModel()) .withModel(draw.getModel())
.build(); .build();
} }

View File

@ -60,13 +60,6 @@
<version>2.14.0</version> <version>2.14.0</version>
</dependency> </dependency>
<!-- bigmodel -->
<dependency>
<groupId>cn.bigmodel.openapi</groupId>
<artifactId>oapi-java-sdk</artifactId>
<version>release-V4-2.0.2</version>
</dependency>
<!-- Test 测试相关 --> <!-- Test 测试相关 -->
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>

View File

@ -28,7 +28,6 @@ public enum AiPlatformEnum {
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney
SUNO("Suno", "Suno"), // Suno AI SUNO("Suno", "Suno"), // Suno AI
CHATGLM("ChatGlm", "ChatGlm"), // Suno AI
; ;

View File

@ -9,7 +9,6 @@ import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration; import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties; import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageModel;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
@ -31,6 +30,7 @@ import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallbackContext;
@ -48,7 +48,9 @@ import org.springframework.ai.qianfan.api.QianFanImageApi;
import org.springframework.ai.stabilityai.StabilityAiImageModel; import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.ai.stabilityai.api.StabilityAiApi; import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.ai.zhipuai.ZhiPuAiChatModel; import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi; import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
import org.springframework.retry.support.RetryTemplate; import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient; import org.springframework.web.client.RestClient;
@ -119,6 +121,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return SpringUtil.getBean(TongYiImagesModel.class); return SpringUtil.getBean(TongYiImagesModel.class);
case YI_YAN: case YI_YAN:
return SpringUtil.getBean(QianFanImageModel.class); return SpringUtil.getBean(QianFanImageModel.class);
case ZHI_PU:
return SpringUtil.getBean(ZhiPuAiImageModel.class);
case OPENAI: case OPENAI:
return SpringUtil.getBean(OpenAiImageModel.class); return SpringUtil.getBean(OpenAiImageModel.class);
case STABLE_DIFFUSION: case STABLE_DIFFUSION:
@ -136,12 +140,12 @@ public class AiModelFactoryImpl implements AiModelFactory {
return buildTongYiImagesModel(apiKey); return buildTongYiImagesModel(apiKey);
case YI_YAN: case YI_YAN:
return buildQianFanImageModel(apiKey); return buildQianFanImageModel(apiKey);
case ZHI_PU:
return buildZhiPuAiImageModel(apiKey, url);
case OPENAI: case OPENAI:
return buildOpenAiImageModel(apiKey, url); return buildOpenAiImageModel(apiKey, url);
case STABLE_DIFFUSION: case STABLE_DIFFUSION:
return buildStabilityAiImageModel(apiKey, url); return buildStabilityAiImageModel(apiKey, url);
case CHATGLM:
return buildChatGlmModel(apiKey);
default: default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
} }
@ -225,7 +229,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
} }
/** /**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)} * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
* ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
*/ */
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) { private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
@ -233,6 +238,16 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new ZhiPuAiChatModel(zhiPuAiApi); return new ZhiPuAiChatModel(zhiPuAiApi);
} }
/**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
* ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
ZhiPuAiImageApi zhiPuAiApi = new ZhiPuAiImageApi(url, apiKey, RestClient.builder());
return new ZhiPuAiImageModel(zhiPuAiApi);
}
/** /**
* 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)} * 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
*/ */
@ -276,7 +291,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new StabilityAiImageModel(stabilityAiApi); return new StabilityAiImageModel(stabilityAiApi);
} }
private ChatGlmImageModel buildChatGlmModel(String apiKey) {
return new ChatGlmImageModel(apiKey);
}
} }

View File

@ -1,75 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.chatglm;
import cn.iocoder.yudao.framework.ai.core.model.chatglm.api.ChatGlmResponseMetadata;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
import org.springframework.ai.image.*;
import java.io.ByteArrayOutputStream;
import java.net.URL;
import java.util.Base64;
import java.util.stream.Collectors;
public class ChatGlmImageModel implements ImageModel {
private ClientV4 client;
public ChatGlmImageModel(String apiSecretKey) {
client = new ClientV4.Builder(apiSecretKey).build();
}
@Override
public ImageResponse call(ImagePrompt request) {
CreateImageRequest imageRequest = CreateImageRequest.builder()
.model(request.getOptions().getModel())
.prompt(request.getInstructions().get(0).getText())
.build();
return convert(client.createImage(imageRequest));
}
private ImageResponse convert(ImageApiResponse result) {
return new ImageResponse(
result.getData().getData().stream().map(item -> {
try {
String url = item.getUrl();
String base64Image = convertImageToBase64(url);
Image image = new Image(url, base64Image);
return new ImageGeneration(image);
} catch (Exception e) {
throw new RuntimeException(e);
}
}).collect(Collectors.toList()),
new ChatGlmResponseMetadata(result)
);
}
/**
* Convert image to base64.
* @param imageUrl the image url.
* @return the base64 image.
* @throws Exception the exception.
*/
public String convertImageToBase64(String imageUrl) throws Exception {
var url = new URL(imageUrl);
var inputStream = url.openStream();
var outputStream = new ByteArrayOutputStream();
var buffer = new byte[4096];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, bytesRead);
}
var imageBytes = outputStream.toByteArray();
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
inputStream.close();
outputStream.close();
return base64Image;
}
}

View File

@ -1,115 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.chatglm;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Setter;
import org.springframework.ai.image.ImageOptions;
/**
* chatglm
* api地址https://open.bigmodel.cn/dev/api#cogview
*/
@Setter
public class ChatGlmImageOptions implements ImageOptions {
@JsonProperty("n")
private Integer n;
@JsonProperty("model")
private String model = "cogview-3";
@JsonProperty("size_width")
private Integer width;
@JsonProperty("size_height")
private Integer height;
@JsonProperty("size")
private String size;
@JsonProperty("style")
private String style;
@JsonProperty("user_id")
private String user;
@JsonProperty("responseFormat")
private String responseFormat;
// ==== build
public static ChatGlmImageOptions.Builder builder() {
return new ChatGlmImageOptions.Builder();
}
public static class Builder {
private final ChatGlmImageOptions options;
private Builder() {
this.options = new ChatGlmImageOptions();
}
public ChatGlmImageOptions.Builder withN(Integer n) {
options.setN(n);
return this;
}
public ChatGlmImageOptions.Builder withModel(String model) {
options.setModel(model);
return this;
}
public ChatGlmImageOptions.Builder withWidth(Integer width) {
options.setWidth(width);
return this;
}
public ChatGlmImageOptions.Builder withHeight(Integer height) {
options.setHeight(height);
return this;
}
public ChatGlmImageOptions.Builder withStyle(String style) {
options.setStyle(style);
return this;
}
public ChatGlmImageOptions.Builder withUser(String user) {
options.setUser(user);
return this;
}
public ChatGlmImageOptions build() {
return options;
}
}
// ==== get
@Override
public Integer getN() {
return n;
}
@Override
public String getModel() {
return model;
}
@Override
public Integer getWidth() {
return width;
}
@Override
public Integer getHeight() {
return height;
}
@Override
public String getResponseFormat() {
return responseFormat;
}
}

View File

@ -1,24 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.chatglm.api;
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
import org.springframework.ai.image.ImageResponseMetadata;
import java.util.HashMap;
public class ChatGlmResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
private Long created;
public ChatGlmResponseMetadata(ImageApiResponse result) {
created = result.getData().getCreated();
}
@Override
public Long getCreated() {
return created;
}
public void setCreated(Long created) {
this.created = created;
}
}

View File

@ -1,40 +0,0 @@
package cn.iocoder.yudao.framework.ai.image;
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageModel;
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageOptions;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import com.alibaba.fastjson.JSON;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.core.httpclient.ApacheHttpClientTransport;
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImageOptionsBuilder;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.qianfan.QianFanImageModel;
import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.qianfan.api.QianFanImageApi;
/**
* 百度千帆 image
*/
public class ChatGlmImageModelTests {
@Test
public void callTest() {
ChatGlmImageModel model = new ChatGlmImageModel("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
ImageResponse call = model.call(new ImagePrompt("万里长城", ChatGlmImageOptions.builder().build()));
System.err.println(call.getResult().getOutput().getUrl());
}
@Test
public void createImageTest() {
ClientV4 client = new ClientV4.Builder("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy").build();
CreateImageRequest createImageRequest = new CreateImageRequest();
createImageRequest.setModel("cogview-3");
createImageRequest.setPrompt("长城!");
ImageApiResponse image = client.createImage(createImageRequest);
System.err.println(JSON.toJSONString(image));
}
}

View File

@ -0,0 +1,35 @@
package cn.iocoder.yudao.framework.ai.image;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
/**
* {@link ZhiPuAiImageModel} 集成测试
*/
public class ZhiPuAiImageModelTests {
private final ZhiPuAiImageApi imageApi = new ZhiPuAiImageApi(
"78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
private final ZhiPuAiImageModel imageModel = new ZhiPuAiImageModel(imageApi);
@Test
@Disabled
public void testCall() {
// 准备参数
ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder()
.withModel(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
.build();
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
// 方法调用
ImageResponse response = imageModel.call(prompt);
// 打印结果
System.out.println(response);
}
}