mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2025-01-18 19:20:05 +08:00
Merge remote-tracking branch 'origin/master-jdk21-ai' into master-jdk21-ai
This commit is contained in:
commit
2ae90b9edc
@ -1,7 +1,7 @@
|
||||
{
|
||||
"local": {
|
||||
"baseUrl": "http://127.0.0.1:48080/admin-api",
|
||||
"token": "Bearer 1c2ce60de96a4fb0bf5bea9604099a3d",
|
||||
"token": "test1",
|
||||
"adminTenentId": "1",
|
||||
|
||||
"appApi": "http://127.0.0.1:48080/app-api",
|
||||
|
@ -39,6 +39,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
|
||||
除此之外不要任何解释性语句。
|
||||
""");
|
||||
|
||||
// TODO @xin:这个 role 是不是删除掉好点哈。= = 目前主要是没做角色枚举。这里多了 role 反倒容易误解哈
|
||||
/**
|
||||
* 角色
|
||||
*/
|
||||
|
@ -60,9 +60,5 @@
|
||||
<groupId>cn.iocoder.boot</groupId>
|
||||
<artifactId>yudao-spring-boot-starter-test</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>cn.iocoder.boot</groupId>
|
||||
<artifactId>yudao-spring-boot-starter-excel</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
@ -12,8 +12,7 @@ import lombok.Data;
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
// TODO @xin:如果没 typehandler 的需求,autoResultMap 可以去掉哈
|
||||
@TableName(value = "ai_mind_map", autoResultMap = true)
|
||||
@TableName(value = "ai_mind_map")
|
||||
@Data
|
||||
public class AiMindMapDO extends BaseDO {
|
||||
|
||||
@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO {
|
||||
|
||||
/**
|
||||
* 用户编号
|
||||
*
|
||||
* <p>
|
||||
* 关联 AdminUserDO 的 userId 字段
|
||||
*/
|
||||
private Long userId;
|
||||
|
@ -5,7 +5,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
/**
|
||||
* AI 音乐 Mapper
|
||||
* AI 思维导图 Mapper
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
|
@ -111,7 +111,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
||||
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
||||
|
||||
// 3.2 创建 chat 需要的 Prompt
|
||||
// 3.2 构建 Prompt,并进行调用
|
||||
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
|
||||
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
||||
|
||||
|
@ -31,6 +31,7 @@ import org.springframework.ai.image.ImageOptions;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.openai.OpenAiImageOptions;
|
||||
import org.springframework.ai.qianfan.QianFanImageOptions;
|
||||
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
@ -142,6 +143,11 @@ public class AiImageServiceImpl implements AiImageService {
|
||||
.withModel(draw.getModel()).withN(1)
|
||||
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
||||
.build();
|
||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
|
||||
return QianFanImageOptions.builder()
|
||||
.withModel(draw.getModel()).withN(1)
|
||||
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
||||
.build();
|
||||
}
|
||||
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package cn.iocoder.yudao.module.ai.service.mindmap;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
@ -31,13 +32,12 @@ import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
|
||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||
|
||||
/**
|
||||
* AI 写作 Service 实现类
|
||||
* AI 思维导图 Service 实现类
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
@ -57,38 +57,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
|
||||
|
||||
@Override
|
||||
public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
|
||||
// 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
|
||||
AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
|
||||
AiChatModelDO model;
|
||||
String systemMessage;
|
||||
if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) {
|
||||
model = chatModalService.getChatModel(mindMapRole.getModelId());
|
||||
systemMessage = mindMapRole.getSystemMessage();
|
||||
} else {
|
||||
model = chatModalService.getRequiredDefaultChatModel();
|
||||
systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
|
||||
}
|
||||
|
||||
// 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型
|
||||
AiChatRoleDO role = CollUtil.getFirst(
|
||||
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
|
||||
// 1.1 获取脑图执行模型
|
||||
AiChatModelDO model = getModel(role);
|
||||
// 1.2 获取角色设定消息
|
||||
String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
|
||||
? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
|
||||
// 1.3 校验平台
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
||||
|
||||
// 2 插入思维导图信息
|
||||
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
||||
// 2. 插入思维导图信息
|
||||
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
|
||||
mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
||||
mindMapMapper.insert(mindMapDO);
|
||||
|
||||
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||
// 3.1 角色设定
|
||||
List<Message> chatMessages = new ArrayList<>();
|
||||
if (StrUtil.isNotBlank(systemMessage)) {
|
||||
chatMessages.add(new SystemMessage(systemMessage));
|
||||
}
|
||||
// 3.2 用户输入
|
||||
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
|
||||
// 3.3 构建提示词
|
||||
Prompt prompt = new Prompt(chatMessages, chatOptions);
|
||||
|
||||
// 3.1 构建 Prompt,并进行调用
|
||||
Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
|
||||
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
||||
// 3.4 流式返回
|
||||
|
||||
// 3.2 流式返回
|
||||
StringBuffer contentBuffer = new StringBuffer();
|
||||
return streamResponse.map(chunk -> {
|
||||
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
|
||||
@ -109,4 +99,36 @@ public class AiMindMapServiceImpl implements AiMindMapService {
|
||||
|
||||
}
|
||||
|
||||
private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
|
||||
// 1. 构建 message 列表
|
||||
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
|
||||
// 2. 构建 options 对象
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||
ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||
return new Prompt(chatMessages, options);
|
||||
}
|
||||
|
||||
private static List<Message> buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
|
||||
List<Message> chatMessages = new ArrayList<>();
|
||||
// 1. 角色设定
|
||||
if (StrUtil.isNotBlank(systemMessage)) {
|
||||
chatMessages.add(new SystemMessage(systemMessage));
|
||||
}
|
||||
// 2. 用户输入
|
||||
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
|
||||
return chatMessages;
|
||||
}
|
||||
|
||||
private AiChatModelDO getModel(AiChatRoleDO role) {
|
||||
AiChatModelDO model = null;
|
||||
if (role != null && role.getModelId() != null) {
|
||||
model = chatModalService.getChatModel(role.getModelId());
|
||||
}
|
||||
if (model != null) {
|
||||
model = chatModalService.getRequiredDefaultChatModel();
|
||||
}
|
||||
Assert.notNull(model, "[AI] 获取不到模型");
|
||||
return model;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package cn.iocoder.yudao.module.ai.service.write;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
@ -67,19 +68,15 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||
|
||||
@Override
|
||||
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
||||
// 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
|
||||
AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
|
||||
// TODO @xin:如果有 writeRole,但是没 modeId,是不是也可以用 systemMessage 哈?建议的写法是:先通过 modelId 获取 model。如果 model == null,则 chatModalService.getRequiredDefaultChatModel();如果还是 null,则抛出异常;。。。。。。。。。。。。。。然后,systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理。
|
||||
AiChatModelDO model;
|
||||
String systemMessage;
|
||||
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
|
||||
model = chatModalService.getChatModel(writeRole.getModelId());
|
||||
systemMessage = writeRole.getSystemMessage();
|
||||
} else {
|
||||
model = chatModalService.getRequiredDefaultChatModel();
|
||||
systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
|
||||
}
|
||||
// 1.2 校验平台
|
||||
// 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型
|
||||
AiChatRoleDO writeRole = CollUtil.getFirst(
|
||||
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
|
||||
// 1.1 获取写作执行模型
|
||||
AiChatModelDO model = getModel(writeRole);
|
||||
// 1.2 获取角色设定消息
|
||||
String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
|
||||
? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
|
||||
// 1.3 校验平台
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
||||
|
||||
@ -88,21 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
|
||||
writeMapper.insert(writeDO);
|
||||
|
||||
// 3. 调用大模型,写作生成
|
||||
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||
// 3.1 角色设定
|
||||
// TODO @xin:要不把 90 到 97 这部分,合并到一个方法里。目的是:让这个方法的主干更明确
|
||||
List<Message> chatMessages = new ArrayList<>();
|
||||
if (StrUtil.isNotBlank(systemMessage)) {
|
||||
chatMessages.add(new SystemMessage(systemMessage));
|
||||
}
|
||||
// 3.2 用户输入
|
||||
chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
|
||||
// 3.3 构建提示词
|
||||
Prompt prompt = new Prompt(chatMessages, chatOptions);
|
||||
// 3.1 构建 Prompt,并进行调用
|
||||
Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
|
||||
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
||||
|
||||
// 4. 流式返回
|
||||
// 3.2 流式返回
|
||||
StringBuffer contentBuffer = new StringBuffer();
|
||||
return streamResponse.map(chunk -> {
|
||||
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
|
||||
@ -122,7 +109,39 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
|
||||
}
|
||||
|
||||
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
|
||||
private AiChatModelDO getModel(AiChatRoleDO writeRole) {
|
||||
AiChatModelDO model = null;
|
||||
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
|
||||
model = chatModalService.getChatModel(writeRole.getModelId());
|
||||
}
|
||||
if (Objects.isNull(model)) {
|
||||
model = chatModalService.getRequiredDefaultChatModel();
|
||||
}
|
||||
Assert.notNull(model, "[AI] 获取不到模型");
|
||||
return model;
|
||||
}
|
||||
|
||||
private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
|
||||
// 1. 构建 message 列表
|
||||
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
|
||||
// 2. 构建 options 对象
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||
ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||
return new Prompt(chatMessages, options);
|
||||
}
|
||||
|
||||
private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
|
||||
List<Message> chatMessages = new ArrayList<>();
|
||||
if (StrUtil.isNotBlank(systemMessage)) {
|
||||
// 1.1 角色设定
|
||||
chatMessages.add(new SystemMessage(systemMessage));
|
||||
}
|
||||
// 1.2 用户输入
|
||||
chatMessages.add(new UserMessage(buildUserMessage(generateReqVO)));
|
||||
return chatMessages;
|
||||
}
|
||||
|
||||
private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) {
|
||||
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
|
||||
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
|
||||
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
|
||||
|
@ -18,12 +18,15 @@ import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
|
||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
|
||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
|
||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
|
||||
@ -111,6 +114,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
|
||||
//noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return SpringUtil.getBean(TongYiImagesModel.class);
|
||||
case YI_YAN:
|
||||
return SpringUtil.getBean(QianFanImageModel.class);
|
||||
case OPENAI:
|
||||
return SpringUtil.getBean(OpenAiImageModel.class);
|
||||
case STABLE_DIFFUSION:
|
||||
@ -124,14 +131,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
|
||||
//noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return buildTongYiImagesModel(apiKey);
|
||||
case YI_YAN:
|
||||
return buildQianFanImageModel(apiKey);
|
||||
case OPENAI:
|
||||
return buildOpenAiImageModel(apiKey, url);
|
||||
case STABLE_DIFFUSION:
|
||||
return buildStabilityAiImageModel(apiKey, url);
|
||||
case TONG_YI:
|
||||
return SpringUtil.getBean(TongYiImagesModel.class);
|
||||
case YI_YAN:
|
||||
return buildQianFanImageModel(apiKey);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
@ -175,6 +182,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
|
||||
}
|
||||
|
||||
private static TongYiImagesModel buildTongYiImagesModel(String key) {
|
||||
ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class);
|
||||
TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class);
|
||||
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
|
||||
connectionProperties.setApiKey(key);
|
||||
return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
||||
*/
|
||||
@ -187,6 +202,18 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return new QianFanChatModel(qianFanApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
||||
*/
|
||||
private QianFanImageModel buildQianFanImageModel(String key) {
|
||||
List<String> keys = StrUtil.split(key, '|');
|
||||
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
|
||||
String appKey = keys.get(0);
|
||||
String secretKey = keys.get(1);
|
||||
QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey);
|
||||
return new QianFanImageModel(qianFanApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
|
||||
*/
|
||||
@ -246,8 +273,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return new StabilityAiImageModel(stabilityAiApi);
|
||||
}
|
||||
|
||||
private QianFanImageModel buildQianFanImageModel(String key) {
|
||||
List<String> keys = StrUtil.split(key, '|');
|
||||
return new QianFanImageModel(new QianFanImageApi(keys.get(0), keys.get(1)));
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ public class OpenAiImageModelTests {
|
||||
"https://api.holdai.top",
|
||||
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf",
|
||||
RestClient.builder());
|
||||
private final OpenAiImageModel imageClient = new OpenAiImageModel(imageApi);
|
||||
private final OpenAiImageModel imageModel = new OpenAiImageModel(imageApi);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
@ -34,7 +34,7 @@ public class OpenAiImageModelTests {
|
||||
ImagePrompt prompt = new ImagePrompt("中国长城!", options);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageClient.call(prompt);
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
@ -1,48 +1,42 @@
|
||||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImageOptionsBuilder;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.qianfan.QianFanImageModel;
|
||||
import org.springframework.ai.qianfan.QianFanImageOptions;
|
||||
import org.springframework.ai.qianfan.api.QianFanApi;
|
||||
import org.springframework.ai.qianfan.api.QianFanImageApi;
|
||||
|
||||
import static cn.iocoder.yudao.framework.ai.image.StabilityAiImageModelTests.viewImage;
|
||||
|
||||
/**
|
||||
* 百度千帆 image
|
||||
* {@link QianFanImageModel} 集成测试类
|
||||
*/
|
||||
public class QianFanImageTests {
|
||||
|
||||
@Test
|
||||
public void callTest() {
|
||||
// todo @芋艿 千帆sdk有个错误,暂时没找到问题
|
||||
QianFanImageApi qianFanImageApi = new QianFanImageApi(
|
||||
"ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
|
||||
QianFanImageModel qianFanImageModel = new QianFanImageModel(qianFanImageApi);
|
||||
private final QianFanImageApi imageApi = new QianFanImageApi(
|
||||
"qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
|
||||
private final QianFanImageModel imageModel = new QianFanImageModel(imageApi);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
// 只支持 1024x1024、768x768、768x1024、1024x768、576x1024、1024x576
|
||||
QianFanImageOptions imageOptions = QianFanImageOptions.builder()
|
||||
.withWidth(512)
|
||||
.withHeight(512)
|
||||
.withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
|
||||
.withWidth(1024).withHeight(1024)
|
||||
.withN(1)
|
||||
.build();
|
||||
ImagePrompt imagePrompt = new ImagePrompt("薄涂炫酷少女头像,田野花朵盛开", imageOptions);
|
||||
ImageResponse call = qianFanImageModel.call(imagePrompt);
|
||||
System.err.println(JsonUtils.toJsonString(call));
|
||||
}
|
||||
ImagePrompt prompt = new ImagePrompt("good", imageOptions);
|
||||
|
||||
@Test
|
||||
public void call2Test() {
|
||||
// 官方测试 test https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java
|
||||
var options = ImageOptionsBuilder.builder().withHeight(1024).withWidth(1024).build();
|
||||
var instructions = "薄涂炫酷少女头像,田野花朵盛开";
|
||||
|
||||
ImagePrompt imagePrompt = new ImagePrompt(instructions, options);
|
||||
|
||||
QianFanImageApi qianFanImageApi = new QianFanImageApi(
|
||||
"ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
|
||||
QianFanImageModel imageModel = new QianFanImageModel(qianFanImageApi);
|
||||
ImageResponse imageResponse = imageModel.call(imagePrompt);
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
String b64Json = response.getResult().getOutput().getB64Json();
|
||||
System.out.println(response);
|
||||
viewImage(b64Json);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ public class StabilityAiImageModelTests {
|
||||
|
||||
private final StabilityAiApi imageApi = new StabilityAiApi(
|
||||
"sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx");
|
||||
private final StabilityAiImageModel imageClient = new StabilityAiImageModel(imageApi);
|
||||
private final StabilityAiImageModel imageModel = new StabilityAiImageModel(imageApi);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
@ -37,7 +37,7 @@ public class StabilityAiImageModelTests {
|
||||
ImagePrompt prompt = new ImagePrompt("great wall", options);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageClient.call(prompt);
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
String b64Json = response.getResult().getOutput().getB64Json();
|
||||
System.out.println(response);
|
||||
|
@ -0,0 +1,43 @@
|
||||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import com.alibaba.dashscope.utils.Constants;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.openai.OpenAiImageOptions;
|
||||
|
||||
/**
|
||||
* {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class TongYiImagesModelTest {
|
||||
|
||||
private final ImageSynthesis imageApi = new ImageSynthesis();
|
||||
private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi);
|
||||
|
||||
static {
|
||||
Constants.apiKey = "sk-Zsd81gZYg7";
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void imageCallTest() {
|
||||
// 准备参数
|
||||
ImageOptions options = OpenAiImageOptions.builder()
|
||||
.withModel(ImageSynthesis.Models.WANX_V1)
|
||||
.withHeight(256).withWidth(256)
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("中国长城!", options);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
}
|
@ -1,39 +0,0 @@
|
||||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import com.alibaba.dashscope.utils.Constants;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
// TODO @fan:改成 TongYiImagesModel 哈
|
||||
/**
|
||||
* 通义万象
|
||||
*/
|
||||
public class TongYiImagesModelTests {
|
||||
|
||||
@Test
|
||||
public void imageCallTest() throws NoApiKeyException {
|
||||
// 设置 api key
|
||||
Constants.apiKey = "sk-Zsd81gZYg7";
|
||||
ImageSynthesisParam param =
|
||||
ImageSynthesisParam.builder()
|
||||
.model(ImageSynthesis.Models.WANX_V1)
|
||||
.n(4)
|
||||
.size("1024*1024")
|
||||
.prompt("雄鹰自由自在的在蓝天白云下飞翔")
|
||||
.build();
|
||||
// 创建 ImageSynthesis
|
||||
ImageSynthesis is = new ImageSynthesis();
|
||||
// 调用 call 生成 image
|
||||
ImageSynthesisResult call = is.call(param);
|
||||
System.err.println(JSON.toJSON(call));
|
||||
for (Map<String, String> result : call.getOutput().getResults()) {
|
||||
System.err.println("地址: " + result.get("url"));
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user