diff --git a/script/idea/http-client.env.json b/script/idea/http-client.env.json index 17dd0d50d..4a4cb5221 100644 --- a/script/idea/http-client.env.json +++ b/script/idea/http-client.env.json @@ -1,7 +1,7 @@ { "local": { "baseUrl": "http://127.0.0.1:48080/admin-api", - "token": "Bearer 1c2ce60de96a4fb0bf5bea9604099a3d", + "token": "test1", "adminTenentId": "1", "appApi": "http://127.0.0.1:48080/app-api", diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java index ad3641421..19cbc8f8f 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java @@ -39,6 +39,7 @@ public enum AiChatRoleEnum implements IntArrayValuable { 除此之外不要任何解释性语句。 """); + // TODO @xin:这个 role 是不是删除掉好点哈。= = 目前主要是没做角色枚举。这里多了 role 反倒容易误解哈 /** * 角色 */ diff --git a/yudao-module-ai/yudao-module-ai-biz/pom.xml b/yudao-module-ai/yudao-module-ai-biz/pom.xml index a537b3db7..7c529f118 100644 --- a/yudao-module-ai/yudao-module-ai-biz/pom.xml +++ b/yudao-module-ai/yudao-module-ai-biz/pom.xml @@ -60,9 +60,5 @@ cn.iocoder.boot yudao-spring-boot-starter-test - - cn.iocoder.boot - yudao-spring-boot-starter-excel - \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java index 92222b590..0442a52d7 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java @@ -12,8 +12,7 @@ import lombok.Data; * * @author xiaoxin */ -// TODO @xin:如果没 typehandler 的需求,autoResultMap 可以去掉哈 -@TableName(value = "ai_mind_map", autoResultMap = true) +@TableName(value = "ai_mind_map") @Data public class AiMindMapDO extends BaseDO { @@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO { /** * 用户编号 - * + *

* 关联 AdminUserDO 的 userId 字段 */ private Long userId; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java index 54fa7235a..ff25e89ff 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java @@ -5,7 +5,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; import org.apache.ibatis.annotations.Mapper; /** - * AI 音乐 Mapper + * AI 思维导图 Mapper * * @author xiaoxin */ diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 6c8cdeaca..72fa06a79 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -111,7 +111,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); - // 3.2 创建 chat 需要的 Prompt + // 3.2 构建 Prompt,并进行调用 Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); Flux streamResponse = chatModel.stream(prompt); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 7ea629e11..02c1ab334 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -31,6 +31,7 @@ import org.springframework.ai.image.ImageOptions; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.openai.OpenAiImageOptions; +import org.springframework.ai.qianfan.QianFanImageOptions; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; @@ -142,6 +143,11 @@ public class AiImageServiceImpl implements AiImageService { .withModel(draw.getModel()).withN(1) .withHeight(draw.getHeight()).withWidth(draw.getWidth()) .build(); + } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) { + return QianFanImageOptions.builder() + .withModel(draw.getModel()).withN(1) + .withHeight(draw.getHeight()).withWidth(draw.getWidth()) + .build(); } throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java index 7d96c70d2..72be20c54 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.ai.service.mindmap; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.lang.Assert; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; @@ -31,13 +32,12 @@ import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; -import java.util.Objects; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; /** - * AI 写作 Service 实现类 + * AI 思维导图 Service 实现类 * * @author xiaoxin */ @@ -57,38 +57,28 @@ public class AiMindMapServiceImpl implements AiMindMapService { @Override public Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { - // 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型 - AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); - AiChatModelDO model; - String systemMessage; - if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) { - model = chatModalService.getChatModel(mindMapRole.getModelId()); - systemMessage = mindMapRole.getSystemMessage(); - } else { - model = chatModalService.getRequiredDefaultChatModel(); - systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); - } - + // 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型 + AiChatRoleDO role = CollUtil.getFirst( + chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); + // 1.1 获取脑图执行模型 + AiChatModelDO model = getModel(role); + // 1.2 获取角色设定消息 + String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage()) + ? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); + // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); - // 2 插入思维导图信息 - AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); + // 2. 插入思维导图信息 + AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, + mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); mindMapMapper.insert(mindMapDO); - ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); - // 3.1 角色设定 - List chatMessages = new ArrayList<>(); - if (StrUtil.isNotBlank(systemMessage)) { - chatMessages.add(new SystemMessage(systemMessage)); - } - // 3.2 用户输入 - chatMessages.add(new UserMessage(generateReqVO.getPrompt())); - // 3.3 构建提示词 - Prompt prompt = new Prompt(chatMessages, chatOptions); - + // 3.1 构建 Prompt,并进行调用 + Prompt prompt = buildPrompt(generateReqVO, model, systemMessage); Flux streamResponse = chatModel.stream(prompt); - // 3.4 流式返回 + + // 3.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; @@ -109,4 +99,36 @@ public class AiMindMapServiceImpl implements AiMindMapService { } + private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) { + // 1. 构建 message 列表 + List chatMessages = buildMessages(generateReqVO, systemMessage); + // 2. 构建 options 对象 + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); + return new Prompt(chatMessages, options); + } + + private static List buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) { + List chatMessages = new ArrayList<>(); + // 1. 角色设定 + if (StrUtil.isNotBlank(systemMessage)) { + chatMessages.add(new SystemMessage(systemMessage)); + } + // 2. 用户输入 + chatMessages.add(new UserMessage(generateReqVO.getPrompt())); + return chatMessages; + } + + private AiChatModelDO getModel(AiChatRoleDO role) { + AiChatModelDO model = null; + if (role != null && role.getModelId() != null) { + model = chatModalService.getChatModel(role.getModelId()); + } + if (model != null) { + model = chatModalService.getRequiredDefaultChatModel(); + } + Assert.notNull(model, "[AI] 获取不到模型"); + return model; + } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index b03a90ab7..2fae31d59 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.ai.service.write; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.lang.Assert; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; @@ -67,19 +68,15 @@ public class AiWriteServiceImpl implements AiWriteService { @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { - // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型 - AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); - // TODO @xin:如果有 writeRole,但是没 modeId,是不是也可以用 systemMessage 哈?建议的写法是:先通过 modelId 获取 model。如果 model == null,则 chatModalService.getRequiredDefaultChatModel();如果还是 null,则抛出异常;。。。。。。。。。。。。。。然后,systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理。 - AiChatModelDO model; - String systemMessage; - if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { - model = chatModalService.getChatModel(writeRole.getModelId()); - systemMessage = writeRole.getSystemMessage(); - } else { - model = chatModalService.getRequiredDefaultChatModel(); - systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage(); - } - // 1.2 校验平台 + // 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型 + AiChatRoleDO writeRole = CollUtil.getFirst( + chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); + // 1.1 获取写作执行模型 + AiChatModelDO model = getModel(writeRole); + // 1.2 获取角色设定消息 + String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage()) + ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage(); + // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); @@ -88,21 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService { write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel())); writeMapper.insert(writeDO); - // 3. 调用大模型,写作生成 - ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); - // 3.1 角色设定 - // TODO @xin:要不把 90 到 97 这部分,合并到一个方法里。目的是:让这个方法的主干更明确 - List chatMessages = new ArrayList<>(); - if (StrUtil.isNotBlank(systemMessage)) { - chatMessages.add(new SystemMessage(systemMessage)); - } - // 3.2 用户输入 - chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO))); - // 3.3 构建提示词 - Prompt prompt = new Prompt(chatMessages, chatOptions); + // 3.1 构建 Prompt,并进行调用 + Prompt prompt = buildPrompt(generateReqVO, model, systemMessage); Flux streamResponse = chatModel.stream(prompt); - // 4. 流式返回 + // 3.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; @@ -122,7 +109,39 @@ public class AiWriteServiceImpl implements AiWriteService { }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); } - private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { + private AiChatModelDO getModel(AiChatRoleDO writeRole) { + AiChatModelDO model = null; + if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { + model = chatModalService.getChatModel(writeRole.getModelId()); + } + if (Objects.isNull(model)) { + model = chatModalService.getRequiredDefaultChatModel(); + } + Assert.notNull(model, "[AI] 获取不到模型"); + return model; + } + + private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) { + // 1. 构建 message 列表 + List chatMessages = buildMessages(generateReqVO, systemMessage); + // 2. 构建 options 对象 + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); + return new Prompt(chatMessages, options); + } + + private List buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) { + List chatMessages = new ArrayList<>(); + if (StrUtil.isNotBlank(systemMessage)) { + // 1.1 角色设定 + chatMessages.add(new SystemMessage(systemMessage)); + } + // 1.2 用户输入 + chatMessages.add(new UserMessage(buildUserMessage(generateReqVO))); + return chatMessages; + } + + private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) { String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage()); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index fbf835707..66a32167c 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -18,12 +18,15 @@ import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties; import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel; +import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties; import com.alibaba.dashscope.aigc.generation.Generation; +import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties; import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties; +import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; @@ -111,6 +114,10 @@ public class AiModelFactoryImpl implements AiModelFactory { public ImageModel getDefaultImageModel(AiPlatformEnum platform) { //noinspection EnhancedSwitchMigration switch (platform) { + case TONG_YI: + return SpringUtil.getBean(TongYiImagesModel.class); + case YI_YAN: + return SpringUtil.getBean(QianFanImageModel.class); case OPENAI: return SpringUtil.getBean(OpenAiImageModel.class); case STABLE_DIFFUSION: @@ -124,14 +131,14 @@ public class AiModelFactoryImpl implements AiModelFactory { public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) { //noinspection EnhancedSwitchMigration switch (platform) { + case TONG_YI: + return buildTongYiImagesModel(apiKey); + case YI_YAN: + return buildQianFanImageModel(apiKey); case OPENAI: return buildOpenAiImageModel(apiKey, url); case STABLE_DIFFUSION: return buildStabilityAiImageModel(apiKey, url); - case TONG_YI: - return SpringUtil.getBean(TongYiImagesModel.class); - case YI_YAN: - return buildQianFanImageModel(apiKey); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); } @@ -175,6 +182,14 @@ public class AiModelFactoryImpl implements AiModelFactory { return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); } + private static TongYiImagesModel buildTongYiImagesModel(String key) { + ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class); + TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class); + TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties(); + connectionProperties.setApiKey(key); + return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties); + } + /** * 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} */ @@ -187,6 +202,18 @@ public class AiModelFactoryImpl implements AiModelFactory { return new QianFanChatModel(qianFanApi); } + /** + * 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} + */ + private QianFanImageModel buildQianFanImageModel(String key) { + List keys = StrUtil.split(key, '|'); + Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式"); + String appKey = keys.get(0); + String secretKey = keys.get(1); + QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey); + return new QianFanImageModel(qianFanApi); + } + /** * 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)} */ @@ -246,8 +273,4 @@ public class AiModelFactoryImpl implements AiModelFactory { return new StabilityAiImageModel(stabilityAiApi); } - private QianFanImageModel buildQianFanImageModel(String key) { - List keys = StrUtil.split(key, '|'); - return new QianFanImageModel(new QianFanImageApi(keys.get(0), keys.get(1))); - } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java index 740978e60..c9b07d9ff 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java @@ -21,7 +21,7 @@ public class OpenAiImageModelTests { "https://api.holdai.top", "sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf", RestClient.builder()); - private final OpenAiImageModel imageClient = new OpenAiImageModel(imageApi); + private final OpenAiImageModel imageModel = new OpenAiImageModel(imageApi); @Test @Disabled @@ -34,7 +34,7 @@ public class OpenAiImageModelTests { ImagePrompt prompt = new ImagePrompt("中国长城!", options); // 方法调用 - ImageResponse response = imageClient.call(prompt); + ImageResponse response = imageModel.call(prompt); // 打印结果 System.out.println(response); } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java index b8de6f486..22bf6614e 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java @@ -1,48 +1,42 @@ package cn.iocoder.yudao.framework.ai.image; -import cn.iocoder.yudao.framework.common.util.json.JsonUtils; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.qianfan.QianFanImageModel; import org.springframework.ai.qianfan.QianFanImageOptions; -import org.springframework.ai.qianfan.api.QianFanApi; import org.springframework.ai.qianfan.api.QianFanImageApi; +import static cn.iocoder.yudao.framework.ai.image.StabilityAiImageModelTests.viewImage; + /** - * 百度千帆 image + * {@link QianFanImageModel} 集成测试类 */ public class QianFanImageTests { - @Test - public void callTest() { - // todo @芋艿 千帆sdk有个错误,暂时没找到问题 - QianFanImageApi qianFanImageApi = new QianFanImageApi( - "ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb"); - QianFanImageModel qianFanImageModel = new QianFanImageModel(qianFanImageApi); + private final QianFanImageApi imageApi = new QianFanImageApi( + "qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e"); + private final QianFanImageModel imageModel = new QianFanImageModel(imageApi); + @Test + @Disabled + public void testCall() { + // 准备参数 + // 只支持 1024x1024、768x768、768x1024、1024x768、576x1024、1024x576 QianFanImageOptions imageOptions = QianFanImageOptions.builder() - .withWidth(512) - .withHeight(512) + .withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue()) + .withWidth(1024).withHeight(1024) + .withN(1) .build(); - ImagePrompt imagePrompt = new ImagePrompt("薄涂炫酷少女头像,田野花朵盛开", imageOptions); - ImageResponse call = qianFanImageModel.call(imagePrompt); - System.err.println(JsonUtils.toJsonString(call)); - } + ImagePrompt prompt = new ImagePrompt("good", imageOptions); - @Test - public void call2Test() { - // 官方测试 test https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java - var options = ImageOptionsBuilder.builder().withHeight(1024).withWidth(1024).build(); - var instructions = "薄涂炫酷少女头像,田野花朵盛开"; - - ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - - QianFanImageApi qianFanImageApi = new QianFanImageApi( - "ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb"); - QianFanImageModel imageModel = new QianFanImageModel(qianFanImageApi); - ImageResponse imageResponse = imageModel.call(imagePrompt); + // 方法调用 + ImageResponse response = imageModel.call(prompt); + // 打印结果 + String b64Json = response.getResult().getOutput().getB64Json(); + System.out.println(response); + viewImage(b64Json); } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java index cb7412821..7ee7e6044 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java @@ -24,7 +24,7 @@ public class StabilityAiImageModelTests { private final StabilityAiApi imageApi = new StabilityAiApi( "sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx"); - private final StabilityAiImageModel imageClient = new StabilityAiImageModel(imageApi); + private final StabilityAiImageModel imageModel = new StabilityAiImageModel(imageApi); @Test @Disabled @@ -37,7 +37,7 @@ public class StabilityAiImageModelTests { ImagePrompt prompt = new ImagePrompt("great wall", options); // 方法调用 - ImageResponse response = imageClient.call(prompt); + ImageResponse response = imageModel.call(prompt); // 打印结果 String b64Json = response.getResult().getOutput().getB64Json(); System.out.println(response); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java new file mode 100644 index 000000000..41d7859c4 --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java @@ -0,0 +1,43 @@ +package cn.iocoder.yudao.framework.ai.image; + +import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel; +import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; +import com.alibaba.dashscope.utils.Constants; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.openai.OpenAiImageOptions; + +/** + * {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类 + * + * @author fansili + */ +public class TongYiImagesModelTest { + + private final ImageSynthesis imageApi = new ImageSynthesis(); + private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi); + + static { + Constants.apiKey = "sk-Zsd81gZYg7"; + } + + @Test + @Disabled + public void imageCallTest() { + // 准备参数 + ImageOptions options = OpenAiImageOptions.builder() + .withModel(ImageSynthesis.Models.WANX_V1) + .withHeight(256).withWidth(256) + .build(); + ImagePrompt prompt = new ImagePrompt("中国长城!", options); + + // 方法调用 + ImageResponse response = imageModel.call(prompt); + // 打印结果 + System.out.println(response); + } + +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java deleted file mode 100644 index 7f44873b5..000000000 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java +++ /dev/null @@ -1,39 +0,0 @@ -package cn.iocoder.yudao.framework.ai.image; - -import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; -import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam; -import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult; -import com.alibaba.dashscope.exception.NoApiKeyException; -import com.alibaba.dashscope.utils.Constants; -import com.alibaba.fastjson.JSON; -import org.junit.jupiter.api.Test; - -import java.util.Map; - -// TODO @fan:改成 TongYiImagesModel 哈 -/** - * 通义万象 - */ -public class TongYiImagesModelTests { - - @Test - public void imageCallTest() throws NoApiKeyException { - // 设置 api key - Constants.apiKey = "sk-Zsd81gZYg7"; - ImageSynthesisParam param = - ImageSynthesisParam.builder() - .model(ImageSynthesis.Models.WANX_V1) - .n(4) - .size("1024*1024") - .prompt("雄鹰自由自在的在蓝天白云下飞翔") - .build(); - // 创建 ImageSynthesis - ImageSynthesis is = new ImageSynthesis(); - // 调用 call 生成 image - ImageSynthesisResult call = is.call(param); - System.err.println(JSON.toJSON(call)); - for (Map result : call.getOutput().getResults()) { - System.err.println("地址: " + result.get("url")); - } - } -}