diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 02c1ab334..a7e68656b 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -9,6 +9,7 @@ import cn.hutool.core.util.StrUtil; import cn.hutool.extra.spring.SpringUtil; import cn.hutool.http.HttpUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageOptions; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageResult; @@ -148,6 +149,10 @@ public class AiImageServiceImpl implements AiImageService { .withModel(draw.getModel()).withN(1) .withHeight(draw.getHeight()).withWidth(draw.getWidth()) .build(); + } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.CHATGLM.getPlatform())) { + return ChatGlmImageOptions.builder() + .withModel(draw.getModel()) + .build(); } throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java index 596118168..2e302501c 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java @@ -28,6 +28,7 @@ public enum AiPlatformEnum { STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney SUNO("Suno", "Suno"), // Suno AI + CHATGLM("ChatGlm", "ChatGlm"), // Suno AI ; diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index 66a32167c..2dae410cb 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -9,6 +9,7 @@ import cn.hutool.extra.spring.SpringUtil; import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration; import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; +import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageModel; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; @@ -139,6 +140,8 @@ public class AiModelFactoryImpl implements AiModelFactory { return buildOpenAiImageModel(apiKey, url); case STABLE_DIFFUSION: return buildStabilityAiImageModel(apiKey, url); + case CHATGLM: + return buildChatGlmModel(apiKey); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); } @@ -273,4 +276,7 @@ public class AiModelFactoryImpl implements AiModelFactory { return new StabilityAiImageModel(stabilityAiApi); } + private ChatGlmImageModel buildChatGlmModel(String apiKey) { + return new ChatGlmImageModel(apiKey); + } }