From 29e421432d828bad49a147136bec8a7116dcf04d Mon Sep 17 00:00:00 2001
From: xiaoxin <718949661@qq.com>
Date: Thu, 11 Jul 2024 10:14:59 +0800
Subject: [PATCH 1/5] =?UTF-8?q?=E3=80=90=E8=A7=A3=E5=86=B3todo=E3=80=91AI?=
=?UTF-8?q?=20=E5=86=99=E4=BD=9C=E3=80=81=E8=84=91=E5=9B=BE=EF=BC=9Amodel?=
=?UTF-8?q?=E3=80=81systemMessage=E8=8E=B7=E5=8F=96=E9=80=BB=E8=BE=91?=
=?UTF-8?q?=E8=B0=83=E6=95=B4?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../dal/dataobject/mindmap/AiMindMapDO.java | 5 +-
.../service/mindmap/AiMindMapServiceImpl.java | 53 ++++++++++++------
.../ai/service/write/AiWriteServiceImpl.java | 56 ++++++++++++-------
3 files changed, 72 insertions(+), 42 deletions(-)
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
index 92222b590..0442a52d7 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
@@ -12,8 +12,7 @@ import lombok.Data;
*
* @author xiaoxin
*/
-// TODO @xin:如果没 typehandler 的需求,autoResultMap 可以去掉哈
-@TableName(value = "ai_mind_map", autoResultMap = true)
+@TableName(value = "ai_mind_map")
@Data
public class AiMindMapDO extends BaseDO {
@@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO {
/**
* 用户编号
- *
+ *
* 关联 AdminUserDO 的 userId 字段
*/
private Long userId;
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
index 7d96c70d2..7b49ee807 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
@@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.mindmap;
import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@@ -57,33 +58,25 @@ public class AiMindMapServiceImpl implements AiMindMapService {
@Override
public Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
- // 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
+ // 1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
- AiChatModelDO model;
- String systemMessage;
- if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) {
- model = chatModalService.getChatModel(mindMapRole.getModelId());
- systemMessage = mindMapRole.getSystemMessage();
- } else {
- model = chatModalService.getRequiredDefaultChatModel();
- systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
- }
-
+ // 1.1 获取脑图执行模型
+ AiChatModelDO model = getModel(mindMapRole);
+ // 1.2 获取角色设定消息
+ String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage())
+ ? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
+ // 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 2 插入思维导图信息
- AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
+ AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
+ mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
mindMapMapper.insert(mindMapDO);
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
// 3.1 角色设定
- List chatMessages = new ArrayList<>();
- if (StrUtil.isNotBlank(systemMessage)) {
- chatMessages.add(new SystemMessage(systemMessage));
- }
- // 3.2 用户输入
- chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
+ List chatMessages = buildMessages(generateReqVO, systemMessage);
// 3.3 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions);
@@ -109,4 +102,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
}
+ private static List buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
+ List chatMessages = new ArrayList<>();
+ if (StrUtil.isNotBlank(systemMessage)) {
+ // 1.1 角色设定
+ chatMessages.add(new SystemMessage(systemMessage));
+ }
+ // 1.2 用户输入
+ chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
+ return chatMessages;
+ }
+
+ // TODO 芋艿:这里脑图、写作都用到了,是不是可以抽哪里去
+ private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) {
+ AiChatModelDO model = null;
+ if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) {
+ model = chatModalService.getChatModel(chatRoleDO.getModelId());
+ }
+ if (Objects.isNull(model)) {
+ model = chatModalService.getRequiredDefaultChatModel();
+ }
+ Assert.notNull(model, "[AI] 获取不到模型");
+ return model;
+ }
+
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
index b03a90ab7..4b583e3c1 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
@@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@@ -67,19 +68,14 @@ public class AiWriteServiceImpl implements AiWriteService {
@Override
public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
- // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
+ // 1 获取写作模型 尝试获取写作助手角色,没有则使用默认模型
AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
- // TODO @xin:如果有 writeRole,但是没 modeId,是不是也可以用 systemMessage 哈?建议的写法是:先通过 modelId 获取 model。如果 model == null,则 chatModalService.getRequiredDefaultChatModel();如果还是 null,则抛出异常;。。。。。。。。。。。。。。然后,systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理。
- AiChatModelDO model;
- String systemMessage;
- if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
- model = chatModalService.getChatModel(writeRole.getModelId());
- systemMessage = writeRole.getSystemMessage();
- } else {
- model = chatModalService.getRequiredDefaultChatModel();
- systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
- }
- // 1.2 校验平台
+ // 1.1 获取写作执行模型
+ AiChatModelDO model = getModel(writeRole);
+ // 1.2 获取角色设定消息
+ String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
+ ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
+ // 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
@@ -90,16 +86,11 @@ public class AiWriteServiceImpl implements AiWriteService {
// 3. 调用大模型,写作生成
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
- // 3.1 角色设定
- // TODO @xin:要不把 90 到 97 这部分,合并到一个方法里。目的是:让这个方法的主干更明确
- List chatMessages = new ArrayList<>();
- if (StrUtil.isNotBlank(systemMessage)) {
- chatMessages.add(new SystemMessage(systemMessage));
- }
- // 3.2 用户输入
- chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
- // 3.3 构建提示词
+ // 3.1 构建消息列表
+ List chatMessages = buildMessages(generateReqVO, systemMessage);
+ // 3.2 构建提示词
Prompt prompt = new Prompt(chatMessages, chatOptions);
+ // 3.3 流式调用
Flux streamResponse = chatModel.stream(prompt);
// 4. 流式返回
@@ -122,6 +113,29 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
}
+ private AiChatModelDO getModel(AiChatRoleDO writeRole) {
+ AiChatModelDO model = null;
+ if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
+ model = chatModalService.getChatModel(writeRole.getModelId());
+ }
+ if (Objects.isNull(model)) {
+ model = chatModalService.getRequiredDefaultChatModel();
+ }
+ Assert.notNull(model, "[AI] 获取不到模型");
+ return model;
+ }
+
+ private List buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
+ List chatMessages = new ArrayList<>();
+ if (StrUtil.isNotBlank(systemMessage)) {
+ // 1.1 角色设定
+ chatMessages.add(new SystemMessage(systemMessage));
+ }
+ // 1.2 用户输入
+ chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
+ return chatMessages;
+ }
+
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
From c6c003707eec3fc8c7793515e3c14c46383c81ce Mon Sep 17 00:00:00 2001
From: YunaiV
Date: Thu, 11 Jul 2024 21:37:45 +0800
Subject: [PATCH 2/5] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98?=
=?UTF-8?q?=E5=8C=96=E3=80=91AI=EF=BC=9A=E9=80=9A=E4=B9=89=E5=8D=83?=
=?UTF-8?q?=E9=97=AE=E7=9A=84=20tests=20=E7=B1=BB?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
script/idea/http-client.env.json | 2 +-
.../yudao/module/ai/enums/AiChatRoleEnum.java | 1 +
yudao-module-ai/yudao-module-ai-biz/pom.xml | 4 --
.../ai/dal/mysql/mindmap/AiMindMapMapper.java | 2 +-
.../ai/core/factory/AiModelFactoryImpl.java | 39 ++++++++++++++----
.../ai/image/OpenAiImageModelTests.java | 4 +-
.../framework/ai/image/QianFanImageTests.java | 5 ++-
.../ai/image/StabilityAiImageModelTests.java | 4 +-
.../ai/image/TongYiImagesModelTest.java | 41 +++++++++++++++++++
.../ai/image/TongYiImagesModelTests.java | 39 ------------------
10 files changed, 82 insertions(+), 59 deletions(-)
create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java
delete mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java
diff --git a/script/idea/http-client.env.json b/script/idea/http-client.env.json
index 17dd0d50d..4a4cb5221 100644
--- a/script/idea/http-client.env.json
+++ b/script/idea/http-client.env.json
@@ -1,7 +1,7 @@
{
"local": {
"baseUrl": "http://127.0.0.1:48080/admin-api",
- "token": "Bearer 1c2ce60de96a4fb0bf5bea9604099a3d",
+ "token": "test1",
"adminTenentId": "1",
"appApi": "http://127.0.0.1:48080/app-api",
diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java
index ad3641421..19cbc8f8f 100644
--- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java
+++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java
@@ -39,6 +39,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
除此之外不要任何解释性语句。
""");
+ // TODO @xin:这个 role 是不是删除掉好点哈。= = 目前主要是没做角色枚举。这里多了 role 反倒容易误解哈
/**
* 角色
*/
diff --git a/yudao-module-ai/yudao-module-ai-biz/pom.xml b/yudao-module-ai/yudao-module-ai-biz/pom.xml
index a537b3db7..7c529f118 100644
--- a/yudao-module-ai/yudao-module-ai-biz/pom.xml
+++ b/yudao-module-ai/yudao-module-ai-biz/pom.xml
@@ -60,9 +60,5 @@
cn.iocoder.boot
yudao-spring-boot-starter-test
-
- cn.iocoder.boot
- yudao-spring-boot-starter-excel
-
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java
index 54fa7235a..ff25e89ff 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java
@@ -5,7 +5,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
import org.apache.ibatis.annotations.Mapper;
/**
- * AI 音乐 Mapper
+ * AI 思维导图 Mapper
*
* @author xiaoxin
*/
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java
index fbf835707..66a32167c 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java
@@ -18,12 +18,15 @@ import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
+import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation;
+import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
+import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
@@ -111,6 +114,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
+ case TONG_YI:
+ return SpringUtil.getBean(TongYiImagesModel.class);
+ case YI_YAN:
+ return SpringUtil.getBean(QianFanImageModel.class);
case OPENAI:
return SpringUtil.getBean(OpenAiImageModel.class);
case STABLE_DIFFUSION:
@@ -124,14 +131,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration
switch (platform) {
+ case TONG_YI:
+ return buildTongYiImagesModel(apiKey);
+ case YI_YAN:
+ return buildQianFanImageModel(apiKey);
case OPENAI:
return buildOpenAiImageModel(apiKey, url);
case STABLE_DIFFUSION:
return buildStabilityAiImageModel(apiKey, url);
- case TONG_YI:
- return SpringUtil.getBean(TongYiImagesModel.class);
- case YI_YAN:
- return buildQianFanImageModel(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@@ -175,6 +182,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
}
+ private static TongYiImagesModel buildTongYiImagesModel(String key) {
+ ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class);
+ TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class);
+ TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
+ connectionProperties.setApiKey(key);
+ return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties);
+ }
+
/**
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
@@ -187,6 +202,18 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new QianFanChatModel(qianFanApi);
}
+ /**
+ * 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
+ */
+ private QianFanImageModel buildQianFanImageModel(String key) {
+ List keys = StrUtil.split(key, '|');
+ Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
+ String appKey = keys.get(0);
+ String secretKey = keys.get(1);
+ QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey);
+ return new QianFanImageModel(qianFanApi);
+ }
+
/**
* 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
*/
@@ -246,8 +273,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new StabilityAiImageModel(stabilityAiApi);
}
- private QianFanImageModel buildQianFanImageModel(String key) {
- List keys = StrUtil.split(key, '|');
- return new QianFanImageModel(new QianFanImageApi(keys.get(0), keys.get(1)));
- }
}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java
index 740978e60..c9b07d9ff 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java
@@ -21,7 +21,7 @@ public class OpenAiImageModelTests {
"https://api.holdai.top",
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf",
RestClient.builder());
- private final OpenAiImageModel imageClient = new OpenAiImageModel(imageApi);
+ private final OpenAiImageModel imageModel = new OpenAiImageModel(imageApi);
@Test
@Disabled
@@ -34,7 +34,7 @@ public class OpenAiImageModelTests {
ImagePrompt prompt = new ImagePrompt("中国长城!", options);
// 方法调用
- ImageResponse response = imageClient.call(prompt);
+ ImageResponse response = imageModel.call(prompt);
// 打印结果
System.out.println(response);
}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java
index b8de6f486..04312bcbd 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java
@@ -7,7 +7,6 @@ import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.qianfan.QianFanImageModel;
import org.springframework.ai.qianfan.QianFanImageOptions;
-import org.springframework.ai.qianfan.api.QianFanApi;
import org.springframework.ai.qianfan.api.QianFanImageApi;
/**
@@ -19,7 +18,7 @@ public class QianFanImageTests {
public void callTest() {
// todo @芋艿 千帆sdk有个错误,暂时没找到问题
QianFanImageApi qianFanImageApi = new QianFanImageApi(
- "ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
+ "qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
QianFanImageModel qianFanImageModel = new QianFanImageModel(qianFanImageApi);
QianFanImageOptions imageOptions = QianFanImageOptions.builder()
@@ -45,4 +44,6 @@ public class QianFanImageTests {
ImageResponse imageResponse = imageModel.call(imagePrompt);
}
+
+
}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java
index cb7412821..7ee7e6044 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java
@@ -24,7 +24,7 @@ public class StabilityAiImageModelTests {
private final StabilityAiApi imageApi = new StabilityAiApi(
"sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx");
- private final StabilityAiImageModel imageClient = new StabilityAiImageModel(imageApi);
+ private final StabilityAiImageModel imageModel = new StabilityAiImageModel(imageApi);
@Test
@Disabled
@@ -37,7 +37,7 @@ public class StabilityAiImageModelTests {
ImagePrompt prompt = new ImagePrompt("great wall", options);
// 方法调用
- ImageResponse response = imageClient.call(prompt);
+ ImageResponse response = imageModel.call(prompt);
// 打印结果
String b64Json = response.getResult().getOutput().getB64Json();
System.out.println(response);
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java
new file mode 100644
index 000000000..0ed736cde
--- /dev/null
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java
@@ -0,0 +1,41 @@
+package cn.iocoder.yudao.framework.ai.image;
+
+import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
+import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
+import com.alibaba.dashscope.utils.Constants;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.image.ImageOptions;
+import org.springframework.ai.image.ImagePrompt;
+import org.springframework.ai.image.ImageResponse;
+import org.springframework.ai.openai.OpenAiImageOptions;
+
+/**
+ * {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类
+ *
+ * @author fansili
+ */
+public class TongYiImagesModelTest {
+
+ private final ImageSynthesis imageApi = new ImageSynthesis();
+ private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi);
+
+ static {
+ Constants.apiKey = "sk-Zsd81gZYg7";
+ }
+
+ @Test
+ public void imageCallTest() {
+ // 准备参数
+ ImageOptions options = OpenAiImageOptions.builder()
+ .withModel(ImageSynthesis.Models.WANX_V1)
+ .withHeight(256).withWidth(256)
+ .build();
+ ImagePrompt prompt = new ImagePrompt("中国长城!", options);
+
+ // 方法调用
+ ImageResponse response = imageModel.call(prompt);
+ // 打印结果
+ System.out.println(response);
+ }
+
+}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java
deleted file mode 100644
index 7f44873b5..000000000
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java
+++ /dev/null
@@ -1,39 +0,0 @@
-package cn.iocoder.yudao.framework.ai.image;
-
-import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
-import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
-import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
-import com.alibaba.dashscope.exception.NoApiKeyException;
-import com.alibaba.dashscope.utils.Constants;
-import com.alibaba.fastjson.JSON;
-import org.junit.jupiter.api.Test;
-
-import java.util.Map;
-
-// TODO @fan:改成 TongYiImagesModel 哈
-/**
- * 通义万象
- */
-public class TongYiImagesModelTests {
-
- @Test
- public void imageCallTest() throws NoApiKeyException {
- // 设置 api key
- Constants.apiKey = "sk-Zsd81gZYg7";
- ImageSynthesisParam param =
- ImageSynthesisParam.builder()
- .model(ImageSynthesis.Models.WANX_V1)
- .n(4)
- .size("1024*1024")
- .prompt("雄鹰自由自在的在蓝天白云下飞翔")
- .build();
- // 创建 ImageSynthesis
- ImageSynthesis is = new ImageSynthesis();
- // 调用 call 生成 image
- ImageSynthesisResult call = is.call(param);
- System.err.println(JSON.toJSON(call));
- for (Map result : call.getOutput().getResults()) {
- System.err.println("地址: " + result.get("url"));
- }
- }
-}
From 18aeb072a6187b09b9451f6a93d0bb342f38ebd3 Mon Sep 17 00:00:00 2001
From: cherishsince
Date: Thu, 11 Jul 2024 21:46:09 +0800
Subject: [PATCH 3/5] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91buildIm?=
=?UTF-8?q?ageOptions=20=E6=94=AF=E6=8C=81=E5=8D=83=E5=B8=86?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../yudao/module/ai/service/image/AiImageServiceImpl.java | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java
index 7ea629e11..02c1ab334 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java
@@ -31,6 +31,7 @@ import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions;
+import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
@@ -142,6 +143,11 @@ public class AiImageServiceImpl implements AiImageService {
.withModel(draw.getModel()).withN(1)
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.build();
+ } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
+ return QianFanImageOptions.builder()
+ .withModel(draw.getModel()).withN(1)
+ .withHeight(draw.getHeight()).withWidth(draw.getWidth())
+ .build();
}
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
}
From 698b2b24aeee15dc27d7eea058cce7e968471608 Mon Sep 17 00:00:00 2001
From: YunaiV
Date: Thu, 11 Jul 2024 22:15:44 +0800
Subject: [PATCH 4/5] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98?=
=?UTF-8?q?=E5=8C=96=E3=80=91AI=EF=BC=9A=E6=96=87=E5=BF=83=E4=B8=80?=
=?UTF-8?q?=E8=A8=80=E7=9A=84=20tests=20=E7=B1=BB?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../framework/ai/image/QianFanImageTests.java | 53 ++++++++-----------
.../ai/image/TongYiImagesModelTest.java | 2 +
2 files changed, 25 insertions(+), 30 deletions(-)
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java
index 04312bcbd..22bf6614e 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java
@@ -1,49 +1,42 @@
package cn.iocoder.yudao.framework.ai.image;
-import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
+import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
-import org.springframework.ai.image.ImageOptionsBuilder;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.qianfan.QianFanImageModel;
import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.qianfan.api.QianFanImageApi;
+import static cn.iocoder.yudao.framework.ai.image.StabilityAiImageModelTests.viewImage;
+
/**
- * 百度千帆 image
+ * {@link QianFanImageModel} 集成测试类
*/
public class QianFanImageTests {
- @Test
- public void callTest() {
- // todo @芋艿 千帆sdk有个错误,暂时没找到问题
- QianFanImageApi qianFanImageApi = new QianFanImageApi(
- "qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
- QianFanImageModel qianFanImageModel = new QianFanImageModel(qianFanImageApi);
+ private final QianFanImageApi imageApi = new QianFanImageApi(
+ "qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
+ private final QianFanImageModel imageModel = new QianFanImageModel(imageApi);
+ @Test
+ @Disabled
+ public void testCall() {
+ // 准备参数
+ // 只支持 1024x1024、768x768、768x1024、1024x768、576x1024、1024x576
QianFanImageOptions imageOptions = QianFanImageOptions.builder()
- .withWidth(512)
- .withHeight(512)
+ .withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
+ .withWidth(1024).withHeight(1024)
+ .withN(1)
.build();
- ImagePrompt imagePrompt = new ImagePrompt("薄涂炫酷少女头像,田野花朵盛开", imageOptions);
- ImageResponse call = qianFanImageModel.call(imagePrompt);
- System.err.println(JsonUtils.toJsonString(call));
+ ImagePrompt prompt = new ImagePrompt("good", imageOptions);
+
+ // 方法调用
+ ImageResponse response = imageModel.call(prompt);
+ // 打印结果
+ String b64Json = response.getResult().getOutput().getB64Json();
+ System.out.println(response);
+ viewImage(b64Json);
}
- @Test
- public void call2Test() {
- // 官方测试 test https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java
- var options = ImageOptionsBuilder.builder().withHeight(1024).withWidth(1024).build();
- var instructions = "薄涂炫酷少女头像,田野花朵盛开";
-
- ImagePrompt imagePrompt = new ImagePrompt(instructions, options);
-
- QianFanImageApi qianFanImageApi = new QianFanImageApi(
- "ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
- QianFanImageModel imageModel = new QianFanImageModel(qianFanImageApi);
- ImageResponse imageResponse = imageModel.call(imagePrompt);
- }
-
-
-
}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java
index 0ed736cde..41d7859c4 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java
@@ -3,6 +3,7 @@ package cn.iocoder.yudao.framework.ai.image;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.utils.Constants;
+import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
@@ -24,6 +25,7 @@ public class TongYiImagesModelTest {
}
@Test
+ @Disabled
public void imageCallTest() {
// 准备参数
ImageOptions options = OpenAiImageOptions.builder()
From 68ed8cd6f839be4448b7d3044e9d8f2a1d95f9b3 Mon Sep 17 00:00:00 2001
From: YunaiV
Date: Fri, 12 Jul 2024 09:26:32 +0800
Subject: [PATCH 5/5] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98?=
=?UTF-8?q?=E5=8C=96=E3=80=91AI=EF=BC=9A=E6=80=9D=E7=BB=B4=E5=AF=BC?=
=?UTF-8?q?=E5=85=A5=E3=80=81=E5=86=99=E4=BD=9C=E7=9A=84=E7=94=9F=E6=88=90?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../chat/AiChatMessageServiceImpl.java | 2 +-
.../service/mindmap/AiMindMapServiceImpl.java | 49 ++++++++++---------
.../ai/service/write/AiWriteServiceImpl.java | 29 ++++++-----
3 files changed, 45 insertions(+), 35 deletions(-)
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
index 6c8cdeaca..72fa06a79 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
@@ -111,7 +111,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
- // 3.2 创建 chat 需要的 Prompt
+ // 3.2 构建 Prompt,并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
Flux streamResponse = chatModel.stream(prompt);
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
index 7b49ee807..72be20c54 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
@@ -32,13 +32,12 @@ import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
-import java.util.Objects;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
/**
- * AI 写作 Service 实现类
+ * AI 思维导图 Service 实现类
*
* @author xiaoxin
*/
@@ -58,30 +57,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
@Override
public Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
- // 1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
- AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
+ // 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型
+ AiChatRoleDO role = CollUtil.getFirst(
+ chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
// 1.1 获取脑图执行模型
- AiChatModelDO model = getModel(mindMapRole);
+ AiChatModelDO model = getModel(role);
// 1.2 获取角色设定消息
- String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage())
- ? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
+ String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
+ ? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
// 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
- // 2 插入思维导图信息
+ // 2. 插入思维导图信息
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
mindMapMapper.insert(mindMapDO);
- ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
- // 3.1 角色设定
- List chatMessages = buildMessages(generateReqVO, systemMessage);
- // 3.3 构建提示词
- Prompt prompt = new Prompt(chatMessages, chatOptions);
-
+ // 3.1 构建 Prompt,并进行调用
+ Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
Flux streamResponse = chatModel.stream(prompt);
- // 3.4 流式返回
+
+ // 3.2 流式返回
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@@ -102,24 +99,32 @@ public class AiMindMapServiceImpl implements AiMindMapService {
}
+ private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
+ // 1. 构建 message 列表
+ List chatMessages = buildMessages(generateReqVO, systemMessage);
+ // 2. 构建 options 对象
+ AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
+ ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
+ return new Prompt(chatMessages, options);
+ }
+
private static List buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
List chatMessages = new ArrayList<>();
+ // 1. 角色设定
if (StrUtil.isNotBlank(systemMessage)) {
- // 1.1 角色设定
chatMessages.add(new SystemMessage(systemMessage));
}
- // 1.2 用户输入
+ // 2. 用户输入
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
return chatMessages;
}
- // TODO 芋艿:这里脑图、写作都用到了,是不是可以抽哪里去
- private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) {
+ private AiChatModelDO getModel(AiChatRoleDO role) {
AiChatModelDO model = null;
- if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) {
- model = chatModalService.getChatModel(chatRoleDO.getModelId());
+ if (role != null && role.getModelId() != null) {
+ model = chatModalService.getChatModel(role.getModelId());
}
- if (Objects.isNull(model)) {
+ if (model != null) {
model = chatModalService.getRequiredDefaultChatModel();
}
Assert.notNull(model, "[AI] 获取不到模型");
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
index 4b583e3c1..2fae31d59 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
@@ -68,8 +68,9 @@ public class AiWriteServiceImpl implements AiWriteService {
@Override
public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
- // 1 获取写作模型 尝试获取写作助手角色,没有则使用默认模型
- AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
+ // 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型
+ AiChatRoleDO writeRole = CollUtil.getFirst(
+ chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
// 1.1 获取写作执行模型
AiChatModelDO model = getModel(writeRole);
// 1.2 获取角色设定消息
@@ -84,16 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService {
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
writeMapper.insert(writeDO);
- // 3. 调用大模型,写作生成
- ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
- // 3.1 构建消息列表
- List chatMessages = buildMessages(generateReqVO, systemMessage);
- // 3.2 构建提示词
- Prompt prompt = new Prompt(chatMessages, chatOptions);
- // 3.3 流式调用
+ // 3.1 构建 Prompt,并进行调用
+ Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
Flux streamResponse = chatModel.stream(prompt);
- // 4. 流式返回
+ // 3.2 流式返回
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@@ -125,6 +121,15 @@ public class AiWriteServiceImpl implements AiWriteService {
return model;
}
+ private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
+ // 1. 构建 message 列表
+ List chatMessages = buildMessages(generateReqVO, systemMessage);
+ // 2. 构建 options 对象
+ AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
+ ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
+ return new Prompt(chatMessages, options);
+ }
+
private List buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
List chatMessages = new ArrayList<>();
if (StrUtil.isNotBlank(systemMessage)) {
@@ -132,11 +137,11 @@ public class AiWriteServiceImpl implements AiWriteService {
chatMessages.add(new SystemMessage(systemMessage));
}
// 1.2 用户输入
- chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
+ chatMessages.add(new UserMessage(buildUserMessage(generateReqVO)));
return chatMessages;
}
- private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
+ private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) {
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());