From 29e421432d828bad49a147136bec8a7116dcf04d Mon Sep 17 00:00:00 2001 From: xiaoxin <718949661@qq.com> Date: Thu, 11 Jul 2024 10:14:59 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E3=80=90=E8=A7=A3=E5=86=B3todo=E3=80=91AI?= =?UTF-8?q?=20=E5=86=99=E4=BD=9C=E3=80=81=E8=84=91=E5=9B=BE=EF=BC=9Amodel?= =?UTF-8?q?=E3=80=81systemMessage=E8=8E=B7=E5=8F=96=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dal/dataobject/mindmap/AiMindMapDO.java | 5 +- .../service/mindmap/AiMindMapServiceImpl.java | 53 ++++++++++++------ .../ai/service/write/AiWriteServiceImpl.java | 56 ++++++++++++------- 3 files changed, 72 insertions(+), 42 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java index 92222b590..0442a52d7 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java @@ -12,8 +12,7 @@ import lombok.Data; * * @author xiaoxin */ -// TODO @xin:如果没 typehandler 的需求,autoResultMap 可以去掉哈 -@TableName(value = "ai_mind_map", autoResultMap = true) +@TableName(value = "ai_mind_map") @Data public class AiMindMapDO extends BaseDO { @@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO { /** * 用户编号 - * + *

* 关联 AdminUserDO 的 userId 字段 */ private Long userId; diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java index 7d96c70d2..7b49ee807 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.ai.service.mindmap; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.lang.Assert; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; @@ -57,33 +58,25 @@ public class AiMindMapServiceImpl implements AiMindMapService { @Override public Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { - // 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型 + // 1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型 AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); - AiChatModelDO model; - String systemMessage; - if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) { - model = chatModalService.getChatModel(mindMapRole.getModelId()); - systemMessage = mindMapRole.getSystemMessage(); - } else { - model = chatModalService.getRequiredDefaultChatModel(); - systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); - } - + // 1.1 获取脑图执行模型 + AiChatModelDO model = getModel(mindMapRole); + // 1.2 获取角色设定消息 + String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage()) + ? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); + // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); // 2 插入思维导图信息 - AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); + AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, + mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); mindMapMapper.insert(mindMapDO); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); // 3.1 角色设定 - List chatMessages = new ArrayList<>(); - if (StrUtil.isNotBlank(systemMessage)) { - chatMessages.add(new SystemMessage(systemMessage)); - } - // 3.2 用户输入 - chatMessages.add(new UserMessage(generateReqVO.getPrompt())); + List chatMessages = buildMessages(generateReqVO, systemMessage); // 3.3 构建提示词 Prompt prompt = new Prompt(chatMessages, chatOptions); @@ -109,4 +102,28 @@ public class AiMindMapServiceImpl implements AiMindMapService { } + private static List buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) { + List chatMessages = new ArrayList<>(); + if (StrUtil.isNotBlank(systemMessage)) { + // 1.1 角色设定 + chatMessages.add(new SystemMessage(systemMessage)); + } + // 1.2 用户输入 + chatMessages.add(new UserMessage(generateReqVO.getPrompt())); + return chatMessages; + } + + // TODO 芋艿:这里脑图、写作都用到了,是不是可以抽哪里去 + private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) { + AiChatModelDO model = null; + if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) { + model = chatModalService.getChatModel(chatRoleDO.getModelId()); + } + if (Objects.isNull(model)) { + model = chatModalService.getRequiredDefaultChatModel(); + } + Assert.notNull(model, "[AI] 获取不到模型"); + return model; + } + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index b03a90ab7..4b583e3c1 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.ai.service.write; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.lang.Assert; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.util.AiUtils; @@ -67,19 +68,14 @@ public class AiWriteServiceImpl implements AiWriteService { @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { - // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型 + // 1 获取写作模型 尝试获取写作助手角色,没有则使用默认模型 AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); - // TODO @xin:如果有 writeRole,但是没 modeId,是不是也可以用 systemMessage 哈?建议的写法是:先通过 modelId 获取 model。如果 model == null,则 chatModalService.getRequiredDefaultChatModel();如果还是 null,则抛出异常;。。。。。。。。。。。。。。然后,systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理。 - AiChatModelDO model; - String systemMessage; - if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { - model = chatModalService.getChatModel(writeRole.getModelId()); - systemMessage = writeRole.getSystemMessage(); - } else { - model = chatModalService.getRequiredDefaultChatModel(); - systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage(); - } - // 1.2 校验平台 + // 1.1 获取写作执行模型 + AiChatModelDO model = getModel(writeRole); + // 1.2 获取角色设定消息 + String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage()) + ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage(); + // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); @@ -90,16 +86,11 @@ public class AiWriteServiceImpl implements AiWriteService { // 3. 调用大模型,写作生成 ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); - // 3.1 角色设定 - // TODO @xin:要不把 90 到 97 这部分,合并到一个方法里。目的是:让这个方法的主干更明确 - List chatMessages = new ArrayList<>(); - if (StrUtil.isNotBlank(systemMessage)) { - chatMessages.add(new SystemMessage(systemMessage)); - } - // 3.2 用户输入 - chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO))); - // 3.3 构建提示词 + // 3.1 构建消息列表 + List chatMessages = buildMessages(generateReqVO, systemMessage); + // 3.2 构建提示词 Prompt prompt = new Prompt(chatMessages, chatOptions); + // 3.3 流式调用 Flux streamResponse = chatModel.stream(prompt); // 4. 流式返回 @@ -122,6 +113,29 @@ public class AiWriteServiceImpl implements AiWriteService { }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); } + private AiChatModelDO getModel(AiChatRoleDO writeRole) { + AiChatModelDO model = null; + if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { + model = chatModalService.getChatModel(writeRole.getModelId()); + } + if (Objects.isNull(model)) { + model = chatModalService.getRequiredDefaultChatModel(); + } + Assert.notNull(model, "[AI] 获取不到模型"); + return model; + } + + private List buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) { + List chatMessages = new ArrayList<>(); + if (StrUtil.isNotBlank(systemMessage)) { + // 1.1 角色设定 + chatMessages.add(new SystemMessage(systemMessage)); + } + // 1.2 用户输入 + chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO))); + return chatMessages; + } + private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); From c6c003707eec3fc8c7793515e3c14c46383c81ce Mon Sep 17 00:00:00 2001 From: YunaiV Date: Thu, 11 Jul 2024 21:37:45 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E3=80=91AI=EF=BC=9A=E9=80=9A=E4=B9=89=E5=8D=83?= =?UTF-8?q?=E9=97=AE=E7=9A=84=20tests=20=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- script/idea/http-client.env.json | 2 +- .../yudao/module/ai/enums/AiChatRoleEnum.java | 1 + yudao-module-ai/yudao-module-ai-biz/pom.xml | 4 -- .../ai/dal/mysql/mindmap/AiMindMapMapper.java | 2 +- .../ai/core/factory/AiModelFactoryImpl.java | 39 ++++++++++++++---- .../ai/image/OpenAiImageModelTests.java | 4 +- .../framework/ai/image/QianFanImageTests.java | 5 ++- .../ai/image/StabilityAiImageModelTests.java | 4 +- .../ai/image/TongYiImagesModelTest.java | 41 +++++++++++++++++++ .../ai/image/TongYiImagesModelTests.java | 39 ------------------ 10 files changed, 82 insertions(+), 59 deletions(-) create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java delete mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java diff --git a/script/idea/http-client.env.json b/script/idea/http-client.env.json index 17dd0d50d..4a4cb5221 100644 --- a/script/idea/http-client.env.json +++ b/script/idea/http-client.env.json @@ -1,7 +1,7 @@ { "local": { "baseUrl": "http://127.0.0.1:48080/admin-api", - "token": "Bearer 1c2ce60de96a4fb0bf5bea9604099a3d", + "token": "test1", "adminTenentId": "1", "appApi": "http://127.0.0.1:48080/app-api", diff --git a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java index ad3641421..19cbc8f8f 100644 --- a/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java +++ b/yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java @@ -39,6 +39,7 @@ public enum AiChatRoleEnum implements IntArrayValuable { 除此之外不要任何解释性语句。 """); + // TODO @xin:这个 role 是不是删除掉好点哈。= = 目前主要是没做角色枚举。这里多了 role 反倒容易误解哈 /** * 角色 */ diff --git a/yudao-module-ai/yudao-module-ai-biz/pom.xml b/yudao-module-ai/yudao-module-ai-biz/pom.xml index a537b3db7..7c529f118 100644 --- a/yudao-module-ai/yudao-module-ai-biz/pom.xml +++ b/yudao-module-ai/yudao-module-ai-biz/pom.xml @@ -60,9 +60,5 @@ cn.iocoder.boot yudao-spring-boot-starter-test - - cn.iocoder.boot - yudao-spring-boot-starter-excel - \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java index 54fa7235a..ff25e89ff 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java @@ -5,7 +5,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; import org.apache.ibatis.annotations.Mapper; /** - * AI 音乐 Mapper + * AI 思维导图 Mapper * * @author xiaoxin */ diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index fbf835707..66a32167c 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -18,12 +18,15 @@ import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties; import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel; +import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties; import com.alibaba.dashscope.aigc.generation.Generation; +import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties; import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties; +import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; @@ -111,6 +114,10 @@ public class AiModelFactoryImpl implements AiModelFactory { public ImageModel getDefaultImageModel(AiPlatformEnum platform) { //noinspection EnhancedSwitchMigration switch (platform) { + case TONG_YI: + return SpringUtil.getBean(TongYiImagesModel.class); + case YI_YAN: + return SpringUtil.getBean(QianFanImageModel.class); case OPENAI: return SpringUtil.getBean(OpenAiImageModel.class); case STABLE_DIFFUSION: @@ -124,14 +131,14 @@ public class AiModelFactoryImpl implements AiModelFactory { public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) { //noinspection EnhancedSwitchMigration switch (platform) { + case TONG_YI: + return buildTongYiImagesModel(apiKey); + case YI_YAN: + return buildQianFanImageModel(apiKey); case OPENAI: return buildOpenAiImageModel(apiKey, url); case STABLE_DIFFUSION: return buildStabilityAiImageModel(apiKey, url); - case TONG_YI: - return SpringUtil.getBean(TongYiImagesModel.class); - case YI_YAN: - return buildQianFanImageModel(apiKey); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); } @@ -175,6 +182,14 @@ public class AiModelFactoryImpl implements AiModelFactory { return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); } + private static TongYiImagesModel buildTongYiImagesModel(String key) { + ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class); + TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class); + TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties(); + connectionProperties.setApiKey(key); + return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties); + } + /** * 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} */ @@ -187,6 +202,18 @@ public class AiModelFactoryImpl implements AiModelFactory { return new QianFanChatModel(qianFanApi); } + /** + * 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} + */ + private QianFanImageModel buildQianFanImageModel(String key) { + List keys = StrUtil.split(key, '|'); + Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式"); + String appKey = keys.get(0); + String secretKey = keys.get(1); + QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey); + return new QianFanImageModel(qianFanApi); + } + /** * 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)} */ @@ -246,8 +273,4 @@ public class AiModelFactoryImpl implements AiModelFactory { return new StabilityAiImageModel(stabilityAiApi); } - private QianFanImageModel buildQianFanImageModel(String key) { - List keys = StrUtil.split(key, '|'); - return new QianFanImageModel(new QianFanImageApi(keys.get(0), keys.get(1))); - } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java index 740978e60..c9b07d9ff 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java @@ -21,7 +21,7 @@ public class OpenAiImageModelTests { "https://api.holdai.top", "sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf", RestClient.builder()); - private final OpenAiImageModel imageClient = new OpenAiImageModel(imageApi); + private final OpenAiImageModel imageModel = new OpenAiImageModel(imageApi); @Test @Disabled @@ -34,7 +34,7 @@ public class OpenAiImageModelTests { ImagePrompt prompt = new ImagePrompt("中国长城!", options); // 方法调用 - ImageResponse response = imageClient.call(prompt); + ImageResponse response = imageModel.call(prompt); // 打印结果 System.out.println(response); } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java index b8de6f486..04312bcbd 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java @@ -7,7 +7,6 @@ import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.qianfan.QianFanImageModel; import org.springframework.ai.qianfan.QianFanImageOptions; -import org.springframework.ai.qianfan.api.QianFanApi; import org.springframework.ai.qianfan.api.QianFanImageApi; /** @@ -19,7 +18,7 @@ public class QianFanImageTests { public void callTest() { // todo @芋艿 千帆sdk有个错误,暂时没找到问题 QianFanImageApi qianFanImageApi = new QianFanImageApi( - "ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb"); + "qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e"); QianFanImageModel qianFanImageModel = new QianFanImageModel(qianFanImageApi); QianFanImageOptions imageOptions = QianFanImageOptions.builder() @@ -45,4 +44,6 @@ public class QianFanImageTests { ImageResponse imageResponse = imageModel.call(imagePrompt); } + + } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java index cb7412821..7ee7e6044 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java @@ -24,7 +24,7 @@ public class StabilityAiImageModelTests { private final StabilityAiApi imageApi = new StabilityAiApi( "sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx"); - private final StabilityAiImageModel imageClient = new StabilityAiImageModel(imageApi); + private final StabilityAiImageModel imageModel = new StabilityAiImageModel(imageApi); @Test @Disabled @@ -37,7 +37,7 @@ public class StabilityAiImageModelTests { ImagePrompt prompt = new ImagePrompt("great wall", options); // 方法调用 - ImageResponse response = imageClient.call(prompt); + ImageResponse response = imageModel.call(prompt); // 打印结果 String b64Json = response.getResult().getOutput().getB64Json(); System.out.println(response); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java new file mode 100644 index 000000000..0ed736cde --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java @@ -0,0 +1,41 @@ +package cn.iocoder.yudao.framework.ai.image; + +import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel; +import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; +import com.alibaba.dashscope.utils.Constants; +import org.junit.jupiter.api.Test; +import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.openai.OpenAiImageOptions; + +/** + * {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类 + * + * @author fansili + */ +public class TongYiImagesModelTest { + + private final ImageSynthesis imageApi = new ImageSynthesis(); + private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi); + + static { + Constants.apiKey = "sk-Zsd81gZYg7"; + } + + @Test + public void imageCallTest() { + // 准备参数 + ImageOptions options = OpenAiImageOptions.builder() + .withModel(ImageSynthesis.Models.WANX_V1) + .withHeight(256).withWidth(256) + .build(); + ImagePrompt prompt = new ImagePrompt("中国长城!", options); + + // 方法调用 + ImageResponse response = imageModel.call(prompt); + // 打印结果 + System.out.println(response); + } + +} diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java deleted file mode 100644 index 7f44873b5..000000000 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java +++ /dev/null @@ -1,39 +0,0 @@ -package cn.iocoder.yudao.framework.ai.image; - -import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; -import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam; -import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult; -import com.alibaba.dashscope.exception.NoApiKeyException; -import com.alibaba.dashscope.utils.Constants; -import com.alibaba.fastjson.JSON; -import org.junit.jupiter.api.Test; - -import java.util.Map; - -// TODO @fan:改成 TongYiImagesModel 哈 -/** - * 通义万象 - */ -public class TongYiImagesModelTests { - - @Test - public void imageCallTest() throws NoApiKeyException { - // 设置 api key - Constants.apiKey = "sk-Zsd81gZYg7"; - ImageSynthesisParam param = - ImageSynthesisParam.builder() - .model(ImageSynthesis.Models.WANX_V1) - .n(4) - .size("1024*1024") - .prompt("雄鹰自由自在的在蓝天白云下飞翔") - .build(); - // 创建 ImageSynthesis - ImageSynthesis is = new ImageSynthesis(); - // 调用 call 生成 image - ImageSynthesisResult call = is.call(param); - System.err.println(JSON.toJSON(call)); - for (Map result : call.getOutput().getResults()) { - System.err.println("地址: " + result.get("url")); - } - } -} From 18aeb072a6187b09b9451f6a93d0bb342f38ebd3 Mon Sep 17 00:00:00 2001 From: cherishsince Date: Thu, 11 Jul 2024 21:46:09 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91buildIm?= =?UTF-8?q?ageOptions=20=E6=94=AF=E6=8C=81=E5=8D=83=E5=B8=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../yudao/module/ai/service/image/AiImageServiceImpl.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 7ea629e11..02c1ab334 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -31,6 +31,7 @@ import org.springframework.ai.image.ImageOptions; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.openai.OpenAiImageOptions; +import org.springframework.ai.qianfan.QianFanImageOptions; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; @@ -142,6 +143,11 @@ public class AiImageServiceImpl implements AiImageService { .withModel(draw.getModel()).withN(1) .withHeight(draw.getHeight()).withWidth(draw.getWidth()) .build(); + } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) { + return QianFanImageOptions.builder() + .withModel(draw.getModel()).withN(1) + .withHeight(draw.getHeight()).withWidth(draw.getWidth()) + .build(); } throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); } From 698b2b24aeee15dc27d7eea058cce7e968471608 Mon Sep 17 00:00:00 2001 From: YunaiV Date: Thu, 11 Jul 2024 22:15:44 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E3=80=91AI=EF=BC=9A=E6=96=87=E5=BF=83=E4=B8=80?= =?UTF-8?q?=E8=A8=80=E7=9A=84=20tests=20=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../framework/ai/image/QianFanImageTests.java | 53 ++++++++----------- .../ai/image/TongYiImagesModelTest.java | 2 + 2 files changed, 25 insertions(+), 30 deletions(-) diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java index 04312bcbd..22bf6614e 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java @@ -1,49 +1,42 @@ package cn.iocoder.yudao.framework.ai.image; -import cn.iocoder.yudao.framework.common.util.json.JsonUtils; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.qianfan.QianFanImageModel; import org.springframework.ai.qianfan.QianFanImageOptions; import org.springframework.ai.qianfan.api.QianFanImageApi; +import static cn.iocoder.yudao.framework.ai.image.StabilityAiImageModelTests.viewImage; + /** - * 百度千帆 image + * {@link QianFanImageModel} 集成测试类 */ public class QianFanImageTests { - @Test - public void callTest() { - // todo @芋艿 千帆sdk有个错误,暂时没找到问题 - QianFanImageApi qianFanImageApi = new QianFanImageApi( - "qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e"); - QianFanImageModel qianFanImageModel = new QianFanImageModel(qianFanImageApi); + private final QianFanImageApi imageApi = new QianFanImageApi( + "qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e"); + private final QianFanImageModel imageModel = new QianFanImageModel(imageApi); + @Test + @Disabled + public void testCall() { + // 准备参数 + // 只支持 1024x1024、768x768、768x1024、1024x768、576x1024、1024x576 QianFanImageOptions imageOptions = QianFanImageOptions.builder() - .withWidth(512) - .withHeight(512) + .withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue()) + .withWidth(1024).withHeight(1024) + .withN(1) .build(); - ImagePrompt imagePrompt = new ImagePrompt("薄涂炫酷少女头像,田野花朵盛开", imageOptions); - ImageResponse call = qianFanImageModel.call(imagePrompt); - System.err.println(JsonUtils.toJsonString(call)); + ImagePrompt prompt = new ImagePrompt("good", imageOptions); + + // 方法调用 + ImageResponse response = imageModel.call(prompt); + // 打印结果 + String b64Json = response.getResult().getOutput().getB64Json(); + System.out.println(response); + viewImage(b64Json); } - @Test - public void call2Test() { - // 官方测试 test https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java - var options = ImageOptionsBuilder.builder().withHeight(1024).withWidth(1024).build(); - var instructions = "薄涂炫酷少女头像,田野花朵盛开"; - - ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - - QianFanImageApi qianFanImageApi = new QianFanImageApi( - "ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb"); - QianFanImageModel imageModel = new QianFanImageModel(qianFanImageApi); - ImageResponse imageResponse = imageModel.call(imagePrompt); - } - - - } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java index 0ed736cde..41d7859c4 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java @@ -3,6 +3,7 @@ package cn.iocoder.yudao.framework.ai.image; import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel; import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; import com.alibaba.dashscope.utils.Constants; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springframework.ai.image.ImageOptions; import org.springframework.ai.image.ImagePrompt; @@ -24,6 +25,7 @@ public class TongYiImagesModelTest { } @Test + @Disabled public void imageCallTest() { // 准备参数 ImageOptions options = OpenAiImageOptions.builder() From 68ed8cd6f839be4448b7d3044e9d8f2a1d95f9b3 Mon Sep 17 00:00:00 2001 From: YunaiV Date: Fri, 12 Jul 2024 09:26:32 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E3=80=90=E4=BB=A3=E7=A0=81=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E3=80=91AI=EF=BC=9A=E6=80=9D=E7=BB=B4=E5=AF=BC?= =?UTF-8?q?=E5=85=A5=E3=80=81=E5=86=99=E4=BD=9C=E7=9A=84=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat/AiChatMessageServiceImpl.java | 2 +- .../service/mindmap/AiMindMapServiceImpl.java | 49 ++++++++++--------- .../ai/service/write/AiWriteServiceImpl.java | 29 ++++++----- 3 files changed, 45 insertions(+), 35 deletions(-) diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 6c8cdeaca..72fa06a79 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -111,7 +111,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); - // 3.2 创建 chat 需要的 Prompt + // 3.2 构建 Prompt,并进行调用 Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); Flux streamResponse = chatModel.stream(prompt); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java index 7b49ee807..72be20c54 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java @@ -32,13 +32,12 @@ import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; -import java.util.Objects; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; /** - * AI 写作 Service 实现类 + * AI 思维导图 Service 实现类 * * @author xiaoxin */ @@ -58,30 +57,28 @@ public class AiMindMapServiceImpl implements AiMindMapService { @Override public Flux> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { - // 1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型 - AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); + // 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型 + AiChatRoleDO role = CollUtil.getFirst( + chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); // 1.1 获取脑图执行模型 - AiChatModelDO model = getModel(mindMapRole); + AiChatModelDO model = getModel(role); // 1.2 获取角色设定消息 - String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage()) - ? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); + String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage()) + ? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage(); // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); - // 2 插入思维导图信息 + // 2. 插入思维导图信息 AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); mindMapMapper.insert(mindMapDO); - ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); - // 3.1 角色设定 - List chatMessages = buildMessages(generateReqVO, systemMessage); - // 3.3 构建提示词 - Prompt prompt = new Prompt(chatMessages, chatOptions); - + // 3.1 构建 Prompt,并进行调用 + Prompt prompt = buildPrompt(generateReqVO, model, systemMessage); Flux streamResponse = chatModel.stream(prompt); - // 3.4 流式返回 + + // 3.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; @@ -102,24 +99,32 @@ public class AiMindMapServiceImpl implements AiMindMapService { } + private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) { + // 1. 构建 message 列表 + List chatMessages = buildMessages(generateReqVO, systemMessage); + // 2. 构建 options 对象 + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); + return new Prompt(chatMessages, options); + } + private static List buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) { List chatMessages = new ArrayList<>(); + // 1. 角色设定 if (StrUtil.isNotBlank(systemMessage)) { - // 1.1 角色设定 chatMessages.add(new SystemMessage(systemMessage)); } - // 1.2 用户输入 + // 2. 用户输入 chatMessages.add(new UserMessage(generateReqVO.getPrompt())); return chatMessages; } - // TODO 芋艿:这里脑图、写作都用到了,是不是可以抽哪里去 - private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) { + private AiChatModelDO getModel(AiChatRoleDO role) { AiChatModelDO model = null; - if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) { - model = chatModalService.getChatModel(chatRoleDO.getModelId()); + if (role != null && role.getModelId() != null) { + model = chatModalService.getChatModel(role.getModelId()); } - if (Objects.isNull(model)) { + if (model != null) { model = chatModalService.getRequiredDefaultChatModel(); } Assert.notNull(model, "[AI] 获取不到模型"); diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java index 4b583e3c1..2fae31d59 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java @@ -68,8 +68,9 @@ public class AiWriteServiceImpl implements AiWriteService { @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { - // 1 获取写作模型 尝试获取写作助手角色,没有则使用默认模型 - AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); + // 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型 + AiChatRoleDO writeRole = CollUtil.getFirst( + chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); // 1.1 获取写作执行模型 AiChatModelDO model = getModel(writeRole); // 1.2 获取角色设定消息 @@ -84,16 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService { write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel())); writeMapper.insert(writeDO); - // 3. 调用大模型,写作生成 - ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); - // 3.1 构建消息列表 - List chatMessages = buildMessages(generateReqVO, systemMessage); - // 3.2 构建提示词 - Prompt prompt = new Prompt(chatMessages, chatOptions); - // 3.3 流式调用 + // 3.1 构建 Prompt,并进行调用 + Prompt prompt = buildPrompt(generateReqVO, model, systemMessage); Flux streamResponse = chatModel.stream(prompt); - // 4. 流式返回 + // 3.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; @@ -125,6 +121,15 @@ public class AiWriteServiceImpl implements AiWriteService { return model; } + private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) { + // 1. 构建 message 列表 + List chatMessages = buildMessages(generateReqVO, systemMessage); + // 2. 构建 options 对象 + AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); + ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); + return new Prompt(chatMessages, options); + } + private List buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) { List chatMessages = new ArrayList<>(); if (StrUtil.isNotBlank(systemMessage)) { @@ -132,11 +137,11 @@ public class AiWriteServiceImpl implements AiWriteService { chatMessages.add(new SystemMessage(systemMessage)); } // 1.2 用户输入 - chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO))); + chatMessages.add(new UserMessage(buildUserMessage(generateReqVO))); return chatMessages; } - private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { + private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) { String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());