Merge remote-tracking branch 'origin/master-jdk21-ai' into master-jdk21-ai

This commit is contained in:
cherishsince 2024-07-12 15:03:41 +08:00
commit 2ae90b9edc
15 changed files with 208 additions and 144 deletions

View File

@ -1,7 +1,7 @@
{
"local": {
"baseUrl": "http://127.0.0.1:48080/admin-api",
"token": "Bearer 1c2ce60de96a4fb0bf5bea9604099a3d",
"token": "test1",
"adminTenentId": "1",
"appApi": "http://127.0.0.1:48080/app-api",

View File

@ -39,6 +39,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
除此之外不要任何解释性语句
""");
// TODO @xin这个 role 是不是删除掉好点哈= = 目前主要是没做角色枚举这里多了 role 反倒容易误解哈
/**
* 角色
*/

View File

@ -60,9 +60,5 @@
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-test</artifactId>
</dependency>
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-excel</artifactId>
</dependency>
</dependencies>
</project>

View File

@ -12,8 +12,7 @@ import lombok.Data;
*
* @author xiaoxin
*/
// TODO @xin如果没 typehandler 的需求autoResultMap 可以去掉哈
@TableName(value = "ai_mind_map", autoResultMap = true)
@TableName(value = "ai_mind_map")
@Data
public class AiMindMapDO extends BaseDO {
@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO {
/**
* 用户编号
*
* <p>
* 关联 AdminUserDO userId 字段
*/
private Long userId;

View File

@ -5,7 +5,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
import org.apache.ibatis.annotations.Mapper;
/**
* AI 音乐 Mapper
* AI 思维导图 Mapper
*
* @author xiaoxin
*/

View File

@ -111,7 +111,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 创建 chat 需要的 Prompt
// 3.2 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);

View File

@ -31,6 +31,7 @@ import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
@ -142,6 +143,11 @@ public class AiImageServiceImpl implements AiImageService {
.withModel(draw.getModel()).withN(1)
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
return QianFanImageOptions.builder()
.withModel(draw.getModel()).withN(1)
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.build();
}
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
}

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.mindmap;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@ -31,13 +32,12 @@ import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
/**
* AI 写作 Service 实现类
* AI 思维导图 Service 实现类
*
* @author xiaoxin
*/
@ -57,38 +57,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
@Override
public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
// 1.1 获取脑图模型 尝试获取思维导图助手角色如果没有则使用默认模型
AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
AiChatModelDO model;
String systemMessage;
if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) {
model = chatModalService.getChatModel(mindMapRole.getModelId());
systemMessage = mindMapRole.getSystemMessage();
} else {
model = chatModalService.getRequiredDefaultChatModel();
systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
}
// 1. 获取脑图模型尝试获取思维导图助手角色如果没有则使用默认模型
AiChatRoleDO role = CollUtil.getFirst(
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
// 1.1 获取脑图执行模型
AiChatModelDO model = getModel(role);
// 1.2 获取角色设定消息
String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
// 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 2 插入思维导图信息
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
// 2. 插入思维导图信息
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
mindMapMapper.insert(mindMapDO);
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
// 3.1 角色设定
List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage));
}
// 3.2 用户输入
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
// 3.3 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions);
// 3.1 构建 Prompt并进行调用
Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 3.4 流式返回
// 3.2 流式返回
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@ -109,4 +99,36 @@ public class AiMindMapServiceImpl implements AiMindMapService {
}
private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
// 1. 构建 message 列表
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
// 2. 构建 options 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
return new Prompt(chatMessages, options);
}
private static List<Message> buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
List<Message> chatMessages = new ArrayList<>();
// 1. 角色设定
if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage));
}
// 2. 用户输入
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
return chatMessages;
}
private AiChatModelDO getModel(AiChatRoleDO role) {
AiChatModelDO model = null;
if (role != null && role.getModelId() != null) {
model = chatModalService.getChatModel(role.getModelId());
}
if (model != null) {
model = chatModalService.getRequiredDefaultChatModel();
}
Assert.notNull(model, "[AI] 获取不到模型");
return model;
}
}

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@ -67,19 +68,15 @@ public class AiWriteServiceImpl implements AiWriteService {
@Override
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
// 1.1 获取写作模型 尝试获取写作助手角色如果没有则使用默认模型
AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
// TODO @xin如果有 writeRole但是没 modeId是不是也可以用 systemMessage 建议的写法是先通过 modelId 获取 model如果 model == null chatModalService.getRequiredDefaultChatModel()如果还是 null则抛出异常然后systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理
AiChatModelDO model;
String systemMessage;
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
model = chatModalService.getChatModel(writeRole.getModelId());
systemMessage = writeRole.getSystemMessage();
} else {
model = chatModalService.getRequiredDefaultChatModel();
systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
}
// 1.2 校验平台
// 1 获取写作模型尝试获取写作助手角色没有则使用默认模型
AiChatRoleDO writeRole = CollUtil.getFirst(
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
// 1.1 获取写作执行模型
AiChatModelDO model = getModel(writeRole);
// 1.2 获取角色设定消息
String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
// 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
@ -88,21 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService {
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
writeMapper.insert(writeDO);
// 3. 调用大模型写作生成
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
// 3.1 角色设定
// TODO @xin要不把 90 97 这部分合并到一个方法里目的是让这个方法的主干更明确
List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage));
}
// 3.2 用户输入
chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
// 3.3 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions);
// 3.1 构建 Prompt并进行调用
Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 4. 流式返回
// 3.2 流式返回
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@ -122,7 +109,39 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
}
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
private AiChatModelDO getModel(AiChatRoleDO writeRole) {
AiChatModelDO model = null;
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
model = chatModalService.getChatModel(writeRole.getModelId());
}
if (Objects.isNull(model)) {
model = chatModalService.getRequiredDefaultChatModel();
}
Assert.notNull(model, "[AI] 获取不到模型");
return model;
}
private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
// 1. 构建 message 列表
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
// 2. 构建 options 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
return new Prompt(chatMessages, options);
}
private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) {
// 1.1 角色设定
chatMessages.add(new SystemMessage(systemMessage));
}
// 1.2 用户输入
chatMessages.add(new UserMessage(buildUserMessage(generateReqVO)));
return chatMessages;
}
private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) {
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());

View File

@ -18,12 +18,15 @@ import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
@ -111,6 +114,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return SpringUtil.getBean(TongYiImagesModel.class);
case YI_YAN:
return SpringUtil.getBean(QianFanImageModel.class);
case OPENAI:
return SpringUtil.getBean(OpenAiImageModel.class);
case STABLE_DIFFUSION:
@ -124,14 +131,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return buildTongYiImagesModel(apiKey);
case YI_YAN:
return buildQianFanImageModel(apiKey);
case OPENAI:
return buildOpenAiImageModel(apiKey, url);
case STABLE_DIFFUSION:
return buildStabilityAiImageModel(apiKey, url);
case TONG_YI:
return SpringUtil.getBean(TongYiImagesModel.class);
case YI_YAN:
return buildQianFanImageModel(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@ -175,6 +182,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
}
private static TongYiImagesModel buildTongYiImagesModel(String key) {
ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class);
TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class);
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
connectionProperties.setApiKey(key);
return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties);
}
/**
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
@ -187,6 +202,18 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new QianFanChatModel(qianFanApi);
}
/**
* 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private QianFanImageModel buildQianFanImageModel(String key) {
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
String appKey = keys.get(0);
String secretKey = keys.get(1);
QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey);
return new QianFanImageModel(qianFanApi);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
*/
@ -246,8 +273,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new StabilityAiImageModel(stabilityAiApi);
}
private QianFanImageModel buildQianFanImageModel(String key) {
List<String> keys = StrUtil.split(key, '|');
return new QianFanImageModel(new QianFanImageApi(keys.get(0), keys.get(1)));
}
}

View File

@ -21,7 +21,7 @@ public class OpenAiImageModelTests {
"https://api.holdai.top",
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf",
RestClient.builder());
private final OpenAiImageModel imageClient = new OpenAiImageModel(imageApi);
private final OpenAiImageModel imageModel = new OpenAiImageModel(imageApi);
@Test
@Disabled
@ -34,7 +34,7 @@ public class OpenAiImageModelTests {
ImagePrompt prompt = new ImagePrompt("中国长城!", options);
// 方法调用
ImageResponse response = imageClient.call(prompt);
ImageResponse response = imageModel.call(prompt);
// 打印结果
System.out.println(response);
}

View File

@ -1,48 +1,42 @@
package cn.iocoder.yudao.framework.ai.image;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImageOptionsBuilder;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.qianfan.QianFanImageModel;
import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.qianfan.api.QianFanApi;
import org.springframework.ai.qianfan.api.QianFanImageApi;
import static cn.iocoder.yudao.framework.ai.image.StabilityAiImageModelTests.viewImage;
/**
* 百度千帆 image
* {@link QianFanImageModel} 集成测试类
*/
public class QianFanImageTests {
@Test
public void callTest() {
// todo @芋艿 千帆sdk有个错误暂时没找到问题
QianFanImageApi qianFanImageApi = new QianFanImageApi(
"ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
QianFanImageModel qianFanImageModel = new QianFanImageModel(qianFanImageApi);
private final QianFanImageApi imageApi = new QianFanImageApi(
"qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
private final QianFanImageModel imageModel = new QianFanImageModel(imageApi);
@Test
@Disabled
public void testCall() {
// 准备参数
// 只支持 1024x1024768x768768x10241024x768576x10241024x576
QianFanImageOptions imageOptions = QianFanImageOptions.builder()
.withWidth(512)
.withHeight(512)
.withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
.withWidth(1024).withHeight(1024)
.withN(1)
.build();
ImagePrompt imagePrompt = new ImagePrompt("薄涂炫酷少女头像,田野花朵盛开", imageOptions);
ImageResponse call = qianFanImageModel.call(imagePrompt);
System.err.println(JsonUtils.toJsonString(call));
}
ImagePrompt prompt = new ImagePrompt("good", imageOptions);
@Test
public void call2Test() {
// 官方测试 test https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java
var options = ImageOptionsBuilder.builder().withHeight(1024).withWidth(1024).build();
var instructions = "薄涂炫酷少女头像,田野花朵盛开";
ImagePrompt imagePrompt = new ImagePrompt(instructions, options);
QianFanImageApi qianFanImageApi = new QianFanImageApi(
"ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
QianFanImageModel imageModel = new QianFanImageModel(qianFanImageApi);
ImageResponse imageResponse = imageModel.call(imagePrompt);
// 方法调用
ImageResponse response = imageModel.call(prompt);
// 打印结果
String b64Json = response.getResult().getOutput().getB64Json();
System.out.println(response);
viewImage(b64Json);
}
}

View File

@ -24,7 +24,7 @@ public class StabilityAiImageModelTests {
private final StabilityAiApi imageApi = new StabilityAiApi(
"sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx");
private final StabilityAiImageModel imageClient = new StabilityAiImageModel(imageApi);
private final StabilityAiImageModel imageModel = new StabilityAiImageModel(imageApi);
@Test
@Disabled
@ -37,7 +37,7 @@ public class StabilityAiImageModelTests {
ImagePrompt prompt = new ImagePrompt("great wall", options);
// 方法调用
ImageResponse response = imageClient.call(prompt);
ImageResponse response = imageModel.call(prompt);
// 打印结果
String b64Json = response.getResult().getOutput().getB64Json();
System.out.println(response);

View File

@ -0,0 +1,43 @@
package cn.iocoder.yudao.framework.ai.image;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.utils.Constants;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions;
/**
* {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类
*
* @author fansili
*/
public class TongYiImagesModelTest {
private final ImageSynthesis imageApi = new ImageSynthesis();
private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi);
static {
Constants.apiKey = "sk-Zsd81gZYg7";
}
@Test
@Disabled
public void imageCallTest() {
// 准备参数
ImageOptions options = OpenAiImageOptions.builder()
.withModel(ImageSynthesis.Models.WANX_V1)
.withHeight(256).withWidth(256)
.build();
ImagePrompt prompt = new ImagePrompt("中国长城!", options);
// 方法调用
ImageResponse response = imageModel.call(prompt);
// 打印结果
System.out.println(response);
}
}

View File

@ -1,39 +0,0 @@
package cn.iocoder.yudao.framework.ai.image;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.utils.Constants;
import com.alibaba.fastjson.JSON;
import org.junit.jupiter.api.Test;
import java.util.Map;
// TODO @fan改成 TongYiImagesModel
/**
* 通义万象
*/
public class TongYiImagesModelTests {
@Test
public void imageCallTest() throws NoApiKeyException {
// 设置 api key
Constants.apiKey = "sk-Zsd81gZYg7";
ImageSynthesisParam param =
ImageSynthesisParam.builder()
.model(ImageSynthesis.Models.WANX_V1)
.n(4)
.size("1024*1024")
.prompt("雄鹰自由自在的在蓝天白云下飞翔")
.build();
// 创建 ImageSynthesis
ImageSynthesis is = new ImageSynthesis();
// 调用 call 生成 image
ImageSynthesisResult call = is.call(param);
System.err.println(JSON.toJSON(call));
for (Map<String, String> result : call.getOutput().getResults()) {
System.err.println("地址: " + result.get("url"));
}
}
}