Merge branch 'refs/heads/up-master-jdk21-ai' into master-jdk21-ai

# Conflicts:
#	yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java
#	yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/DocServiceImpl.java
This commit is contained in:
xiaoxin 2024-08-15 16:00:56 +08:00
commit d3a4c3c718
18 changed files with 273 additions and 17 deletions

View File

@ -18,7 +18,7 @@
<name>${project.artifactId}</name> <name>${project.artifactId}</name>
<description> <description>
ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维图等功能。 ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维图等功能。
目前已接入各种模型,不限于: 目前已接入各种模型,不限于:
国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek 国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek
国外OpenAI、Ollama、Midjourney、StableDiffusion、Suno 国外OpenAI、Ollama、Midjourney、StableDiffusion、Suno

View File

@ -22,7 +22,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
除此之外不需要除了正文内容外的其他回复如标题开头任何解释性语句或道歉 除此之外不需要除了正文内容外的其他回复如标题开头任何解释性语句或道歉
"""), """),
AI_MIND_MAP_ROLE(2, "图助手", """ AI_MIND_MAP_ROLE(2, "图助手", """
你是一位非常优秀的思维导图助手你会把用户的所有提问都总结成思维导图然后以 Markdown 格式输出markdown 只需要输出一级标题二级标题三级标题四级标题最多输出四级除此之外不要输出任何其他 markdown 标记下面是一个合格的例子 你是一位非常优秀的思维导图助手你会把用户的所有提问都总结成思维导图然后以 Markdown 格式输出markdown 只需要输出一级标题二级标题三级标题四级标题最多输出四级除此之外不要输出任何其他 markdown 标记下面是一个合格的例子
# Geek-AI 助手 # Geek-AI 助手
## 完整的开源系统 ## 完整的开源系统

View File

@ -45,11 +45,13 @@ public interface ErrorCodeConstants {
// ========== API 音乐 1-040-006-000 ========== // ========== API 音乐 1-040-006-000 ==========
ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!"); ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!");
// ========== API 写作 1-022-007-000 ========== // ========== API 写作 1-022-007-000 ==========
ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!"); ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!");
ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "写作生成异常!"); ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "写作生成异常!");
// ========== API 思维导图 1-040-008-000 ==========
ErrorCode MIND_MAP_NOT_EXISTS = new ErrorCode(1_040_008_000, "思维导图不存在!");
// ========== API 知识库 1-022-008-000 ========== // ========== API 知识库 1-022-008-000 ==========
ErrorCode KNOWLEDGE_NOT_EXISTS = new ErrorCode(1_022_008_000, "知识库不存在!"); ErrorCode KNOWLEDGE_NOT_EXISTS = new ErrorCode(1_022_008_000, "知识库不存在!");

View File

@ -12,7 +12,7 @@
<name>${project.artifactId}</name> <name>${project.artifactId}</name>
<description> <description>
ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维图等功能。 ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维图等功能。
目前已接入各种模型,不限于: 目前已接入各种模型,不限于:
国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek 国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek
国外OpenAI、Ollama、Midjourney、StableDiffusion、Suno 国外OpenAI、Ollama、Midjourney、StableDiffusion、Suno

View File

@ -1,20 +1,25 @@
package cn.iocoder.yudao.module.ai.controller.admin.mindmap; package cn.iocoder.yudao.module.ai.controller.admin.mindmap;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
import cn.iocoder.yudao.module.ai.service.mindmap.AiMindMapService; import cn.iocoder.yudao.module.ai.service.mindmap.AiMindMapService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import jakarta.annotation.security.PermitAll; import jakarta.annotation.security.PermitAll;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.*;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
@Tag(name = "管理后台 - AI 思维导图") @Tag(name = "管理后台 - AI 思维导图")
@ -26,10 +31,29 @@ public class AiMindMapController {
private AiMindMapService mindMapService; private AiMindMapService mindMapService;
@PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@Operation(summary = "图生成(流式)", description = "流式返回,响应较快") @Operation(summary = "图生成(流式)", description = "流式返回,响应较快")
@PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题 @PermitAll // 解决 SSE 最终响应的时候会被 Access Denied 拦截的问题
public Flux<CommonResult<String>> generateMindMap(@RequestBody @Valid AiMindMapGenerateReqVO generateReqVO) { public Flux<CommonResult<String>> generateMindMap(@RequestBody @Valid AiMindMapGenerateReqVO generateReqVO) {
return mindMapService.generateMindMap(generateReqVO, getLoginUserId()); return mindMapService.generateMindMap(generateReqVO, getLoginUserId());
} }
// ================ 导图管理 ================
@DeleteMapping("/delete")
@Operation(summary = "删除思维导图")
@Parameter(name = "id", description = "编号", required = true)
@PreAuthorize("@ss.hasPermission('ai:mind-map:delete')")
public CommonResult<Boolean> deleteMindMap(@RequestParam("id") Long id) {
mindMapService.deleteMindMap(id);
return success(true);
}
@GetMapping("/page")
@Operation(summary = "获得思维导图分页")
@PreAuthorize("@ss.hasPermission('ai:mind-map:query')")
public CommonResult<PageResult<AiMindMapRespVO>> getMindMapPage(@Valid AiMindMapPageReqVO pageReqVO) {
PageResult<AiMindMapDO> pageResult = mindMapService.getMindMapPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiMindMapRespVO.class));
}
} }

View File

@ -0,0 +1,30 @@
package cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import org.springframework.format.annotation.DateTimeFormat;
import java.time.LocalDateTime;
import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND;
@Schema(description = "管理后台 - AI 思维导图分页 Request VO")
@Data
@EqualsAndHashCode(callSuper = true)
@ToString(callSuper = true)
public class AiMindMapPageReqVO extends PageParam {
@Schema(description = "用户编号", example = "4325")
private Long userId;
@Schema(description = "生成内容提示", example = "Java 学习路线")
private String prompt;
@Schema(description = "创建时间")
@DateTimeFormat(pattern = FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND)
private LocalDateTime[] createTime;
}

View File

@ -0,0 +1,36 @@
package cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
@Schema(description = "管理后台 - AI 思维导图 Response VO")
@Data
public class AiMindMapRespVO {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "3373")
private Long id;
@Schema(description = "用户编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "4325")
private Long userId;
@Schema(description = "生成内容提示", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 学习路线")
private String prompt;
@Schema(description = "生成的思维导图内容")
private String generatedContent;
@Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
private String platform;
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "gpt-3.5-turbo-0125")
private String model;
@Schema(description = "错误信息")
private String errorMessage;
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime createTime;
}

View File

@ -1,6 +1,9 @@
package cn.iocoder.yudao.module.ai.dal.mysql.mindmap; package cn.iocoder.yudao.module.ai.dal.mysql.mindmap;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
@ -11,4 +14,13 @@ import org.apache.ibatis.annotations.Mapper;
*/ */
@Mapper @Mapper
public interface AiMindMapMapper extends BaseMapperX<AiMindMapDO> { public interface AiMindMapMapper extends BaseMapperX<AiMindMapDO> {
default PageResult<AiMindMapDO> selectPage(AiMindMapPageReqVO reqVO) {
return selectPage(reqVO, new LambdaQueryWrapperX<AiMindMapDO>()
.eqIfPresent(AiMindMapDO::getUserId, reqVO.getUserId())
.eqIfPresent(AiMindMapDO::getPrompt, reqVO.getPrompt())
.betweenIfPresent(AiMindMapDO::getCreateTime, reqVO.getCreateTime())
.orderByDesc(AiMindMapDO::getId));
}
} }

View File

@ -1,7 +1,10 @@
package cn.iocoder.yudao.module.ai.service.mindmap; package cn.iocoder.yudao.module.ai.service.mindmap;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
/** /**
@ -20,4 +23,19 @@ public interface AiMindMapService {
*/ */
Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId); Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId);
/**
* 删除思维导图
*
* @param id 编号
*/
void deleteMindMap(Long id);
/**
* 获得思维导图分页
*
* @param pageReqVO 分页查询
* @return 思维导图分页
*/
PageResult<AiMindMapDO> getMindMapPage(AiMindMapPageReqVO pageReqVO);
} }

View File

@ -6,9 +6,11 @@ import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils; import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
@ -33,8 +35,10 @@ import reactor.core.publisher.Flux;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_EXISTS;
/** /**
* AI 思维导图 Service 实现类 * AI 思维导图 Service 实现类
@ -57,10 +61,10 @@ public class AiMindMapServiceImpl implements AiMindMapService {
@Override @Override
public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
// 1. 获取图模型尝试获取思维导图助手角色如果没有则使用默认模型 // 1. 获取图模型尝试获取思维导图助手角色如果没有则使用默认模型
AiChatRoleDO role = CollUtil.getFirst( AiChatRoleDO role = CollUtil.getFirst(
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
// 1.1 获取图执行模型 // 1.1 获取图执行模型
AiChatModelDO model = getModel(role); AiChatModelDO model = getModel(role);
// 1.2 获取角色设定消息 // 1.2 获取角色设定消息
String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage()) String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
@ -131,4 +135,23 @@ public class AiMindMapServiceImpl implements AiMindMapService {
return model; return model;
} }
@Override
public void deleteMindMap(Long id) {
// 校验存在
validateMindMapExists(id);
// 删除
mindMapMapper.deleteById(id);
}
private void validateMindMapExists(Long id) {
if (mindMapMapper.selectById(id) == null) {
throw exception(MIND_MAP_NOT_EXISTS);
}
}
@Override
public PageResult<AiMindMapDO> getMindMapPage(AiMindMapPageReqVO pageReqVO) {
return mindMapMapper.selectPage(pageReqVO);
}
} }

View File

@ -23,12 +23,16 @@
<artifactId>spring-ai-zhipuai-spring-boot-starter</artifactId> <artifactId>spring-ai-zhipuai-spring-boot-starter</artifactId>
<version>${spring-ai.version}</version> <version>${spring-ai.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId> <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<version>${spring-ai.version}</version> <version>${spring-ai.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-azure-openai-spring-boot-starter</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId> <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>

View File

@ -12,6 +12,7 @@ import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties; import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
import org.springframework.ai.document.MetadataMode; import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.ai.vectorstore.RedisVectorStore;
@ -21,6 +22,7 @@ import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Lazy;
import redis.clients.jedis.JedisPooled; import redis.clients.jedis.JedisPooled;
/** /**
@ -82,7 +84,8 @@ public class YudaoAiAutoConfiguration {
// ========== rag 相关 ========== // ========== rag 相关 ==========
@Bean @Bean
public TransformersEmbeddingModel transformersEmbeddingClient() { @Lazy // TODO 芋艿临时注释避免无法启动
public EmbeddingModel transformersEmbeddingClient() {
return new TransformersEmbeddingModel(MetadataMode.EMBED); return new TransformersEmbeddingModel(MetadataMode.EMBED);
} }
@ -90,6 +93,7 @@ public class YudaoAiAutoConfiguration {
* 我们启动有加载很多 Embedding 模型不晓得取哪个好 new TransformersEmbeddingModel * 我们启动有加载很多 Embedding 模型不晓得取哪个好 new TransformersEmbeddingModel
*/ */
@Bean @Bean
@Lazy // TODO 芋艿临时注释避免无法启动
public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties, public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties,
RedisProperties redisProperties) { RedisProperties redisProperties) {
var config = RedisVectorStore.RedisVectorStoreConfig.builder() var config = RedisVectorStore.RedisVectorStoreConfig.builder()
@ -105,6 +109,7 @@ public class YudaoAiAutoConfiguration {
} }
@Bean @Bean
@Lazy // TODO 芋艿临时注释避免无法启动
public TokenTextSplitter tokenTextSplitter() { public TokenTextSplitter tokenTextSplitter() {
return new TokenTextSplitter(500, 100, 5, 10000, true); return new TokenTextSplitter(500, 100, 5, 10000, true);
} }

View File

@ -22,7 +22,8 @@ public enum AiPlatformEnum {
// ========== 国外平台 ========== // ========== 国外平台 ==========
OPENAI("OpenAI", "OpenAI"), OPENAI("OpenAI", "OpenAI"), // OpenAI 官方
AZURE_OPENAI("AzureOpenAI", "AzureOpenAI"), // OpenAI 微软
OLLAMA("Ollama", "Ollama"), OLLAMA("Ollama", "Ollama"),
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI

View File

@ -21,6 +21,10 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties; import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation; import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.azure.ai.openai.OpenAIClient;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
@ -31,6 +35,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallbackContext;
@ -82,6 +87,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return buildXingHuoChatModel(apiKey); return buildXingHuoChatModel(apiKey);
case OPENAI: case OPENAI:
return buildOpenAiChatModel(apiKey, url); return buildOpenAiChatModel(apiKey, url);
case AZURE_OPENAI:
return buildAzureOpenAiChatModel(apiKey, url);
case OLLAMA: case OLLAMA:
return buildOllamaChatModel(url); return buildOllamaChatModel(url);
default: default:
@ -106,6 +113,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return SpringUtil.getBean(XingHuoChatModel.class); return SpringUtil.getBean(XingHuoChatModel.class);
case OPENAI: case OPENAI:
return SpringUtil.getBean(OpenAiChatModel.class); return SpringUtil.getBean(OpenAiChatModel.class);
case AZURE_OPENAI:
return SpringUtil.getBean(AzureOpenAiChatModel.class);
case OLLAMA: case OLLAMA:
return SpringUtil.getBean(OllamaChatModel.class); return SpringUtil.getBean(OllamaChatModel.class);
default: default:
@ -268,6 +277,21 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new OpenAiChatModel(openAiApi); return new OpenAiChatModel(openAiApi);
} }
/**
* 可参考 {@link AzureOpenAiAutoConfiguration}
*/
private static AzureOpenAiChatModel buildAzureOpenAiChatModel(String apiKey, String url) {
AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration();
// 创建 OpenAIClient 对象
AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
connectionProperties.setApiKey(apiKey);
connectionProperties.setEndpoint(url);
OpenAIClient openAIClient = azureOpenAiAutoConfiguration.openAIClient(connectionProperties);
// 获取 AzureOpenAiChatProperties 对象
AzureOpenAiChatProperties chatProperties = SpringUtil.getBean(AzureOpenAiChatProperties.class);
return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties, null, null);
}
/** /**
* 可参考 {@link OpenAiAutoConfiguration} * 可参考 {@link OpenAiAutoConfiguration}
*/ */

View File

@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.OllamaOptions;
@ -35,6 +36,9 @@ public class AiUtils {
return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build(); return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
case OPENAI: case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case AZURE_OPENAI:
// TODO 芋艿貌似没 model 字段
return AzureOpenAiChatOptions.builder().withDeploymentName(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA: case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
default: default:

View File

@ -0,0 +1,70 @@
package cn.iocoder.yudao.framework.ai.chat;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.util.ClientOptions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import static org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties.DEFAULT_DEPLOYMENT_NAME;
/**
* {@link AzureOpenAiChatModel} 集成测试
*
* @author 芋道源码
*/
public class AzureOpenAIChatModelTests {
private final OpenAIClient openAiApi = (new OpenAIClientBuilder())
.endpoint("https://eastusprejade.openai.azure.com")
.credential(new AzureKeyCredential("xxx"))
.clientOptions((new ClientOptions()).setApplicationId("spring-ai"))
.buildClient();
private final AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAiApi,
AzureOpenAiChatOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).build());
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
System.out.println(response.getResult().getOutput());
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(response -> {
// System.out.println(response);
System.out.println(response.getResult().getOutput());
}).then().block();
}
}

View File

@ -1,6 +1,5 @@
package cn.iocoder.yudao.framework.ai.chat; package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
@ -17,7 +16,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
/** /**
* {@link XingHuoChatModel} 集成测试 * {@link OpenAiChatModel} 集成测试
* *
* @author 芋道源码 * @author 芋道源码
*/ */

View File

@ -162,9 +162,13 @@ spring:
secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
zhipuai: # 智谱 AI zhipuai: # 智谱 AI
api-key: 32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs api-key: 32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs
openai: openai: # OpenAI 官方
api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z
base-url: https://api.gptsapi.net base-url: https://api.gptsapi.net
azure: # OpenAI 微软
openai:
endpoint: https://eastusprejade.openai.azure.com
api-key: xxx
ollama: ollama:
base-url: http://127.0.0.1:11434 base-url: http://127.0.0.1:11434
chat: chat: