【功能新增】AI:集成 Azure 的 OpenAI 模型

This commit is contained in:
YunaiV 2024-08-10 14:34:57 +08:00
parent b453856864
commit 83bd96d672
7 changed files with 47 additions and 6 deletions

View File

@ -7,7 +7,6 @@ import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
@ -16,7 +15,7 @@ import java.util.List;
* *
* @author xiaoxin * @author xiaoxin
*/ */
@Service //@Service // TODO 芋艿临时注释避免无法启动
@Slf4j @Slf4j
public class DocServiceImpl implements DocService { public class DocServiceImpl implements DocService {

View File

@ -23,12 +23,16 @@
<artifactId>spring-ai-zhipuai-spring-boot-starter</artifactId> <artifactId>spring-ai-zhipuai-spring-boot-starter</artifactId>
<version>${spring-ai.version}</version> <version>${spring-ai.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId> <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<version>${spring-ai.version}</version> <version>${spring-ai.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-azure-openai-spring-boot-starter</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId> <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>

View File

@ -12,6 +12,7 @@ import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties; import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
import org.springframework.ai.document.MetadataMode; import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.ai.vectorstore.RedisVectorStore;
@ -21,6 +22,7 @@ import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Lazy;
import redis.clients.jedis.JedisPooled; import redis.clients.jedis.JedisPooled;
/** /**
@ -82,7 +84,8 @@ public class YudaoAiAutoConfiguration {
// ========== rag 相关 ========== // ========== rag 相关 ==========
@Bean @Bean
public TransformersEmbeddingModel transformersEmbeddingClient() { @Lazy // TODO 芋艿临时注释避免无法启动
public EmbeddingModel transformersEmbeddingClient() {
return new TransformersEmbeddingModel(MetadataMode.EMBED); return new TransformersEmbeddingModel(MetadataMode.EMBED);
} }
@ -90,6 +93,7 @@ public class YudaoAiAutoConfiguration {
* 我们启动有加载很多 Embedding 模型不晓得取哪个好 new TransformersEmbeddingModel * 我们启动有加载很多 Embedding 模型不晓得取哪个好 new TransformersEmbeddingModel
*/ */
@Bean @Bean
@Lazy // TODO 芋艿临时注释避免无法启动
public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties, public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties,
RedisProperties redisProperties) { RedisProperties redisProperties) {
var config = RedisVectorStore.RedisVectorStoreConfig.builder() var config = RedisVectorStore.RedisVectorStoreConfig.builder()
@ -105,6 +109,7 @@ public class YudaoAiAutoConfiguration {
} }
@Bean @Bean
@Lazy // TODO 芋艿临时注释避免无法启动
public TokenTextSplitter tokenTextSplitter() { public TokenTextSplitter tokenTextSplitter() {
return new TokenTextSplitter(500, 100, 5, 10000, true); return new TokenTextSplitter(500, 100, 5, 10000, true);
} }

View File

@ -22,7 +22,8 @@ public enum AiPlatformEnum {
// ========== 国外平台 ========== // ========== 国外平台 ==========
OPENAI("OpenAI", "OpenAI"), OPENAI("OpenAI", "OpenAI"), // OpenAI 官方
AZURE_OPENAI("AzureOpenAI", "AzureOpenAI"), // OpenAI 微软
OLLAMA("Ollama", "Ollama"), OLLAMA("Ollama", "Ollama"),
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI

View File

@ -21,6 +21,10 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties; import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation; import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.azure.ai.openai.OpenAIClient;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
@ -31,6 +35,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallbackContext;
@ -82,6 +87,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return buildXingHuoChatModel(apiKey); return buildXingHuoChatModel(apiKey);
case OPENAI: case OPENAI:
return buildOpenAiChatModel(apiKey, url); return buildOpenAiChatModel(apiKey, url);
case AZURE_OPENAI:
return buildAzureOpenAiChatModel(apiKey, url);
case OLLAMA: case OLLAMA:
return buildOllamaChatModel(url); return buildOllamaChatModel(url);
default: default:
@ -106,6 +113,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return SpringUtil.getBean(XingHuoChatModel.class); return SpringUtil.getBean(XingHuoChatModel.class);
case OPENAI: case OPENAI:
return SpringUtil.getBean(OpenAiChatModel.class); return SpringUtil.getBean(OpenAiChatModel.class);
case AZURE_OPENAI:
return SpringUtil.getBean(AzureOpenAiChatModel.class);
case OLLAMA: case OLLAMA:
return SpringUtil.getBean(OllamaChatModel.class); return SpringUtil.getBean(OllamaChatModel.class);
default: default:
@ -268,6 +277,21 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new OpenAiChatModel(openAiApi); return new OpenAiChatModel(openAiApi);
} }
/**
* 可参考 {@link AzureOpenAiAutoConfiguration}
*/
private static AzureOpenAiChatModel buildAzureOpenAiChatModel(String apiKey, String url) {
AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration();
// 创建 OpenAIClient 对象
AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
connectionProperties.setApiKey(apiKey);
connectionProperties.setEndpoint(url);
OpenAIClient openAIClient = azureOpenAiAutoConfiguration.openAIClient(connectionProperties);
// 获取 AzureOpenAiChatProperties 对象
AzureOpenAiChatProperties chatProperties = SpringUtil.getBean(AzureOpenAiChatProperties.class);
return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties, null, null);
}
/** /**
* 可参考 {@link OpenAiAutoConfiguration} * 可参考 {@link OpenAiAutoConfiguration}
*/ */

View File

@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.OllamaOptions;
@ -35,6 +36,9 @@ public class AiUtils {
return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build(); return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
case OPENAI: case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case AZURE_OPENAI:
// TODO 芋艿貌似没 model 字段
return AzureOpenAiChatOptions.builder().withDeploymentName(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA: case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
default: default:

View File

@ -162,9 +162,13 @@ spring:
secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
zhipuai: # 智谱 AI zhipuai: # 智谱 AI
api-key: 32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs api-key: 32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs
openai: openai: # OpenAI 官方
api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z
base-url: https://api.gptsapi.net base-url: https://api.gptsapi.net
azure: # OpenAI 微软
openai:
endpoint: https://eastusprejade.openai.azure.com
api-key: xxx
ollama: ollama:
base-url: http://127.0.0.1:11434 base-url: http://127.0.0.1:11434
chat: chat: