Merge remote-tracking branch 'origin/master-jdk21-ai' into master-jdk21-ai

This commit is contained in:
cherishsince 2024-07-12 15:03:41 +08:00
commit 2ae90b9edc
15 changed files with 208 additions and 144 deletions

View File

@ -1,7 +1,7 @@
{ {
"local": { "local": {
"baseUrl": "http://127.0.0.1:48080/admin-api", "baseUrl": "http://127.0.0.1:48080/admin-api",
"token": "Bearer 1c2ce60de96a4fb0bf5bea9604099a3d", "token": "test1",
"adminTenentId": "1", "adminTenentId": "1",
"appApi": "http://127.0.0.1:48080/app-api", "appApi": "http://127.0.0.1:48080/app-api",

View File

@ -39,6 +39,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
除此之外不要任何解释性语句 除此之外不要任何解释性语句
"""); """);
// TODO @xin这个 role 是不是删除掉好点哈= = 目前主要是没做角色枚举这里多了 role 反倒容易误解哈
/** /**
* 角色 * 角色
*/ */

View File

@ -60,9 +60,5 @@
<groupId>cn.iocoder.boot</groupId> <groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-test</artifactId> <artifactId>yudao-spring-boot-starter-test</artifactId>
</dependency> </dependency>
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-excel</artifactId>
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -12,8 +12,7 @@ import lombok.Data;
* *
* @author xiaoxin * @author xiaoxin
*/ */
// TODO @xin如果没 typehandler 的需求autoResultMap 可以去掉哈 @TableName(value = "ai_mind_map")
@TableName(value = "ai_mind_map", autoResultMap = true)
@Data @Data
public class AiMindMapDO extends BaseDO { public class AiMindMapDO extends BaseDO {
@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO {
/** /**
* 用户编号 * 用户编号
* * <p>
* 关联 AdminUserDO userId 字段 * 关联 AdminUserDO userId 字段
*/ */
private Long userId; private Long userId;

View File

@ -5,7 +5,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
/** /**
* AI 音乐 Mapper * AI 思维导图 Mapper
* *
* @author xiaoxin * @author xiaoxin
*/ */

View File

@ -111,7 +111,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 创建 chat 需要的 Prompt // 3.2 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);

View File

@ -31,6 +31,7 @@ import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -142,6 +143,11 @@ public class AiImageServiceImpl implements AiImageService {
.withModel(draw.getModel()).withN(1) .withModel(draw.getModel()).withN(1)
.withHeight(draw.getHeight()).withWidth(draw.getWidth()) .withHeight(draw.getHeight()).withWidth(draw.getWidth())
.build(); .build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
return QianFanImageOptions.builder()
.withModel(draw.getModel()).withN(1)
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.build();
} }
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
} }

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.mindmap; package cn.iocoder.yudao.module.ai.service.mindmap;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@ -31,13 +32,12 @@ import reactor.core.publisher.Flux;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
/** /**
* AI 写作 Service 实现类 * AI 思维导图 Service 实现类
* *
* @author xiaoxin * @author xiaoxin
*/ */
@ -57,38 +57,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
@Override @Override
public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
// 1.1 获取脑图模型 尝试获取思维导图助手角色如果没有则使用默认模型 // 1. 获取脑图模型尝试获取思维导图助手角色如果没有则使用默认模型
AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); AiChatRoleDO role = CollUtil.getFirst(
AiChatModelDO model; chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
String systemMessage; // 1.1 获取脑图执行模型
if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) { AiChatModelDO model = getModel(role);
model = chatModalService.getChatModel(mindMapRole.getModelId()); // 1.2 获取角色设定消息
systemMessage = mindMapRole.getSystemMessage(); String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
} else { ? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
model = chatModalService.getRequiredDefaultChatModel(); // 1.3 校验平台
systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
}
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 2 插入思维导图信息 // 2. 插入思维导图信息
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
mindMapMapper.insert(mindMapDO); mindMapMapper.insert(mindMapDO);
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); // 3.1 构建 Prompt并进行调用
// 3.1 角色设定 Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage));
}
// 3.2 用户输入
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
// 3.3 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 3.4 流式返回
// 3.2 流式返回
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> { return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@ -109,4 +99,36 @@ public class AiMindMapServiceImpl implements AiMindMapService {
} }
private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
// 1. 构建 message 列表
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
// 2. 构建 options 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
return new Prompt(chatMessages, options);
}
private static List<Message> buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
List<Message> chatMessages = new ArrayList<>();
// 1. 角色设定
if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage));
}
// 2. 用户输入
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
return chatMessages;
}
private AiChatModelDO getModel(AiChatRoleDO role) {
AiChatModelDO model = null;
if (role != null && role.getModelId() != null) {
model = chatModalService.getChatModel(role.getModelId());
}
if (model != null) {
model = chatModalService.getRequiredDefaultChatModel();
}
Assert.notNull(model, "[AI] 获取不到模型");
return model;
}
} }

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.write; package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@ -67,19 +68,15 @@ public class AiWriteServiceImpl implements AiWriteService {
@Override @Override
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
// 1.1 获取写作模型 尝试获取写作助手角色如果没有则使用默认模型 // 1 获取写作模型尝试获取写作助手角色没有则使用默认模型
AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); AiChatRoleDO writeRole = CollUtil.getFirst(
// TODO @xin如果有 writeRole但是没 modeId是不是也可以用 systemMessage 建议的写法是先通过 modelId 获取 model如果 model == null chatModalService.getRequiredDefaultChatModel()如果还是 null则抛出异常然后systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理 chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
AiChatModelDO model; // 1.1 获取写作执行模型
String systemMessage; AiChatModelDO model = getModel(writeRole);
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { // 1.2 获取角色设定消息
model = chatModalService.getChatModel(writeRole.getModelId()); String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
systemMessage = writeRole.getSystemMessage(); ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
} else { // 1.3 校验平台
model = chatModalService.getRequiredDefaultChatModel();
systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
}
// 1.2 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
@ -88,21 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService {
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel())); write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
writeMapper.insert(writeDO); writeMapper.insert(writeDO);
// 3. 调用大模型写作生成 // 3.1 构建 Prompt并进行调用
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
// 3.1 角色设定
// TODO @xin要不把 90 97 这部分合并到一个方法里目的是让这个方法的主干更明确
List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) {
chatMessages.add(new SystemMessage(systemMessage));
}
// 3.2 用户输入
chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
// 3.3 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 4. 流式返回 // 3.2 流式返回
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> { return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@ -122,7 +109,39 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
} }
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { private AiChatModelDO getModel(AiChatRoleDO writeRole) {
AiChatModelDO model = null;
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
model = chatModalService.getChatModel(writeRole.getModelId());
}
if (Objects.isNull(model)) {
model = chatModalService.getRequiredDefaultChatModel();
}
Assert.notNull(model, "[AI] 获取不到模型");
return model;
}
private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
// 1. 构建 message 列表
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
// 2. 构建 options 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
return new Prompt(chatMessages, options);
}
private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
List<Message> chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) {
// 1.1 角色设定
chatMessages.add(new SystemMessage(systemMessage));
}
// 1.2 用户输入
chatMessages.add(new UserMessage(buildUserMessage(generateReqVO)));
return chatMessages;
}
private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) {
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage()); String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());

View File

@ -18,12 +18,15 @@ import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel; import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation; import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties; import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties; import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
@ -111,6 +114,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ImageModel getDefaultImageModel(AiPlatformEnum platform) { public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case TONG_YI:
return SpringUtil.getBean(TongYiImagesModel.class);
case YI_YAN:
return SpringUtil.getBean(QianFanImageModel.class);
case OPENAI: case OPENAI:
return SpringUtil.getBean(OpenAiImageModel.class); return SpringUtil.getBean(OpenAiImageModel.class);
case STABLE_DIFFUSION: case STABLE_DIFFUSION:
@ -124,14 +131,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) { public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration //noinspection EnhancedSwitchMigration
switch (platform) { switch (platform) {
case TONG_YI:
return buildTongYiImagesModel(apiKey);
case YI_YAN:
return buildQianFanImageModel(apiKey);
case OPENAI: case OPENAI:
return buildOpenAiImageModel(apiKey, url); return buildOpenAiImageModel(apiKey, url);
case STABLE_DIFFUSION: case STABLE_DIFFUSION:
return buildStabilityAiImageModel(apiKey, url); return buildStabilityAiImageModel(apiKey, url);
case TONG_YI:
return SpringUtil.getBean(TongYiImagesModel.class);
case YI_YAN:
return buildQianFanImageModel(apiKey);
default: default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
} }
@ -175,6 +182,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
} }
private static TongYiImagesModel buildTongYiImagesModel(String key) {
ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class);
TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class);
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
connectionProperties.setApiKey(key);
return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties);
}
/** /**
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} * 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/ */
@ -187,6 +202,18 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new QianFanChatModel(qianFanApi); return new QianFanChatModel(qianFanApi);
} }
/**
* 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private QianFanImageModel buildQianFanImageModel(String key) {
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
String appKey = keys.get(0);
String secretKey = keys.get(1);
QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey);
return new QianFanImageModel(qianFanApi);
}
/** /**
* 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)} * 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
*/ */
@ -246,8 +273,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new StabilityAiImageModel(stabilityAiApi); return new StabilityAiImageModel(stabilityAiApi);
} }
private QianFanImageModel buildQianFanImageModel(String key) {
List<String> keys = StrUtil.split(key, '|');
return new QianFanImageModel(new QianFanImageApi(keys.get(0), keys.get(1)));
}
} }

View File

@ -21,7 +21,7 @@ public class OpenAiImageModelTests {
"https://api.holdai.top", "https://api.holdai.top",
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf", "sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf",
RestClient.builder()); RestClient.builder());
private final OpenAiImageModel imageClient = new OpenAiImageModel(imageApi); private final OpenAiImageModel imageModel = new OpenAiImageModel(imageApi);
@Test @Test
@Disabled @Disabled
@ -34,7 +34,7 @@ public class OpenAiImageModelTests {
ImagePrompt prompt = new ImagePrompt("中国长城!", options); ImagePrompt prompt = new ImagePrompt("中国长城!", options);
// 方法调用 // 方法调用
ImageResponse response = imageClient.call(prompt); ImageResponse response = imageModel.call(prompt);
// 打印结果 // 打印结果
System.out.println(response); System.out.println(response);
} }

View File

@ -1,48 +1,42 @@
package cn.iocoder.yudao.framework.ai.image; package cn.iocoder.yudao.framework.ai.image;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImageOptionsBuilder;
import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.qianfan.QianFanImageModel; import org.springframework.ai.qianfan.QianFanImageModel;
import org.springframework.ai.qianfan.QianFanImageOptions; import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.qianfan.api.QianFanApi;
import org.springframework.ai.qianfan.api.QianFanImageApi; import org.springframework.ai.qianfan.api.QianFanImageApi;
import static cn.iocoder.yudao.framework.ai.image.StabilityAiImageModelTests.viewImage;
/** /**
* 百度千帆 image * {@link QianFanImageModel} 集成测试类
*/ */
public class QianFanImageTests { public class QianFanImageTests {
@Test private final QianFanImageApi imageApi = new QianFanImageApi(
public void callTest() { "qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
// todo @芋艿 千帆sdk有个错误暂时没找到问题 private final QianFanImageModel imageModel = new QianFanImageModel(imageApi);
QianFanImageApi qianFanImageApi = new QianFanImageApi(
"ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
QianFanImageModel qianFanImageModel = new QianFanImageModel(qianFanImageApi);
@Test
@Disabled
public void testCall() {
// 准备参数
// 只支持 1024x1024768x768768x10241024x768576x10241024x576
QianFanImageOptions imageOptions = QianFanImageOptions.builder() QianFanImageOptions imageOptions = QianFanImageOptions.builder()
.withWidth(512) .withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
.withHeight(512) .withWidth(1024).withHeight(1024)
.withN(1)
.build(); .build();
ImagePrompt imagePrompt = new ImagePrompt("薄涂炫酷少女头像,田野花朵盛开", imageOptions); ImagePrompt prompt = new ImagePrompt("good", imageOptions);
ImageResponse call = qianFanImageModel.call(imagePrompt);
System.err.println(JsonUtils.toJsonString(call));
}
@Test // 方法调用
public void call2Test() { ImageResponse response = imageModel.call(prompt);
// 官方测试 test https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java // 打印结果
var options = ImageOptionsBuilder.builder().withHeight(1024).withWidth(1024).build(); String b64Json = response.getResult().getOutput().getB64Json();
var instructions = "薄涂炫酷少女头像,田野花朵盛开"; System.out.println(response);
viewImage(b64Json);
ImagePrompt imagePrompt = new ImagePrompt(instructions, options);
QianFanImageApi qianFanImageApi = new QianFanImageApi(
"ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
QianFanImageModel imageModel = new QianFanImageModel(qianFanImageApi);
ImageResponse imageResponse = imageModel.call(imagePrompt);
} }
} }

View File

@ -24,7 +24,7 @@ public class StabilityAiImageModelTests {
private final StabilityAiApi imageApi = new StabilityAiApi( private final StabilityAiApi imageApi = new StabilityAiApi(
"sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx"); "sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx");
private final StabilityAiImageModel imageClient = new StabilityAiImageModel(imageApi); private final StabilityAiImageModel imageModel = new StabilityAiImageModel(imageApi);
@Test @Test
@Disabled @Disabled
@ -37,7 +37,7 @@ public class StabilityAiImageModelTests {
ImagePrompt prompt = new ImagePrompt("great wall", options); ImagePrompt prompt = new ImagePrompt("great wall", options);
// 方法调用 // 方法调用
ImageResponse response = imageClient.call(prompt); ImageResponse response = imageModel.call(prompt);
// 打印结果 // 打印结果
String b64Json = response.getResult().getOutput().getB64Json(); String b64Json = response.getResult().getOutput().getB64Json();
System.out.println(response); System.out.println(response);

View File

@ -0,0 +1,43 @@
package cn.iocoder.yudao.framework.ai.image;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.utils.Constants;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions;
/**
* {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类
*
* @author fansili
*/
public class TongYiImagesModelTest {
private final ImageSynthesis imageApi = new ImageSynthesis();
private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi);
static {
Constants.apiKey = "sk-Zsd81gZYg7";
}
@Test
@Disabled
public void imageCallTest() {
// 准备参数
ImageOptions options = OpenAiImageOptions.builder()
.withModel(ImageSynthesis.Models.WANX_V1)
.withHeight(256).withWidth(256)
.build();
ImagePrompt prompt = new ImagePrompt("中国长城!", options);
// 方法调用
ImageResponse response = imageModel.call(prompt);
// 打印结果
System.out.println(response);
}
}

View File

@ -1,39 +0,0 @@
package cn.iocoder.yudao.framework.ai.image;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.utils.Constants;
import com.alibaba.fastjson.JSON;
import org.junit.jupiter.api.Test;
import java.util.Map;
// TODO @fan改成 TongYiImagesModel
/**
* 通义万象
*/
public class TongYiImagesModelTests {
@Test
public void imageCallTest() throws NoApiKeyException {
// 设置 api key
Constants.apiKey = "sk-Zsd81gZYg7";
ImageSynthesisParam param =
ImageSynthesisParam.builder()
.model(ImageSynthesis.Models.WANX_V1)
.n(4)
.size("1024*1024")
.prompt("雄鹰自由自在的在蓝天白云下飞翔")
.build();
// 创建 ImageSynthesis
ImageSynthesis is = new ImageSynthesis();
// 调用 call 生成 image
ImageSynthesisResult call = is.call(param);
System.err.println(JSON.toJSON(call));
for (Map<String, String> result : call.getOutput().getResults()) {
System.err.println("地址: " + result.get("url"));
}
}
}