mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2024-11-26 01:01:52 +08:00
从spring-ai 迁移 chat 模块和依赖模块过来
This commit is contained in:
parent
37790d5fc1
commit
f6ea1bda76
105
yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml
Normal file
105
yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
<parent>
|
||||||
|
<groupId>cn.iocoder.boot</groupId>
|
||||||
|
<artifactId>yudao-module-ai</artifactId>
|
||||||
|
<version>${revision}</version>
|
||||||
|
</parent>
|
||||||
|
|
||||||
|
<artifactId>yudao-spring-boot-starter-ai</artifactId>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<maven.compiler.source>21</maven.compiler.source>
|
||||||
|
<maven.compiler.target>21</maven.compiler.target>
|
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
</properties>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework</groupId>
|
||||||
|
<artifactId>spring-core</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
|
<artifactId>jackson-databind</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework</groupId>
|
||||||
|
<artifactId>spring-context</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>net.jodah</groupId>
|
||||||
|
<artifactId>typetools</artifactId>
|
||||||
|
<version>0.6.3</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.github.victools</groupId>
|
||||||
|
<artifactId>jsonschema-module-jackson</artifactId>
|
||||||
|
<version>4.31.1</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.github.victools</groupId>
|
||||||
|
<artifactId>jsonschema-module-swagger-2</artifactId>
|
||||||
|
<version>4.33.1</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.github.victools</groupId>
|
||||||
|
<artifactId>jsonschema-generator</artifactId>
|
||||||
|
<version>4.31.1</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>io.projectreactor</groupId>
|
||||||
|
<artifactId>reactor-core</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.cloud</groupId>
|
||||||
|
<artifactId>spring-cloud-function-context</artifactId>
|
||||||
|
<version>4.1.0</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.antlr</groupId>
|
||||||
|
<artifactId>stringtemplate</artifactId>
|
||||||
|
<version>4.0.2</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.projectlombok</groupId>
|
||||||
|
<artifactId>lombok</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework</groupId>
|
||||||
|
<artifactId>spring-web</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework</groupId>
|
||||||
|
<artifactId>spring-webflux</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.retry</groupId>
|
||||||
|
<artifactId>spring-retry</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>cn.hutool</groupId>
|
||||||
|
<artifactId>hutool-all</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.squareup.okhttp3</groupId>
|
||||||
|
<artifactId>okhttp</artifactId>
|
||||||
|
<version>4.12.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
</project>
|
@ -0,0 +1,36 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat;
|
||||||
|
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.UserMessage;
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
||||||
|
import cn.iocoder.yudao.framework.ai.model.ModelClient;
|
||||||
|
|
||||||
|
@FunctionalInterface
|
||||||
|
public interface ChatClient extends ModelClient<Prompt, ChatResponse> {
|
||||||
|
|
||||||
|
default String call(String message) {
|
||||||
|
Prompt prompt = new Prompt(new UserMessage(message));
|
||||||
|
Generation generation = call(prompt).getResult();
|
||||||
|
return (generation != null) ? generation.getOutput().getContent() : "";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
ChatResponse call(Prompt prompt);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,112 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.metadata.ChatResponseMetadata;
|
||||||
|
import cn.iocoder.yudao.framework.ai.model.ModelResponse;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 人工智能提供商返回的聊天完成(例如生成)响应。
|
||||||
|
*
|
||||||
|
* The chat completion (e.g. generation) response returned by an AI provider.
|
||||||
|
*/
|
||||||
|
public class ChatResponse implements ModelResponse<Generation> {
|
||||||
|
|
||||||
|
private final ChatResponseMetadata chatResponseMetadata;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List of generated messages returned by the AI provider.
|
||||||
|
*/
|
||||||
|
private final List<Generation> generations;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct a new {@link ChatResponse} instance without metadata.
|
||||||
|
* @param generations the {@link List} of {@link Generation} returned by the AI
|
||||||
|
* provider.
|
||||||
|
*/
|
||||||
|
public ChatResponse(List<Generation> generations) {
|
||||||
|
this(generations, ChatResponseMetadata.NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct a new {@link ChatResponse} instance.
|
||||||
|
* @param generations the {@link List} of {@link Generation} returned by the AI
|
||||||
|
* provider.
|
||||||
|
* @param chatResponseMetadata {@link ChatResponseMetadata} containing information
|
||||||
|
* about the use of the AI provider's API.
|
||||||
|
*/
|
||||||
|
public ChatResponse(List<Generation> generations, ChatResponseMetadata chatResponseMetadata) {
|
||||||
|
this.chatResponseMetadata = chatResponseMetadata;
|
||||||
|
this.generations = List.copyOf(generations);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The {@link List} of {@link Generation generated outputs}.
|
||||||
|
* <p>
|
||||||
|
* It is a {@link List} of {@link List lists} because the Prompt could request
|
||||||
|
* multiple output {@link Generation generations}.
|
||||||
|
* @return the {@link List} of {@link Generation generated outputs}.
|
||||||
|
*/
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Generation> getResults() {
|
||||||
|
return this.generations;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Returns the first {@link Generation} in the generations list.
|
||||||
|
*/
|
||||||
|
public Generation getResult() {
|
||||||
|
if (CollectionUtils.isEmpty(this.generations)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return this.generations.get(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Returns {@link ChatResponseMetadata} containing information about the use
|
||||||
|
* of the AI provider's API.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public ChatResponseMetadata getMetadata() {
|
||||||
|
return this.chatResponseMetadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "ChatResponse [metadata=" + chatResponseMetadata + ", generations=" + generations + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o)
|
||||||
|
return true;
|
||||||
|
if (!(o instanceof ChatResponse that))
|
||||||
|
return false;
|
||||||
|
return Objects.equals(chatResponseMetadata, that.chatResponseMetadata)
|
||||||
|
&& Objects.equals(generations, that.generations);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(chatResponseMetadata, generations);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,83 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.AssistantMessage;
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.metadata.ChatGenerationMetadata;
|
||||||
|
import cn.iocoder.yudao.framework.ai.model.ModelResult;
|
||||||
|
import org.springframework.lang.Nullable;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 表示AI返回的响应。
|
||||||
|
*
|
||||||
|
* Represents a response returned by the AI.
|
||||||
|
*/
|
||||||
|
public class Generation implements ModelResult<AssistantMessage> {
|
||||||
|
|
||||||
|
private AssistantMessage assistantMessage;
|
||||||
|
|
||||||
|
private ChatGenerationMetadata chatGenerationMetadata;
|
||||||
|
|
||||||
|
public Generation(String text) {
|
||||||
|
this.assistantMessage = new AssistantMessage(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Generation(String text, Map<String, Object> properties) {
|
||||||
|
this.assistantMessage = new AssistantMessage(text, properties);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AssistantMessage getOutput() {
|
||||||
|
return this.assistantMessage;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ChatGenerationMetadata getMetadata() {
|
||||||
|
ChatGenerationMetadata chatGenerationMetadata = this.chatGenerationMetadata;
|
||||||
|
return chatGenerationMetadata != null ? chatGenerationMetadata : ChatGenerationMetadata.NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Generation withGenerationMetadata(@Nullable ChatGenerationMetadata chatGenerationMetadata) {
|
||||||
|
this.chatGenerationMetadata = chatGenerationMetadata;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o)
|
||||||
|
return true;
|
||||||
|
if (!(o instanceof Generation that))
|
||||||
|
return false;
|
||||||
|
return Objects.equals(assistantMessage, that.assistantMessage)
|
||||||
|
&& Objects.equals(chatGenerationMetadata, that.chatGenerationMetadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(assistantMessage, chatGenerationMetadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "Generation{" + "assistantMessage=" + assistantMessage + ", chatGenerationMetadata="
|
||||||
|
+ chatGenerationMetadata + '}';
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,29 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
||||||
|
import cn.iocoder.yudao.framework.ai.model.StreamingModelClient;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
@FunctionalInterface
|
||||||
|
public interface StreamingChatClient extends StreamingModelClient<Prompt, ChatResponse> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Flux<ChatResponse> stream(Prompt prompt);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,151 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.messages;
|
||||||
|
|
||||||
|
import org.springframework.core.io.Resource;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.StreamUtils;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.charset.Charset;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public abstract class AbstractMessage implements Message {
|
||||||
|
|
||||||
|
protected final MessageType messageType;
|
||||||
|
|
||||||
|
protected final String textContent;
|
||||||
|
|
||||||
|
protected final List<MediaData> mediaData;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Additional options for the message to influence the response, not a generative map.
|
||||||
|
*/
|
||||||
|
protected final Map<String, Object> properties;
|
||||||
|
|
||||||
|
protected AbstractMessage(MessageType messageType, String content) {
|
||||||
|
this(messageType, content, Map.of());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected AbstractMessage(MessageType messageType, String content, Map<String, Object> messageProperties) {
|
||||||
|
Assert.notNull(messageType, "Message type must not be null");
|
||||||
|
// Assert.notNull(content, "Content must not be null");
|
||||||
|
this.messageType = messageType;
|
||||||
|
this.textContent = content;
|
||||||
|
this.mediaData = new ArrayList<>();
|
||||||
|
this.properties = messageProperties;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected AbstractMessage(MessageType messageType, String textContent, List<MediaData> mediaData) {
|
||||||
|
this(messageType, textContent, mediaData, Map.of());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected AbstractMessage(MessageType messageType, String textContent, List<MediaData> mediaData,
|
||||||
|
Map<String, Object> messageProperties) {
|
||||||
|
|
||||||
|
Assert.notNull(messageType, "Message type must not be null");
|
||||||
|
Assert.notNull(textContent, "Content must not be null");
|
||||||
|
Assert.notNull(mediaData, "media data must not be null");
|
||||||
|
|
||||||
|
this.messageType = messageType;
|
||||||
|
this.textContent = textContent;
|
||||||
|
this.mediaData = new ArrayList<>(mediaData);
|
||||||
|
this.properties = messageProperties;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected AbstractMessage(MessageType messageType, Resource resource) {
|
||||||
|
this(messageType, resource, Collections.emptyMap());
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("null")
|
||||||
|
protected AbstractMessage(MessageType messageType, Resource resource, Map<String, Object> messageProperties) {
|
||||||
|
Assert.notNull(messageType, "Message type must not be null");
|
||||||
|
Assert.notNull(resource, "Resource must not be null");
|
||||||
|
|
||||||
|
this.messageType = messageType;
|
||||||
|
this.properties = messageProperties;
|
||||||
|
this.mediaData = new ArrayList<>();
|
||||||
|
|
||||||
|
try (InputStream inputStream = resource.getInputStream()) {
|
||||||
|
this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
|
||||||
|
}
|
||||||
|
catch (IOException ex) {
|
||||||
|
throw new RuntimeException("Failed to read resource", ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getContent() {
|
||||||
|
return this.textContent;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<MediaData> getMediaData() {
|
||||||
|
return this.mediaData;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Object> getProperties() {
|
||||||
|
return this.properties;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MessageType getMessageType() {
|
||||||
|
return this.messageType;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
final int prime = 31;
|
||||||
|
int result = 1;
|
||||||
|
result = prime * result + ((mediaData == null) ? 0 : mediaData.hashCode());
|
||||||
|
result = prime * result + ((properties == null) ? 0 : properties.hashCode());
|
||||||
|
result = prime * result + ((messageType == null) ? 0 : messageType.hashCode());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj)
|
||||||
|
return true;
|
||||||
|
if (obj == null)
|
||||||
|
return false;
|
||||||
|
if (getClass() != obj.getClass())
|
||||||
|
return false;
|
||||||
|
AbstractMessage other = (AbstractMessage) obj;
|
||||||
|
if (mediaData == null) {
|
||||||
|
if (other.mediaData != null)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (!mediaData.equals(other.mediaData))
|
||||||
|
return false;
|
||||||
|
if (properties == null) {
|
||||||
|
if (other.properties != null)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (!properties.equals(other.properties))
|
||||||
|
return false;
|
||||||
|
if (messageType != other.messageType)
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.messages;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 让生成人员知道内容是作为对用户的响应生成的。
|
||||||
|
* 此角色指示生成者先前在会话中生成的消息。
|
||||||
|
* 通过包括该系列中的辅助消息,您可以为生成的关于提供上下文之前在谈话中的交流。
|
||||||
|
*
|
||||||
|
* Lets the generative know the content was generated as a response to the user. This role
|
||||||
|
* indicates messages that the generative has previously generated in the conversation. By
|
||||||
|
* including assistant messages in the series, you provide context to the generative about
|
||||||
|
* prior exchanges in the conversation.
|
||||||
|
*/
|
||||||
|
public class AssistantMessage extends AbstractMessage {
|
||||||
|
|
||||||
|
public AssistantMessage(String content) {
|
||||||
|
super(MessageType.ASSISTANT, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AssistantMessage(String content, Map<String, Object> properties) {
|
||||||
|
super(MessageType.ASSISTANT, content, properties);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "AssistantMessage{" + "content='" + getContent() + '\'' + ", properties=" + properties + ", messageType="
|
||||||
|
+ messageType + '}';
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,39 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.messages;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class ChatMessage extends AbstractMessage {
|
||||||
|
|
||||||
|
public ChatMessage(String role, String content) {
|
||||||
|
super(MessageType.valueOf(role), content);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatMessage(String role, String content, Map<String, Object> properties) {
|
||||||
|
super(MessageType.valueOf(role), content, properties);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatMessage(MessageType messageType, String content) {
|
||||||
|
super(messageType, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatMessage(MessageType messageType, String content, Map<String, Object> properties) {
|
||||||
|
super(messageType, content, properties);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.messages;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class FunctionMessage extends AbstractMessage {
|
||||||
|
|
||||||
|
public FunctionMessage(String content) {
|
||||||
|
super(MessageType.SYSTEM, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionMessage(String content, Map<String, Object> properties) {
|
||||||
|
super(MessageType.SYSTEM, content, properties);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "FunctionMessage{" + "content='" + getContent() + '\'' + ", properties=" + properties + ", messageType="
|
||||||
|
+ messageType + '}';
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.messages;
|
||||||
|
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.MimeType;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
public class MediaData {
|
||||||
|
|
||||||
|
private final MimeType mimeType;
|
||||||
|
|
||||||
|
private final Object data;
|
||||||
|
|
||||||
|
public MediaData(MimeType mimeType, Object data) {
|
||||||
|
Assert.notNull(mimeType, "MimeType must not be null");
|
||||||
|
// Assert.notNull(data, "Data must not be null");
|
||||||
|
this.mimeType = mimeType;
|
||||||
|
this.data = data;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MimeType getMimeType() {
|
||||||
|
return this.mimeType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Object getData() {
|
||||||
|
return this.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,32 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.messages;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public interface Message {
|
||||||
|
|
||||||
|
String getContent();
|
||||||
|
|
||||||
|
List<MediaData> getMediaData();
|
||||||
|
|
||||||
|
Map<String, Object> getProperties();
|
||||||
|
|
||||||
|
MessageType getMessageType();
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,52 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.messages;
|
||||||
|
|
||||||
|
public enum MessageType {
|
||||||
|
|
||||||
|
// 用户消息
|
||||||
|
USER("user"),
|
||||||
|
|
||||||
|
// 之前会话的消息
|
||||||
|
ASSISTANT("assistant"),
|
||||||
|
|
||||||
|
// 根据注释说明:您可以使用系统消息来指示具有生成性,表现得像某个角色或以特定的方式提供答案总体安排
|
||||||
|
// 简单理解:在对话前,发送一条具有角色的信息让模型理解(如:你现在是一个10年拍摄经验的导演,拥有丰富的经验。 这样你就可以去问他,怎么拍一个短视频可以在抖音上火)
|
||||||
|
SYSTEM("system"),
|
||||||
|
|
||||||
|
// 函数?根据引用现在不支持,会抛出一个异常 ---> throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
|
||||||
|
FUNCTION("function");
|
||||||
|
|
||||||
|
private final String value;
|
||||||
|
|
||||||
|
MessageType(String value) {
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static MessageType fromValue(String value) {
|
||||||
|
for (MessageType messageType : MessageType.values()) {
|
||||||
|
if (messageType.getValue().equals(value)) {
|
||||||
|
return messageType;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("Invalid MessageType value: " + value);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,48 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.messages;
|
||||||
|
|
||||||
|
import org.springframework.core.io.Resource;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 作为输入传递的“system”类型的消息。系统消息给出高级别对话说明。
|
||||||
|
* 此角色通常提供高级说明对话。
|
||||||
|
* 例如,您可以使用系统消息来指示具有生成性,表现得像某个角色或以特定的方式提供答案总体安排
|
||||||
|
*
|
||||||
|
* A message of the type 'system' passed as input. The system message gives high level
|
||||||
|
* instructions for the conversation. This role typically provides high-level instructions
|
||||||
|
* for the conversation. For example, you might use a system message to instruct the
|
||||||
|
* generative to behave like a certain character or to provide answers in a specific
|
||||||
|
* format.
|
||||||
|
*/
|
||||||
|
public class SystemMessage extends AbstractMessage {
|
||||||
|
|
||||||
|
public SystemMessage(String content) {
|
||||||
|
super(MessageType.SYSTEM, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SystemMessage(Resource resource) {
|
||||||
|
super(MessageType.SYSTEM, resource);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "SystemMessage{" + "content='" + getContent() + '\'' + ", properties=" + properties + ", messageType="
|
||||||
|
+ messageType + '}';
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.messages;
|
||||||
|
|
||||||
|
import org.springframework.core.io.Resource;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 作为输入传递的“user”类型的消息具有用户角色的消息来自最终用户或开发者。
|
||||||
|
* 它们表示问题、提示或您想要的任何输入产生反应的。
|
||||||
|
*
|
||||||
|
* A message of the type 'user' passed as input Messages with the user role are from the
|
||||||
|
* end-user or developer. They represent questions, prompts, or any input that you want
|
||||||
|
* the generative to respond to.
|
||||||
|
*/
|
||||||
|
public class UserMessage extends AbstractMessage {
|
||||||
|
|
||||||
|
public UserMessage(String message) {
|
||||||
|
super(MessageType.USER, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
public UserMessage(Resource resource) {
|
||||||
|
super(MessageType.USER, resource);
|
||||||
|
}
|
||||||
|
|
||||||
|
public UserMessage(String textContent, List<MediaData> mediaDataList) {
|
||||||
|
super(MessageType.USER, textContent, mediaDataList);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "UserMessage{" + "content='" + getContent() + '\'' + ", properties=" + properties + ", messageType="
|
||||||
|
+ messageType + '}';
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,75 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.metadata;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.model.ResultMetadata;
|
||||||
|
import org.springframework.lang.Nullable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract Data Type (ADT) encapsulating information on the completion choices in the AI
|
||||||
|
* response.
|
||||||
|
*
|
||||||
|
* @author John Blum
|
||||||
|
* @since 0.7.0
|
||||||
|
*/
|
||||||
|
public interface ChatGenerationMetadata extends ResultMetadata {
|
||||||
|
|
||||||
|
ChatGenerationMetadata NULL = ChatGenerationMetadata.from(null, null);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factory method used to construct a new {@link ChatGenerationMetadata} from the
|
||||||
|
* given {@link String finish reason} and content filter metadata.
|
||||||
|
* @param finishReason {@link String} contain the reason for the choice completion.
|
||||||
|
* @param contentFilterMetadata underlying AI provider metadata for filtering applied
|
||||||
|
* to generation content.
|
||||||
|
* @return a new {@link ChatGenerationMetadata} from the given {@link String finish
|
||||||
|
* reason} and content filter metadata.
|
||||||
|
*/
|
||||||
|
static ChatGenerationMetadata from(String finishReason, Object contentFilterMetadata) {
|
||||||
|
return new ChatGenerationMetadata() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public <T> T getContentFilterMetadata() {
|
||||||
|
return (T) contentFilterMetadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getFinishReason() {
|
||||||
|
return finishReason;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the underlying AI provider metadata for filtering applied to generation
|
||||||
|
* content.
|
||||||
|
* @param <T> {@link Class Type} used to cast the filtered content metadata into the
|
||||||
|
* AI provider-specific type.
|
||||||
|
* @return the underlying AI provider metadata for filtering applied to generation
|
||||||
|
* content.
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
<T> T getContentFilterMetadata();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the {@link String reason} this choice completed for the generation.
|
||||||
|
* @return the {@link String reason} this choice completed for the generation.
|
||||||
|
*/
|
||||||
|
String getFinishReason();
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,58 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.metadata;
|
||||||
|
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.model.ResponseMetadata;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract Data Type (ADT) modeling common AI provider metadata returned in an AI
|
||||||
|
* response.
|
||||||
|
*
|
||||||
|
* 抽象数据类型(ADT)建模AI响应中返回的常见AI提供者元数据。
|
||||||
|
*
|
||||||
|
* @author John Blum
|
||||||
|
* @since 0.7.0
|
||||||
|
*/
|
||||||
|
public interface ChatResponseMetadata extends ResponseMetadata {
|
||||||
|
|
||||||
|
ChatResponseMetadata NULL = new ChatResponseMetadata() {
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns AI provider specific metadata on rate limits.
|
||||||
|
* @return AI provider specific metadata on rate limits.
|
||||||
|
* @see RateLimit
|
||||||
|
*/
|
||||||
|
default RateLimit getRateLimit() {
|
||||||
|
return new EmptyRateLimit();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns AI provider specific metadata on API usage.
|
||||||
|
* @return AI provider specific metadata on API usage.
|
||||||
|
* @see Usage
|
||||||
|
*/
|
||||||
|
default Usage getUsage() {
|
||||||
|
return new EmptyUsage();
|
||||||
|
}
|
||||||
|
|
||||||
|
default PromptMetadata getPromptMetadata() {
|
||||||
|
return PromptMetadata.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,59 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.metadata;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A RateLimit implementation that returns zero for all property getters
|
||||||
|
*
|
||||||
|
* @author John Blum
|
||||||
|
* @since 0.7.0
|
||||||
|
*/
|
||||||
|
public class EmptyRateLimit implements RateLimit {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Long getRequestsLimit() {
|
||||||
|
return 0L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Long getRequestsRemaining() {
|
||||||
|
return 0L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Duration getRequestsReset() {
|
||||||
|
return Duration.ZERO;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Long getTokensLimit() {
|
||||||
|
return 0L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Long getTokensRemaining() {
|
||||||
|
return 0L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Duration getTokensReset() {
|
||||||
|
return Duration.ZERO;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.metadata;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A EmpytUsage implementation that returns zero for all property getters
|
||||||
|
*
|
||||||
|
* @author John Blum
|
||||||
|
* @since 0.7.0
|
||||||
|
*/
|
||||||
|
public class EmptyUsage implements Usage {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Long getPromptTokens() {
|
||||||
|
return 0L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Long getGenerationTokens() {
|
||||||
|
return 0L;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,136 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.metadata;
|
||||||
|
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.stream.StreamSupport;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract Data Type (ADT) modeling metadata gathered by the AI during request
|
||||||
|
* processing.
|
||||||
|
*
|
||||||
|
* @author John Blum
|
||||||
|
* @since 0.7.0
|
||||||
|
*/
|
||||||
|
@FunctionalInterface
|
||||||
|
public interface PromptMetadata extends Iterable<PromptMetadata.PromptFilterMetadata> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factory method used to create empty {@link PromptMetadata} when the information is
|
||||||
|
* not supplied by the AI provider.
|
||||||
|
* @return empty {@link PromptMetadata}.
|
||||||
|
*/
|
||||||
|
static PromptMetadata empty() {
|
||||||
|
return of();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factory method used to create a new {@link PromptMetadata} composed of an array of
|
||||||
|
* {@link PromptFilterMetadata}.
|
||||||
|
* @param array array of {@link PromptFilterMetadata} used to compose the
|
||||||
|
* {@link PromptMetadata}.
|
||||||
|
* @return a new {@link PromptMetadata} composed of an array of
|
||||||
|
* {@link PromptFilterMetadata}.
|
||||||
|
*/
|
||||||
|
static <T> PromptMetadata of(PromptFilterMetadata... array) {
|
||||||
|
return of(Arrays.asList(array));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factory method used to create a new {@link PromptMetadata} composed of an
|
||||||
|
* {@link Iterable} of {@link PromptFilterMetadata}.
|
||||||
|
* @param iterable {@link Iterable} of {@link PromptFilterMetadata} used to compose
|
||||||
|
* the {@link PromptMetadata}.
|
||||||
|
* @return a new {@link PromptMetadata} composed of an {@link Iterable} of
|
||||||
|
* {@link PromptFilterMetadata}.
|
||||||
|
*/
|
||||||
|
static PromptMetadata of(Iterable<PromptFilterMetadata> iterable) {
|
||||||
|
Assert.notNull(iterable, "An Iterable of PromptFilterMetadata must not be null");
|
||||||
|
return iterable::iterator;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an {@link Optional} {@link PromptFilterMetadata} at the given index.
|
||||||
|
* @param promptIndex index of the {@link PromptFilterMetadata} contained in this
|
||||||
|
* {@link PromptMetadata}.
|
||||||
|
* @return {@link Optional} {@link PromptFilterMetadata} at the given index.
|
||||||
|
* @throws IllegalArgumentException if the prompt index is less than 0.
|
||||||
|
*/
|
||||||
|
default Optional<PromptFilterMetadata> findByPromptIndex(int promptIndex) {
|
||||||
|
|
||||||
|
Assert.isTrue(promptIndex > -1, "Prompt index [%d] must be greater than equal to 0".formatted(promptIndex));
|
||||||
|
|
||||||
|
return StreamSupport.stream(this.spliterator(), false)
|
||||||
|
.filter(promptFilterMetadata -> promptFilterMetadata.getPromptIndex() == promptIndex)
|
||||||
|
.findFirst();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract Data Type (ADT) modeling filter metadata for all prompts sent during an AI
|
||||||
|
* request.
|
||||||
|
*/
|
||||||
|
interface PromptFilterMetadata {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factory method used to construct a new {@link PromptFilterMetadata} with the
|
||||||
|
* given prompt index and content filter metadata.
|
||||||
|
* @param promptIndex index of the prompt filter metadata contained in the AI
|
||||||
|
* response.
|
||||||
|
* @param contentFilterMetadata underlying AI provider metadata for filtering
|
||||||
|
* applied to prompt content.
|
||||||
|
* @return a new instance of {@link PromptFilterMetadata} with the given prompt
|
||||||
|
* index and content filter metadata.
|
||||||
|
*/
|
||||||
|
static PromptFilterMetadata from(int promptIndex, Object contentFilterMetadata) {
|
||||||
|
|
||||||
|
return new PromptFilterMetadata() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getPromptIndex() {
|
||||||
|
return promptIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public <T> T getContentFilterMetadata() {
|
||||||
|
return (T) contentFilterMetadata;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Index of the prompt filter metadata contained in the AI response.
|
||||||
|
* @return an {@link Integer index} fo the prompt filter metadata contained in the
|
||||||
|
* AI response.
|
||||||
|
*/
|
||||||
|
int getPromptIndex();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the underlying AI provider metadata for filtering applied to prompt
|
||||||
|
* content.
|
||||||
|
* @param <T> {@link Class Type} used to cast the filtered content metadata into
|
||||||
|
* the AI provider-specific type.
|
||||||
|
* @return the underlying AI provider metadata for filtering applied to prompt
|
||||||
|
* content.
|
||||||
|
*/
|
||||||
|
<T> T getContentFilterMetadata();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,84 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.metadata;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract Data Type (ADT) encapsulating metadata from an AI provider's API rate limits
|
||||||
|
* granted to the API key in use and the API key's current balance.
|
||||||
|
*
|
||||||
|
* @author John Blum
|
||||||
|
* @since 0.7.0
|
||||||
|
*/
|
||||||
|
public interface RateLimit {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the maximum number of requests that are permitted before exhausting the
|
||||||
|
* rate limit.
|
||||||
|
* @return an {@link Long} with the maximum number of requests that are permitted
|
||||||
|
* before exhausting the rate limit.
|
||||||
|
* @see #getRequestsRemaining()
|
||||||
|
*/
|
||||||
|
Long getRequestsLimit();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the remaining number of requests that are permitted before exhausting the
|
||||||
|
* {@link #getRequestsLimit() rate limit}.
|
||||||
|
* @return an {@link Long} with the remaining number of requests that are permitted
|
||||||
|
* before exhausting the {@link #getRequestsLimit() rate limit}.
|
||||||
|
* @see #getRequestsLimit()
|
||||||
|
*/
|
||||||
|
Long getRequestsRemaining();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link Duration time} until the rate limit (based on requests) resets
|
||||||
|
* to its {@link #getRequestsLimit() initial state}.
|
||||||
|
* @return a {@link Duration} representing the time until the rate limit (based on
|
||||||
|
* requests) resets to its {@link #getRequestsLimit() initial state}.
|
||||||
|
* @see #getRequestsLimit()
|
||||||
|
*/
|
||||||
|
Duration getRequestsReset();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the maximum number of tokens that are permitted before exhausting the rate
|
||||||
|
* limit.
|
||||||
|
* @return an {@link Long} with the maximum number of tokens that are permitted before
|
||||||
|
* exhausting the rate limit.
|
||||||
|
* @see #getTokensRemaining()
|
||||||
|
*/
|
||||||
|
Long getTokensLimit();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the remaining number of tokens that are permitted before exhausting the
|
||||||
|
* {@link #getTokensLimit() rate limit}.
|
||||||
|
* @return an {@link Long} with the remaining number of tokens that are permitted
|
||||||
|
* before exhausting the {@link #getTokensLimit() rate limit}.
|
||||||
|
* @see #getTokensLimit()
|
||||||
|
*/
|
||||||
|
Long getTokensRemaining();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link Duration time} until the rate limit (based on tokens) resets to
|
||||||
|
* its {@link #getTokensLimit() initial state}.
|
||||||
|
* @return a {@link Duration} with the time until the rate limit (based on tokens)
|
||||||
|
* resets to its {@link #getTokensLimit() initial state}.
|
||||||
|
* @see #getTokensLimit()
|
||||||
|
*/
|
||||||
|
Duration getTokensReset();
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,66 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.metadata;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 抽象数据类型(ADT)封装关于人工智能提供商API使用的元数据根据AI请求。
|
||||||
|
*
|
||||||
|
* Abstract Data Type (ADT) encapsulating metadata on the usage of an AI provider's API
|
||||||
|
* per AI request.
|
||||||
|
*
|
||||||
|
* @author John Blum
|
||||||
|
* @since 0.7.0
|
||||||
|
*/
|
||||||
|
public interface Usage {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 返回AI请求的{@literal prompt}中使用的令牌数。
|
||||||
|
* @返回一个{@link Long},其中包含在的{@literal提示符}中使用的令牌数AI请求。
|
||||||
|
*
|
||||||
|
* Returns the number of tokens used in the {@literal prompt} of the AI request.
|
||||||
|
* @return an {@link Long} with the number of tokens used in the {@literal prompt} of
|
||||||
|
* the AI request.
|
||||||
|
* @see #getGenerationTokens()
|
||||||
|
*/
|
||||||
|
Long getPromptTokens();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the number of tokens returned in the {@literal generation (aka completion)}
|
||||||
|
* of the AI's response.
|
||||||
|
* @return an {@link Long} with the number of tokens returned in the
|
||||||
|
* {@literal generation (aka completion)} of the AI's response.
|
||||||
|
* @see #getPromptTokens()
|
||||||
|
*/
|
||||||
|
Long getGenerationTokens();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the total number of tokens from both the {@literal prompt} of an AI request
|
||||||
|
* and {@literal generation} of the AI's response.
|
||||||
|
* @return the total number of tokens from both the {@literal prompt} of an AI request
|
||||||
|
* and {@literal generation} of the AI's response.
|
||||||
|
* @see #getPromptTokens()
|
||||||
|
* @see #getGenerationTokens()
|
||||||
|
*/
|
||||||
|
default Long getTotalTokens() {
|
||||||
|
Long promptTokens = getPromptTokens();
|
||||||
|
promptTokens = promptTokens != null ? promptTokens : 0;
|
||||||
|
Long completionTokens = getGenerationTokens();
|
||||||
|
completionTokens = completionTokens != null ? completionTokens : 0;
|
||||||
|
return promptTokens + completionTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,14 @@
|
|||||||
|
/**
|
||||||
|
* The org.sf.ai.chat package represents the bounded context for the Chat Model within the
|
||||||
|
* AI generative model domain. This package extends the core domain defined in
|
||||||
|
* org.sf.ai.generative, providing implementations specific to chat-based generative AI
|
||||||
|
* interactions.
|
||||||
|
*
|
||||||
|
* In line with Domain-Driven Design principles, this package includes implementations of
|
||||||
|
* entities and value objects specific to the chat context, such as ChatPrompt and
|
||||||
|
* ChatResponse, adhering to the ubiquitous language of chat interactions in AI models.
|
||||||
|
*
|
||||||
|
* This bounded context is designed to encapsulate all aspects of chat-based AI
|
||||||
|
* functionalities, maintaining a clear boundary from other contexts within the AI domain.
|
||||||
|
*/
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat;
|
@ -0,0 +1,55 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.AssistantMessage;
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
|
||||||
|
import org.springframework.core.io.Resource;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class AssistantPromptTemplate extends PromptTemplate {
|
||||||
|
|
||||||
|
public AssistantPromptTemplate(String template) {
|
||||||
|
super(template);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AssistantPromptTemplate(Resource resource) {
|
||||||
|
super(resource);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Prompt create() {
|
||||||
|
return new Prompt(new AssistantMessage(render()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Prompt create(Map<String, Object> model) {
|
||||||
|
return new Prompt(new AssistantMessage(render(model)));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Message createMessage() {
|
||||||
|
return new AssistantMessage(render());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Message createMessage(Map<String, Object> model) {
|
||||||
|
return new AssistantMessage(render(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.model.ModelOptions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 聊天选项代表了常见的选项,可在不同的聊天模式中移植。
|
||||||
|
*
|
||||||
|
* The ChatOptions represent the common options, portable across different chat models.
|
||||||
|
*/
|
||||||
|
public interface ChatOptions extends ModelOptions {
|
||||||
|
|
||||||
|
Float getTemperature();
|
||||||
|
|
||||||
|
void setTemperature(Float temperature);
|
||||||
|
|
||||||
|
Float getTopP();
|
||||||
|
|
||||||
|
void setTopP(Float topP);
|
||||||
|
|
||||||
|
Integer getTopK();
|
||||||
|
|
||||||
|
void setTopK(Integer topK);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
public class ChatOptionsBuilder {
|
||||||
|
|
||||||
|
private class ChatOptionsImpl implements ChatOptions {
|
||||||
|
|
||||||
|
private Float temperature;
|
||||||
|
|
||||||
|
private Float topP;
|
||||||
|
|
||||||
|
private Integer topK;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Float getTemperature() {
|
||||||
|
return temperature;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setTemperature(Float temperature) {
|
||||||
|
this.temperature = temperature;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Float getTopP() {
|
||||||
|
return topP;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setTopP(Float topP) {
|
||||||
|
this.topP = topP;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer getTopK() {
|
||||||
|
return topK;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setTopK(Integer topK) {
|
||||||
|
this.topK = topK;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private final ChatOptionsImpl options = new ChatOptionsImpl();
|
||||||
|
|
||||||
|
private ChatOptionsBuilder() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ChatOptionsBuilder builder() {
|
||||||
|
return new ChatOptionsBuilder();
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatOptionsBuilder withTemperature(Float temperature) {
|
||||||
|
options.setTemperature(temperature);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatOptionsBuilder withTopP(Float topP) {
|
||||||
|
options.setTopP(topP);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatOptionsBuilder withTopK(Integer topK) {
|
||||||
|
options.setTopK(topK);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatOptions build() {
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,87 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* PromptTemplate,用于将角色指定为字符串实现及其角色不足以满足您的需求。
|
||||||
|
*
|
||||||
|
* A PromptTemplate that lets you specify the role as a string should the current
|
||||||
|
* implementations and their roles not suffice for your needs.
|
||||||
|
*/
|
||||||
|
public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplateChatActions {
|
||||||
|
|
||||||
|
private final List<PromptTemplate> promptTemplates;
|
||||||
|
|
||||||
|
public ChatPromptTemplate(List<PromptTemplate> promptTemplates) {
|
||||||
|
this.promptTemplates = promptTemplates;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String render() {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
for (PromptTemplate promptTemplate : promptTemplates) {
|
||||||
|
sb.append(promptTemplate.render());
|
||||||
|
}
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String render(Map<String, Object> model) {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
for (PromptTemplate promptTemplate : promptTemplates) {
|
||||||
|
sb.append(promptTemplate.render(model));
|
||||||
|
}
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Message> createMessages() {
|
||||||
|
List<Message> messages = new ArrayList<>();
|
||||||
|
for (PromptTemplate promptTemplate : promptTemplates) {
|
||||||
|
messages.add(promptTemplate.createMessage());
|
||||||
|
}
|
||||||
|
return messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Message> createMessages(Map<String, Object> model) {
|
||||||
|
List<Message> messages = new ArrayList<>();
|
||||||
|
for (PromptTemplate promptTemplate : promptTemplates) {
|
||||||
|
messages.add(promptTemplate.createMessage(model));
|
||||||
|
}
|
||||||
|
return messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Prompt create() {
|
||||||
|
List<Message> messages = createMessages();
|
||||||
|
return new Prompt(messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Prompt create(Map<String, Object> model) {
|
||||||
|
List<Message> messages = createMessages(model);
|
||||||
|
return new Prompt(messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,27 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
public class FunctionPromptTemplate extends PromptTemplate {
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
public FunctionPromptTemplate(String template) {
|
||||||
|
super(template);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,99 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.UserMessage;
|
||||||
|
import cn.iocoder.yudao.framework.ai.model.ModelOptions;
|
||||||
|
import cn.iocoder.yudao.framework.ai.model.ModelRequest;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文字内容
|
||||||
|
*/
|
||||||
|
public class Prompt implements ModelRequest<List<Message>> {
|
||||||
|
|
||||||
|
private final List<Message> messages;
|
||||||
|
|
||||||
|
private ChatOptions modelOptions;
|
||||||
|
|
||||||
|
public Prompt(String contents) {
|
||||||
|
this(new UserMessage(contents));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Prompt(Message message) {
|
||||||
|
this(Collections.singletonList(message));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Prompt(List<Message> messages) {
|
||||||
|
this.messages = messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Prompt(String contents, ChatOptions modelOptions) {
|
||||||
|
this(new UserMessage(contents), modelOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Prompt(Message message, ChatOptions modelOptions) {
|
||||||
|
this(Collections.singletonList(message), modelOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Prompt(List<Message> messages, ChatOptions modelOptions) {
|
||||||
|
this.messages = messages;
|
||||||
|
this.modelOptions = modelOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getContents() {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
for (Message message : getInstructions()) {
|
||||||
|
sb.append(message.getContent());
|
||||||
|
}
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelOptions getOptions() {
|
||||||
|
return this.modelOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Message> getInstructions() {
|
||||||
|
return this.messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "Prompt{" + "messages=" + this.messages + ", modelOptions=" + this.modelOptions + '}';
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o)
|
||||||
|
return true;
|
||||||
|
if (!(o instanceof Prompt prompt))
|
||||||
|
return false;
|
||||||
|
return Objects.equals(this.messages, prompt.messages) && Objects.equals(this.modelOptions, prompt.modelOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(this.messages, this.modelOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,218 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.UserMessage;
|
||||||
|
import cn.iocoder.yudao.framework.ai.parser.OutputParser;
|
||||||
|
import org.antlr.runtime.Token;
|
||||||
|
import org.antlr.runtime.TokenStream;
|
||||||
|
import org.springframework.core.io.Resource;
|
||||||
|
import org.springframework.util.StreamUtils;
|
||||||
|
import org.stringtemplate.v4.ST;
|
||||||
|
import org.stringtemplate.v4.compiler.STLexer;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.charset.Charset;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 用户输入的提示内容模板
|
||||||
|
*
|
||||||
|
* 实现:提示词模板操作 提示词模板message相关操作
|
||||||
|
*/
|
||||||
|
public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions {
|
||||||
|
|
||||||
|
private ST st;
|
||||||
|
|
||||||
|
private Map<String, Object> dynamicModel = new HashMap<>();
|
||||||
|
|
||||||
|
protected String template;
|
||||||
|
|
||||||
|
protected TemplateFormat templateFormat = TemplateFormat.ST;
|
||||||
|
|
||||||
|
private OutputParser outputParser;
|
||||||
|
|
||||||
|
public PromptTemplate(Resource resource) {
|
||||||
|
try (InputStream inputStream = resource.getInputStream()) {
|
||||||
|
this.template = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
|
||||||
|
}
|
||||||
|
catch (IOException ex) {
|
||||||
|
throw new RuntimeException("Failed to read resource", ex);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
this.st = new ST(this.template, '{', '}');
|
||||||
|
}
|
||||||
|
catch (Exception ex) {
|
||||||
|
throw new IllegalArgumentException("The template string is not valid.", ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public PromptTemplate(String template) {
|
||||||
|
this.template = template;
|
||||||
|
// If the template string is not valid, an exception will be thrown
|
||||||
|
try {
|
||||||
|
this.st = new ST(this.template, '{', '}');
|
||||||
|
}
|
||||||
|
catch (Exception ex) {
|
||||||
|
throw new IllegalArgumentException("The template string is not valid.", ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public PromptTemplate(String template, Map<String, Object> model) {
|
||||||
|
this.template = template;
|
||||||
|
// If the template string is not valid, an exception will be thrown
|
||||||
|
try {
|
||||||
|
this.st = new ST(this.template, '{', '}');
|
||||||
|
for (Entry<String, Object> entry : model.entrySet()) {
|
||||||
|
add(entry.getKey(), entry.getValue());
|
||||||
|
dynamicModel.put(entry.getKey(), entry.getValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (Exception ex) {
|
||||||
|
throw new IllegalArgumentException("The template string is not valid.", ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public PromptTemplate(Resource resource, Map<String, Object> model) {
|
||||||
|
try (InputStream inputStream = resource.getInputStream()) {
|
||||||
|
this.template = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
|
||||||
|
}
|
||||||
|
catch (IOException ex) {
|
||||||
|
throw new RuntimeException("Failed to read resource", ex);
|
||||||
|
}
|
||||||
|
// If the template string is not valid, an exception will be thrown
|
||||||
|
try {
|
||||||
|
this.st = new ST(this.template, '{', '}');
|
||||||
|
for (Entry<String, Object> entry : model.entrySet()) {
|
||||||
|
add(entry.getKey(), entry.getValue());
|
||||||
|
dynamicModel.put(entry.getKey(), entry.getValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (Exception ex) {
|
||||||
|
throw new IllegalArgumentException("The template string is not valid.", ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public OutputParser getOutputParser() {
|
||||||
|
return outputParser;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setOutputParser(OutputParser outputParser) {
|
||||||
|
Objects.requireNonNull(outputParser, "Output Parser can not be null");
|
||||||
|
this.outputParser = outputParser;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void add(String name, Object value) {
|
||||||
|
this.st.add(name, value);
|
||||||
|
this.dynamicModel.put(name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getTemplate() {
|
||||||
|
return this.template;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TemplateFormat getTemplateFormat() {
|
||||||
|
return this.templateFormat;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render Methods
|
||||||
|
@Override
|
||||||
|
public String render() {
|
||||||
|
validate(this.dynamicModel);
|
||||||
|
return st.render();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String render(Map<String, Object> model) {
|
||||||
|
validate(model);
|
||||||
|
for (Entry<String, Object> entry : model.entrySet()) {
|
||||||
|
if (st.getAttribute(entry.getKey()) != null) {
|
||||||
|
st.remove(entry.getKey());
|
||||||
|
}
|
||||||
|
if (entry.getValue() instanceof Resource) {
|
||||||
|
st.add(entry.getKey(), renderResource((Resource) entry.getValue()));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
st.add(entry.getKey(), entry.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return st.render();
|
||||||
|
}
|
||||||
|
|
||||||
|
private String renderResource(Resource resource) {
|
||||||
|
try {
|
||||||
|
return resource.getContentAsString(Charset.defaultCharset());
|
||||||
|
}
|
||||||
|
catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
// try (InputStream inputStream = resource.getInputStream()) {
|
||||||
|
// return StreamUtils.copyToString(inputStream, Charset.defaultCharset());
|
||||||
|
// }
|
||||||
|
// catch (IOException ex) {
|
||||||
|
// throw new RuntimeException(ex);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Message createMessage() {
|
||||||
|
return new UserMessage(render());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Message createMessage(Map<String, Object> model) {
|
||||||
|
return new UserMessage(render(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Prompt create() {
|
||||||
|
return new Prompt(render(new HashMap<>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Prompt create(Map<String, Object> model) {
|
||||||
|
return new Prompt(render(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<String> getInputVariables() {
|
||||||
|
TokenStream tokens = this.st.impl.tokens;
|
||||||
|
return IntStream.range(0, tokens.range())
|
||||||
|
.mapToObj(tokens::get)
|
||||||
|
.filter(token -> token.getType() == STLexer.ID)
|
||||||
|
.map(Token::getText)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void validate(Map<String, Object> model) {
|
||||||
|
Set<String> dynamicVariableNames = new HashSet<>(this.dynamicModel.keySet());
|
||||||
|
Set<String> modelVariables = new HashSet<>(model.keySet());
|
||||||
|
modelVariables.addAll(dynamicVariableNames);
|
||||||
|
Set<String> missingEntries = new HashSet<>(getInputVariables());
|
||||||
|
missingEntries.removeAll(modelVariables);
|
||||||
|
if (!missingEntries.isEmpty()) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"All template variables were not replaced. Missing variable names are " + missingEntries);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 提示词模板操作
|
||||||
|
*/
|
||||||
|
public interface PromptTemplateActions extends PromptTemplateStringActions {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Prompt
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
Prompt create();
|
||||||
|
|
||||||
|
Prompt create(Map<String, Object> model);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 聊天操作
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public interface PromptTemplateChatActions {
|
||||||
|
|
||||||
|
List<Message> createMessages();
|
||||||
|
|
||||||
|
List<Message> createMessages(Map<String, Object> model);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 用户输入的提示内容 模板信息操作
|
||||||
|
*/
|
||||||
|
public interface PromptTemplateMessageActions {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建一个 message
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
Message createMessage();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建一个 message
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
Message createMessage(Map<String, Object> model);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,14 @@
|
|||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 提示次模板字符串操作
|
||||||
|
*/
|
||||||
|
public interface PromptTemplateStringActions {
|
||||||
|
|
||||||
|
String render();
|
||||||
|
|
||||||
|
String render(Map<String, Object> model);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,55 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.messages.SystemMessage;
|
||||||
|
import org.springframework.core.io.Resource;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class SystemPromptTemplate extends PromptTemplate {
|
||||||
|
|
||||||
|
public SystemPromptTemplate(String template) {
|
||||||
|
super(template);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SystemPromptTemplate(Resource resource) {
|
||||||
|
super(resource);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Message createMessage() {
|
||||||
|
return new SystemMessage(render());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Message createMessage(Map<String, Object> model) {
|
||||||
|
return new SystemMessage(render(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Prompt create() {
|
||||||
|
return new Prompt(new SystemMessage(render()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Prompt create(Map<String, Object> model) {
|
||||||
|
return new Prompt(new SystemMessage(render(model)));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,42 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.chat.prompt;
|
||||||
|
|
||||||
|
public enum TemplateFormat {
|
||||||
|
|
||||||
|
ST("ST");
|
||||||
|
|
||||||
|
private final String value;
|
||||||
|
|
||||||
|
TemplateFormat(String value) {
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TemplateFormat fromValue(String value) {
|
||||||
|
for (TemplateFormat templateFormat : TemplateFormat.values()) {
|
||||||
|
if (templateFormat.getValue().equals(value)) {
|
||||||
|
return templateFormat;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("Invalid TemplateFormat value: " + value);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The ModelClient interface provides a generic API for invoking AI models. It is designed
|
||||||
|
* to handle the interaction with various types of AI models by abstracting the process of
|
||||||
|
* sending requests and receiving responses. The interface uses Java generics to
|
||||||
|
* accommodate different types of requests and responses, enhancing flexibility and
|
||||||
|
* adaptability across different AI model implementations.
|
||||||
|
*
|
||||||
|
* @param <TReq> the generic type of the request to the AI model
|
||||||
|
* @param <TRes> the generic type of the response from the AI model
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @since 0.8.0
|
||||||
|
*/
|
||||||
|
public interface ModelClient<TReq extends ModelRequest<?>, TRes extends ModelResponse<?>> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Executes a method call to the AI model.
|
||||||
|
* @param request the request object to be sent to the AI model
|
||||||
|
* @return the response from the AI model
|
||||||
|
*/
|
||||||
|
TRes call(TReq request) throws Throwable;
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface representing the customizable options for AI model interactions. This marker
|
||||||
|
* interface allows for the specification of various settings and parameters that can
|
||||||
|
* influence the behavior and output of AI models. It is designed to provide flexibility
|
||||||
|
* and adaptability in different AI scenarios, ensuring that the AI models can be
|
||||||
|
* fine-tuned according to specific requirements.
|
||||||
|
*
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @since 0.8.0
|
||||||
|
*/
|
||||||
|
public interface ModelOptions {
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,387 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.SerializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.github.victools.jsonschema.generator.*;
|
||||||
|
import com.github.victools.jsonschema.module.jackson.JacksonModule;
|
||||||
|
import com.github.victools.jsonschema.module.jackson.JacksonOption;
|
||||||
|
import com.github.victools.jsonschema.module.swagger2.Swagger2Module;
|
||||||
|
import org.springframework.beans.BeanWrapper;
|
||||||
|
import org.springframework.beans.BeanWrapperImpl;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.beans.PropertyDescriptor;
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility class for manipulating {@link ModelOptions} objects.
|
||||||
|
*
|
||||||
|
* @author Christian Tzolov
|
||||||
|
* @since 0.8.0
|
||||||
|
*/
|
||||||
|
public final class ModelOptionsUtils {
|
||||||
|
|
||||||
|
private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper()
|
||||||
|
.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
|
||||||
|
.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS);
|
||||||
|
|
||||||
|
private final static List<String> BEAN_MERGE_FIELD_EXCISIONS = List.of("class");
|
||||||
|
|
||||||
|
private static ConcurrentHashMap<Class<?>, List<String>> REQUEST_FIELD_NAMES_PER_CLASS = new ConcurrentHashMap<Class<?>, List<String>>();
|
||||||
|
|
||||||
|
private static AtomicReference<SchemaGenerator> SCHEMA_GENERATOR_CACHE = new AtomicReference<>();
|
||||||
|
|
||||||
|
private ModelOptionsUtils() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the given JSON string to a Map of String and Object.
|
||||||
|
* @param json the JSON string to convert to a Map.
|
||||||
|
* @return the converted Map.
|
||||||
|
*/
|
||||||
|
public static Map<String, Object> jsonToMap(String json) {
|
||||||
|
try {
|
||||||
|
return OBJECT_MAPPER.readValue(json, MAP_TYPE_REF);
|
||||||
|
}
|
||||||
|
catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TypeReference<HashMap<String, Object>> MAP_TYPE_REF = new TypeReference<HashMap<String, Object>>() {
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the given JSON string to an Object of the given type.
|
||||||
|
* @param <T> the type of the object to return.
|
||||||
|
* @param json the JSON string to convert to an object.
|
||||||
|
* @param type the type of the object to return.
|
||||||
|
* @return Object instance of the given type.
|
||||||
|
*/
|
||||||
|
public static <T> T jsonToObject(String json, Class<T> type) {
|
||||||
|
try {
|
||||||
|
return OBJECT_MAPPER.readValue(json, type);
|
||||||
|
}
|
||||||
|
catch (Exception e) {
|
||||||
|
throw new RuntimeException("Failed to json: " + json, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the given object to a JSON string.
|
||||||
|
* @param object the object to convert to a JSON string.
|
||||||
|
* @return the JSON string.
|
||||||
|
*/
|
||||||
|
public static String toJsonString(Object object) {
|
||||||
|
try {
|
||||||
|
return OBJECT_MAPPER.writeValueAsString(object);
|
||||||
|
}
|
||||||
|
catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Merges the source object into the target object and returns an object represented
|
||||||
|
* by the given class. The JSON property names are used to match the fields to merge.
|
||||||
|
* The source non-null values override the target values with the same field name. The
|
||||||
|
* source null values are ignored. If the acceptedFieldNames is not empty, only the
|
||||||
|
* fields with the given names are merged and returned. If the acceptedFieldNames is
|
||||||
|
* empty, use the {@code @JsonProperty} names, inferred from the provided clazz.
|
||||||
|
* @param <T> they type of the class to return.
|
||||||
|
* @param source the source object to merge.
|
||||||
|
* @param target the target object to merge into.
|
||||||
|
* @param clazz the class to return.
|
||||||
|
* @param acceptedFieldNames the list of field names accepted for the target object.
|
||||||
|
* @return the merged object represented by the given class.
|
||||||
|
*/
|
||||||
|
public static <T> T merge(Object source, Object target, Class<T> clazz, List<String> acceptedFieldNames) {
|
||||||
|
|
||||||
|
if (source == null) {
|
||||||
|
source = Map.of();
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> requestFieldNames = CollectionUtils.isEmpty(acceptedFieldNames)
|
||||||
|
? REQUEST_FIELD_NAMES_PER_CLASS.computeIfAbsent(clazz, ModelOptionsUtils::getJsonPropertyValues)
|
||||||
|
: acceptedFieldNames;
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(requestFieldNames)) {
|
||||||
|
throw new IllegalArgumentException("No @JsonProperty fields found in the " + clazz.getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Object> sourceMap = ModelOptionsUtils.objectToMap(source);
|
||||||
|
Map<String, Object> targetMap = ModelOptionsUtils.objectToMap(target);
|
||||||
|
|
||||||
|
targetMap.putAll(sourceMap.entrySet()
|
||||||
|
.stream()
|
||||||
|
.filter(e -> e.getValue() != null)
|
||||||
|
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())));
|
||||||
|
|
||||||
|
targetMap = targetMap.entrySet()
|
||||||
|
.stream()
|
||||||
|
.filter(e -> requestFieldNames.contains(e.getKey()))
|
||||||
|
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()));
|
||||||
|
|
||||||
|
return ModelOptionsUtils.mapToClass(targetMap, clazz);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Merges the source object into the target object and returns an object represented
|
||||||
|
* by the given class. The JSON property names are used to match the fields to merge.
|
||||||
|
* The source non-null values override the target values with the same field name. The
|
||||||
|
* source null values are ignored. Returns the only field names that match the
|
||||||
|
* {@code @JsonProperty} names, inferred from the provided clazz.
|
||||||
|
* @param <T> they type of the class to return.
|
||||||
|
* @param source the source object to merge.
|
||||||
|
* @param target the target object to merge into.
|
||||||
|
* @param clazz the class to return.
|
||||||
|
* @return the merged object represented by the given class.
|
||||||
|
*/
|
||||||
|
public static <T> T merge(Object source, Object target, Class<T> clazz) {
|
||||||
|
return ModelOptionsUtils.merge(source, target, clazz, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the given object to a Map.
|
||||||
|
* @param source the object to convert to a Map.
|
||||||
|
* @return the converted Map.
|
||||||
|
*/
|
||||||
|
public static Map<String, Object> objectToMap(Object source) {
|
||||||
|
if (source == null) {
|
||||||
|
return new HashMap<>();
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
String json = OBJECT_MAPPER.writeValueAsString(source);
|
||||||
|
return OBJECT_MAPPER.readValue(json, new TypeReference<Map<String, Object>>() {
|
||||||
|
})
|
||||||
|
.entrySet()
|
||||||
|
.stream()
|
||||||
|
.filter(e -> e.getValue() != null)
|
||||||
|
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()));
|
||||||
|
}
|
||||||
|
catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the given Map to the given class.
|
||||||
|
* @param <T> the type of the class to return.
|
||||||
|
* @param source the Map to convert to the given class.
|
||||||
|
* @param clazz the class to convert the Map to.
|
||||||
|
* @return the converted class.
|
||||||
|
*/
|
||||||
|
public static <T> T mapToClass(Map<String, Object> source, Class<T> clazz) {
|
||||||
|
try {
|
||||||
|
String json = OBJECT_MAPPER.writeValueAsString(source);
|
||||||
|
return OBJECT_MAPPER.readValue(json, clazz);
|
||||||
|
}
|
||||||
|
catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the list of name values of the {@link JsonProperty} annotations.
|
||||||
|
* @param clazz the class that contains fields annotated with {@link JsonProperty}.
|
||||||
|
* @return the list of values of the {@link JsonProperty} annotations.
|
||||||
|
*/
|
||||||
|
public static List<String> getJsonPropertyValues(Class<?> clazz) {
|
||||||
|
List<String> values = new ArrayList<>();
|
||||||
|
Field[] fields = clazz.getDeclaredFields();
|
||||||
|
for (Field field : fields) {
|
||||||
|
JsonProperty jsonPropertyAnnotation = field.getAnnotation(JsonProperty.class);
|
||||||
|
if (jsonPropertyAnnotation != null) {
|
||||||
|
values.add(jsonPropertyAnnotation.value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a new instance of the targetBeanClazz that copies the bean values from the
|
||||||
|
* sourceBean instance.
|
||||||
|
* @param sourceBean the source bean to copy the values from.
|
||||||
|
* @param sourceInterfaceClazz the source interface class. Only the fields with the
|
||||||
|
* same name as the interface methods are copied. This allow the source object to be a
|
||||||
|
* subclass of the source interface with additional, non-interface fields.
|
||||||
|
* @param targetBeanClazz the target class, a subclass of the ChatOptions, to convert
|
||||||
|
* into.
|
||||||
|
* @param <T> the target class type.
|
||||||
|
* @return a new instance of the targetBeanClazz with the values from the sourceBean
|
||||||
|
* instance.
|
||||||
|
*/
|
||||||
|
public static <I, S extends I, T extends S> T copyToTarget(S sourceBean, Class<I> sourceInterfaceClazz,
|
||||||
|
Class<T> targetBeanClazz) {
|
||||||
|
|
||||||
|
Assert.notNull(sourceInterfaceClazz, "SourceOptionsClazz must not be null");
|
||||||
|
Assert.notNull(targetBeanClazz, "TargetOptionsClazz must not be null");
|
||||||
|
|
||||||
|
if (sourceBean == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sourceBean.getClass().isAssignableFrom(targetBeanClazz)) {
|
||||||
|
return (T) sourceBean;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
T targetOptions = targetBeanClazz.getConstructor().newInstance();
|
||||||
|
|
||||||
|
ModelOptionsUtils.mergeBeans(sourceBean, targetOptions, sourceInterfaceClazz, true);
|
||||||
|
|
||||||
|
return targetOptions;
|
||||||
|
}
|
||||||
|
catch (Exception e) {
|
||||||
|
throw new RuntimeException(
|
||||||
|
"Failed to convert the " + sourceInterfaceClazz.getName() + " into " + targetBeanClazz.getName(),
|
||||||
|
e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Merges the source object into the target object. The source null values are
|
||||||
|
* ignored. Only objects with Getter and Setter methods are supported.
|
||||||
|
* @param <T> the type of the source and target object.
|
||||||
|
* @param source the source object to merge.
|
||||||
|
* @param target the target object to merge into.
|
||||||
|
* @param sourceInterfaceClazz the source interface class. Only the fields with the
|
||||||
|
* same name as the interface methods are merged. This allow the source object to be a
|
||||||
|
* subclass of the source interface with additional, non-interface fields.
|
||||||
|
* @param overrideNonNullTargetValues if true, the source non-null values override the
|
||||||
|
* target values with the same field name. If false, the source non-null values are
|
||||||
|
* ignored.
|
||||||
|
* @return the merged target object.
|
||||||
|
*/
|
||||||
|
public static <I, S extends I, T extends S> T mergeBeans(S source, T target, Class<I> sourceInterfaceClazz,
|
||||||
|
boolean overrideNonNullTargetValues) {
|
||||||
|
Assert.notNull(source, "Source object must not be null");
|
||||||
|
Assert.notNull(target, "Target object must not be null");
|
||||||
|
|
||||||
|
BeanWrapper sourceBeanWrap = new BeanWrapperImpl(source);
|
||||||
|
BeanWrapper targetBeanWrap = new BeanWrapperImpl(target);
|
||||||
|
|
||||||
|
List<String> interfaceNames = Arrays.stream(sourceInterfaceClazz.getMethods()).map(m -> m.getName()).toList();
|
||||||
|
|
||||||
|
for (PropertyDescriptor descriptor : sourceBeanWrap.getPropertyDescriptors()) {
|
||||||
|
|
||||||
|
if (!BEAN_MERGE_FIELD_EXCISIONS.contains(descriptor.getName())
|
||||||
|
&& interfaceNames.contains(toGetName(descriptor.getName()))) {
|
||||||
|
|
||||||
|
String propertyName = descriptor.getName();
|
||||||
|
Object value = sourceBeanWrap.getPropertyValue(propertyName);
|
||||||
|
|
||||||
|
// Copy value to the target object
|
||||||
|
if (value != null) {
|
||||||
|
var targetValue = targetBeanWrap.getPropertyValue(propertyName);
|
||||||
|
|
||||||
|
if (targetValue == null || overrideNonNullTargetValues) {
|
||||||
|
targetBeanWrap.setPropertyValue(propertyName, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return target;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String toGetName(String name) {
|
||||||
|
return "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates JSON Schema (version 2020_12) for the given class.
|
||||||
|
* @param clazz the class to generate JSON Schema for.
|
||||||
|
* @param toUpperCaseTypeValues if true, the type values are converted to upper case.
|
||||||
|
* @return the generated JSON Schema as a String.
|
||||||
|
*/
|
||||||
|
public static String getJsonSchema(Class<?> clazz, boolean toUpperCaseTypeValues) {
|
||||||
|
|
||||||
|
if (SCHEMA_GENERATOR_CACHE.get() == null) {
|
||||||
|
|
||||||
|
JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED);
|
||||||
|
Swagger2Module swaggerModule = new Swagger2Module();
|
||||||
|
|
||||||
|
SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12,
|
||||||
|
OptionPreset.PLAIN_JSON)
|
||||||
|
.with(Option.EXTRA_OPEN_API_FORMAT_VALUES)
|
||||||
|
.with(Option.PLAIN_DEFINITION_KEYS)
|
||||||
|
.with(swaggerModule)
|
||||||
|
.with(jacksonModule);
|
||||||
|
|
||||||
|
SchemaGeneratorConfig config = configBuilder.build();
|
||||||
|
SchemaGenerator generator = new SchemaGenerator(config);
|
||||||
|
SCHEMA_GENERATOR_CACHE.compareAndSet(null, generator);
|
||||||
|
}
|
||||||
|
|
||||||
|
ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz);
|
||||||
|
if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI
|
||||||
|
// version of it).
|
||||||
|
toUpperCaseTypeValues(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
return node.toPrettyString();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void toUpperCaseTypeValues(ObjectNode node) {
|
||||||
|
if (node == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (node.isObject()) {
|
||||||
|
node.fields().forEachRemaining(entry -> {
|
||||||
|
JsonNode value = entry.getValue();
|
||||||
|
if (value.isObject()) {
|
||||||
|
toUpperCaseTypeValues((ObjectNode) value);
|
||||||
|
}
|
||||||
|
else if (value.isArray()) {
|
||||||
|
((ArrayNode) value).elements().forEachRemaining(element -> {
|
||||||
|
if (element.isObject() || element.isArray()) {
|
||||||
|
toUpperCaseTypeValues((ObjectNode) element);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
else if (value.isTextual() && entry.getKey().equals("type")) {
|
||||||
|
String oldValue = ((ObjectNode) node).get("type").asText();
|
||||||
|
((ObjectNode) node).put("type", oldValue.toUpperCase());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
else if (node.isArray()) {
|
||||||
|
node.elements().forEachRemaining(element -> {
|
||||||
|
if (element.isObject() || element.isArray()) {
|
||||||
|
toUpperCaseTypeValues((ObjectNode) element);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 表示对AI模型的请求的接口。此接口封装了 与人工智能模型交互所需的必要信息,包括指令或 输入(通用类型T)和附加模型选项。它提供了一种标准化的方式
|
||||||
|
* 向人工智能模型发送请求,确保包括所有必要的细节,并且可以易于管理。
|
||||||
|
*
|
||||||
|
* Interface representing a request to an AI model. This interface encapsulates the
|
||||||
|
* necessary information required to interact with an AI model, including instructions or
|
||||||
|
* inputs (of generic type T) and additional model options. It provides a standardized way
|
||||||
|
* to send requests to AI models, ensuring that all necessary details are included and can
|
||||||
|
* be easily managed.
|
||||||
|
*
|
||||||
|
* @param <T> the type of instructions or input required by the AI model
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @since 0.8.0
|
||||||
|
*/
|
||||||
|
public interface ModelRequest<T> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索AI模型所需的指令或输入。 返回AI模型所需的指令或输入
|
||||||
|
*
|
||||||
|
* Retrieves the instructions or input required by the AI model.
|
||||||
|
* @return the instructions or input required by the AI model
|
||||||
|
*/
|
||||||
|
T getInstructions(); // required input
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索人工智能模型交互的可自定义选项。 返回AI模型交互的自定义选项
|
||||||
|
*
|
||||||
|
* Retrieves the customizable options for AI model interactions.
|
||||||
|
* @return the customizable options for AI model interactions
|
||||||
|
*/
|
||||||
|
ModelOptions getOptions();
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,62 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* 表示从AI模型接收到的响应的接口。此接口提供 访问AI模型生成的主要结果或结果列表的方法,以及 以及响应元数据。它是封装和管理的标准化方式
|
||||||
|
* 人工智能模型的输出,确保轻松检索和处理生成的信息
|
||||||
|
*
|
||||||
|
* Interface representing the response received from an AI model. This interface provides
|
||||||
|
* methods to access the main result or a list of results generated by the AI model, along
|
||||||
|
* with the response metadata. It serves as a standardized way to encapsulate and manage
|
||||||
|
* the output from AI models, ensuring easy retrieval and processing of the generated
|
||||||
|
* information.
|
||||||
|
*
|
||||||
|
* @param <T> the type of the result(s) provided by the AI model
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @since 0.8.0
|
||||||
|
*/
|
||||||
|
public interface ModelResponse<T extends ModelResult<?>> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索AI模型的结果。
|
||||||
|
*
|
||||||
|
* Retrieves the result of the AI model.
|
||||||
|
* @return the result generated by the AI model
|
||||||
|
*/
|
||||||
|
T getResult();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索AI模型生成的输出列表。
|
||||||
|
*
|
||||||
|
* Retrieves the list of generated outputs by the AI model.
|
||||||
|
* @return the list of generated outputs
|
||||||
|
*/
|
||||||
|
List<T> getResults();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索与AI模型的响应相关联的响应元数据。
|
||||||
|
*
|
||||||
|
* Retrieves the response metadata associated with the AI model's response.
|
||||||
|
* @return the response metadata
|
||||||
|
*/
|
||||||
|
ResponseMetadata getMetadata();
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,43 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This interface provides methods to access the main output of the AI model and the
|
||||||
|
* metadata associated with this result. It is designed to offer a standardized and
|
||||||
|
* comprehensive way to handle and interpret the outputs generated by AI models, catering
|
||||||
|
* to diverse AI applications and use cases.
|
||||||
|
*
|
||||||
|
* @param <T> the type of the output generated by the AI model
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @since 0.8.0
|
||||||
|
*/
|
||||||
|
public interface ModelResult<T> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieves the output generated by the AI model.
|
||||||
|
* @return the output generated by the AI model
|
||||||
|
*/
|
||||||
|
T getOutput();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieves the metadata associated with the result of an AI model.
|
||||||
|
* @return the metadata associated with the result
|
||||||
|
*/
|
||||||
|
ResultMetadata getMetadata();
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 表示与AI模型的响应相关联的元数据的接口。此接口 旨在提供有关人工智能生成反应的附加信息 模型,包括处理细节和模型特定数据。它是一种价值
|
||||||
|
* 核心领域内的对象,增强对人工智能模型的理解和管理 在各种应用中的响应。
|
||||||
|
*
|
||||||
|
* Interface representing metadata associated with an AI model's response. This interface
|
||||||
|
* is designed to provide additional information about the generative response from an AI
|
||||||
|
* model, including processing details and model-specific data. It serves as a value
|
||||||
|
* object within the core domain, enhancing the understanding and management of AI model
|
||||||
|
* responses in various applications.
|
||||||
|
*
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @since 0.8.0
|
||||||
|
*/
|
||||||
|
public interface ResponseMetadata {
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface representing metadata associated with the results of an AI model. This
|
||||||
|
* interface focuses on providing additional context and insights into the results
|
||||||
|
* generated by AI models. It could include information like computation time, model
|
||||||
|
* version, or other relevant details that enhance understanding and management of AI
|
||||||
|
* model outputs in various applications.
|
||||||
|
*
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @since 0.8.0
|
||||||
|
*/
|
||||||
|
public interface ResultMetadata {
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,43 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model;
|
||||||
|
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The StreamingModelClient interface provides a generic API for invoking a AI models with
|
||||||
|
* streaming response. It abstracts the process of sending requests and receiving a
|
||||||
|
* streaming responses. The interface uses Java generics to accommodate different types of
|
||||||
|
* requests and responses, enhancing flexibility and adaptability across different AI
|
||||||
|
* model implementations.
|
||||||
|
*
|
||||||
|
* @param <TReq> the generic type of the request to the AI model
|
||||||
|
* @param <TResChunk> the generic type of a single item in the streaming response from the
|
||||||
|
* AI model
|
||||||
|
* @author Christian Tzolov
|
||||||
|
* @since 0.8.0
|
||||||
|
*/
|
||||||
|
public interface StreamingModelClient<TReq extends ModelRequest<?>, TResChunk extends ModelResponse<?>> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Executes a method call to the AI model.
|
||||||
|
* @param request the request object to be sent to the AI model
|
||||||
|
* @return the streaming response from the AI model
|
||||||
|
*/
|
||||||
|
Flux<TResChunk> stream(TReq request);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,158 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model.function;
|
||||||
|
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
public abstract class AbstractFunctionCallSupport<Msg, Req, Resp> {
|
||||||
|
|
||||||
|
protected final static boolean IS_RUNTIME_CALL = true;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The function callback register is used to resolve the function callbacks by name.
|
||||||
|
*/
|
||||||
|
protected final Map<String, FunctionCallback> functionCallbackRegister = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The function callback context is used to resolve the function callbacks by name
|
||||||
|
* from the Spring context. It is optional and usually used with Spring
|
||||||
|
* auto-configuration.
|
||||||
|
*/
|
||||||
|
protected final FunctionCallbackContext functionCallbackContext;
|
||||||
|
|
||||||
|
public AbstractFunctionCallSupport(FunctionCallbackContext functionCallbackContext) {
|
||||||
|
this.functionCallbackContext = functionCallbackContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, FunctionCallback> getFunctionCallbackRegister() {
|
||||||
|
return this.functionCallbackRegister;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Set<String> handleFunctionCallbackConfigurations(FunctionCallingOptions options, boolean isRuntimeCall) {
|
||||||
|
|
||||||
|
Set<String> functionToCall = new HashSet<>();
|
||||||
|
|
||||||
|
if (options != null) {
|
||||||
|
if (!CollectionUtils.isEmpty(options.getFunctionCallbacks())) {
|
||||||
|
options.getFunctionCallbacks().stream().forEach(functionCallback -> {
|
||||||
|
|
||||||
|
// Register the tool callback.
|
||||||
|
if (isRuntimeCall) {
|
||||||
|
this.functionCallbackRegister.put(functionCallback.getName(), functionCallback);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
this.functionCallbackRegister.putIfAbsent(functionCallback.getName(), functionCallback);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Automatically enable the function, usually from prompt callback.
|
||||||
|
if (isRuntimeCall) {
|
||||||
|
functionToCall.add(functionCallback.getName());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the explicitly enabled functions.
|
||||||
|
if (!CollectionUtils.isEmpty(options.getFunctions())) {
|
||||||
|
functionToCall.addAll(options.getFunctions());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return functionToCall;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resolve the function callbacks by name. Retrieve them from the registry or try to
|
||||||
|
* resolve them from the Application Context.
|
||||||
|
* @param functionNames Name of function callbacks to retrieve.
|
||||||
|
* @return list of resolved FunctionCallbacks.
|
||||||
|
*/
|
||||||
|
protected List<FunctionCallback> resolveFunctionCallbacks(Set<String> functionNames) {
|
||||||
|
|
||||||
|
List<FunctionCallback> retrievedFunctionCallbacks = new ArrayList<>();
|
||||||
|
|
||||||
|
for (String functionName : functionNames) {
|
||||||
|
if (!this.functionCallbackRegister.containsKey(functionName)) {
|
||||||
|
|
||||||
|
if (this.functionCallbackContext != null) {
|
||||||
|
FunctionCallback functionCallback = this.functionCallbackContext.getFunctionCallback(functionName,
|
||||||
|
null);
|
||||||
|
if (functionCallback != null) {
|
||||||
|
this.functionCallbackRegister.put(functionName, functionCallback);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"No function callback [" + functionName + "] fund in tht FunctionCallbackContext");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
throw new IllegalStateException("No function callback found for name: " + functionName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FunctionCallback functionCallback = this.functionCallbackRegister.get(functionName);
|
||||||
|
|
||||||
|
retrievedFunctionCallbacks.add(functionCallback);
|
||||||
|
}
|
||||||
|
|
||||||
|
return retrievedFunctionCallbacks;
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
protected Resp callWithFunctionSupport(Req request) {
|
||||||
|
Resp response = this.doChatCompletion(request);
|
||||||
|
return this.handleFunctionCallOrReturn(request, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Resp handleFunctionCallOrReturn(Req request, Resp response) {
|
||||||
|
|
||||||
|
if (!this.isToolFunctionCall(response)) {
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The chat completion tool call requires the complete conversation
|
||||||
|
// history. Including the initial user message.
|
||||||
|
List<Msg> conversationHistory = new ArrayList<>();
|
||||||
|
|
||||||
|
conversationHistory.addAll(this.doGetUserMessages(request));
|
||||||
|
|
||||||
|
Msg responseMessage = this.doGetToolResponseMessage(response);
|
||||||
|
|
||||||
|
// Add the assistant response to the message conversation history.
|
||||||
|
conversationHistory.add(responseMessage);
|
||||||
|
|
||||||
|
Req newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory);
|
||||||
|
|
||||||
|
return this.callWithFunctionSupport(newRequest);
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract protected Req doCreateToolResponseRequest(Req previousRequest, Msg responseMessage,
|
||||||
|
List<Msg> conversationHistory);
|
||||||
|
|
||||||
|
abstract protected List<Msg> doGetUserMessages(Req request);
|
||||||
|
|
||||||
|
abstract protected Msg doGetToolResponseMessage(Resp response);
|
||||||
|
|
||||||
|
abstract protected Resp doChatCompletion(Req request);
|
||||||
|
|
||||||
|
abstract protected boolean isToolFunctionCall(Resp response);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,159 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model.function;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract implementation of the {@link FunctionCallback} for interacting with the
|
||||||
|
* Model's function calling protocol and a {@link Function} wrapping the interaction with
|
||||||
|
* the 3rd party service/function.
|
||||||
|
*
|
||||||
|
* Implement the {@code O apply(I request) } method to implement the interaction with the
|
||||||
|
* 3rd party service/function.
|
||||||
|
*
|
||||||
|
* The {@link #responseConverter} function is responsible to convert the 3rd party
|
||||||
|
* function's output type into a string expected by the LLM model.
|
||||||
|
*
|
||||||
|
* @param <I> the 3rd party service input type.
|
||||||
|
* @param <O> the 3rd party service output type.
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
abstract class AbstractFunctionCallback<I, O> implements Function<I, O>, FunctionCallback {
|
||||||
|
|
||||||
|
private final String name;
|
||||||
|
|
||||||
|
private final String description;
|
||||||
|
|
||||||
|
private final Class<I> inputType;
|
||||||
|
|
||||||
|
private final String inputTypeSchema;
|
||||||
|
|
||||||
|
private final ObjectMapper objectMapper;
|
||||||
|
|
||||||
|
private final Function<O, String> responseConverter;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a new {@link AbstractFunctionCallback} with the given name, description,
|
||||||
|
* input type and default object mapper.
|
||||||
|
* @param name Function name. Should be unique within the ChatClient's function
|
||||||
|
* registry.
|
||||||
|
* @param description Function description. Used as a "system prompt" by the model to
|
||||||
|
* decide if the function should be called.
|
||||||
|
* @param inputTypeSchema Used to compute, the argument's Schema (such as JSON Schema
|
||||||
|
* or OpenAPI Schema)required by the Model's function calling protocol.
|
||||||
|
* @param inputType Used to compute, the argument's JSON schema required by the
|
||||||
|
* Model's function calling protocol.
|
||||||
|
* @param responseConverter Used to convert the function's output type to a string.
|
||||||
|
* @param objectMapper Used to convert the function's input and output types to and
|
||||||
|
* from JSON.
|
||||||
|
*/
|
||||||
|
protected AbstractFunctionCallback(String name, String description, String inputTypeSchema, Class<I> inputType,
|
||||||
|
Function<O, String> responseConverter, ObjectMapper objectMapper) {
|
||||||
|
Assert.notNull(name, "Name must not be null");
|
||||||
|
Assert.notNull(description, "Description must not be null");
|
||||||
|
Assert.notNull(inputType, "InputType must not be null");
|
||||||
|
Assert.notNull(inputTypeSchema, "InputTypeSchema must not be null");
|
||||||
|
Assert.notNull(responseConverter, "ResponseConverter must not be null");
|
||||||
|
Assert.notNull(objectMapper, "ObjectMapper must not be null");
|
||||||
|
this.name = name;
|
||||||
|
this.description = description;
|
||||||
|
this.inputType = inputType;
|
||||||
|
this.inputTypeSchema = inputTypeSchema;
|
||||||
|
this.responseConverter = responseConverter;
|
||||||
|
this.objectMapper = objectMapper;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return this.name;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getDescription() {
|
||||||
|
return this.description;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getInputTypeSchema() {
|
||||||
|
return this.inputTypeSchema;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String call(String functionArguments) {
|
||||||
|
|
||||||
|
// Convert the tool calls JSON arguments into a Java function request object.
|
||||||
|
I request = fromJson(functionArguments, inputType);
|
||||||
|
|
||||||
|
// extend conversation with function response.
|
||||||
|
return this.andThen(this.responseConverter).apply(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
private <T> T fromJson(String json, Class<T> targetClass) {
|
||||||
|
try {
|
||||||
|
return this.objectMapper.readValue(json, targetClass);
|
||||||
|
}
|
||||||
|
catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
final int prime = 31;
|
||||||
|
int result = 1;
|
||||||
|
result = prime * result + ((name == null) ? 0 : name.hashCode());
|
||||||
|
result = prime * result + ((description == null) ? 0 : description.hashCode());
|
||||||
|
result = prime * result + ((inputType == null) ? 0 : inputType.hashCode());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj)
|
||||||
|
return true;
|
||||||
|
if (obj == null)
|
||||||
|
return false;
|
||||||
|
if (getClass() != obj.getClass())
|
||||||
|
return false;
|
||||||
|
AbstractFunctionCallback other = (AbstractFunctionCallback) obj;
|
||||||
|
if (name == null) {
|
||||||
|
if (other.name != null)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (!name.equals(other.name))
|
||||||
|
return false;
|
||||||
|
if (description == null) {
|
||||||
|
if (other.description != null)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (!description.equals(other.description))
|
||||||
|
return false;
|
||||||
|
if (inputType == null) {
|
||||||
|
if (other.inputType != null)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (!inputType.equals(other.inputType))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,53 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model.function;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a model function call handler. Implementations are registered with the
|
||||||
|
* Models and called on prompts that trigger the function call.
|
||||||
|
*
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
public interface FunctionCallback {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Returns the Function name. Unique within the model.
|
||||||
|
*/
|
||||||
|
public String getName();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Returns the function description. This description is used by the model do
|
||||||
|
* decide if the function should be called or not.
|
||||||
|
*/
|
||||||
|
public String getDescription();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Returns the JSON schema of the function input type.
|
||||||
|
*/
|
||||||
|
public String getInputTypeSchema();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when a model detects and triggers a function call. The model is responsible
|
||||||
|
* to pass the function arguments in the pre-configured JSON schema format.
|
||||||
|
* @param functionInput JSON string with the function arguments to be passed to the
|
||||||
|
* function. The arguments are defined as JSON schema usually registered with the the
|
||||||
|
* model.
|
||||||
|
* @return String containing the function call response.
|
||||||
|
*/
|
||||||
|
public String call(String functionInput);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,124 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package cn.iocoder.yudao.framework.ai.model.function;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonClassDescription;
|
||||||
|
import org.springframework.beans.BeansException;
|
||||||
|
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
|
||||||
|
import org.springframework.cloud.function.context.config.FunctionContextUtils;
|
||||||
|
import org.springframework.context.ApplicationContext;
|
||||||
|
import org.springframework.context.ApplicationContextAware;
|
||||||
|
import org.springframework.context.annotation.Description;
|
||||||
|
import org.springframework.context.support.GenericApplicationContext;
|
||||||
|
import org.springframework.lang.NonNull;
|
||||||
|
import org.springframework.lang.Nullable;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
|
import java.lang.reflect.Type;
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Spring {@link ApplicationContextAware} implementation that provides a way to retrieve
|
||||||
|
* a {@link Function} from the Spring context and wrap it into a {@link FunctionCallback}.
|
||||||
|
*
|
||||||
|
* The name of the function is determined by the bean name.
|
||||||
|
*
|
||||||
|
* The description of the function is determined by the following rules:
|
||||||
|
* <ul>
|
||||||
|
* <li>Provided as a default description</li>
|
||||||
|
* <li>Provided as a {@code @Description} annotation on the bean</li>
|
||||||
|
* <li>Provided as a {@code @JsonClassDescription} annotation on the input class</li>
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @author Christian Tzolov
|
||||||
|
* @author Christopher Smith
|
||||||
|
*/
|
||||||
|
public class FunctionCallbackContext implements ApplicationContextAware {
|
||||||
|
|
||||||
|
private GenericApplicationContext applicationContext;
|
||||||
|
|
||||||
|
private FunctionCallbackWrapper.Builder.SchemaType schemaType = FunctionCallbackWrapper.Builder.SchemaType.JSON_SCHEMA;
|
||||||
|
|
||||||
|
public void setSchemaType(FunctionCallbackWrapper.Builder.SchemaType schemaType) {
|
||||||
|
this.schemaType = schemaType;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setApplicationContext(@NonNull ApplicationContext applicationContext) throws BeansException {
|
||||||
|
this.applicationContext = (GenericApplicationContext) applicationContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings({ "rawtypes", "unchecked" })
|
||||||
|
public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {
|
||||||
|
|
||||||
|
Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName);
|
||||||
|
|
||||||
|
if (beanType == null) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Functional bean with name: " + beanName + " does not exist in the context.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Function call Bean must be of type Function. Found: " + beanType.getTypeName());
|
||||||
|
}
|
||||||
|
|
||||||
|
Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0);
|
||||||
|
|
||||||
|
Class<?> functionInputClass = FunctionTypeUtils.getRawType(functionInputType);
|
||||||
|
String functionName = beanName;
|
||||||
|
String functionDescription = defaultDescription;
|
||||||
|
|
||||||
|
if (!StringUtils.hasText(functionDescription)) {
|
||||||
|
// Look for a Description annotation on the bean
|
||||||
|
Description descriptionAnnotation = applicationContext.findAnnotationOnBean(beanName, Description.class);
|
||||||
|
|
||||||
|
if (descriptionAnnotation != null) {
|
||||||
|
functionDescription = descriptionAnnotation.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!StringUtils.hasText(functionDescription)) {
|
||||||
|
// Look for a JsonClassDescription annotation on the input class
|
||||||
|
JsonClassDescription jsonClassDescriptionAnnotation = functionInputClass
|
||||||
|
.getAnnotation(JsonClassDescription.class);
|
||||||
|
if (jsonClassDescriptionAnnotation != null) {
|
||||||
|
functionDescription = jsonClassDescriptionAnnotation.value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!StringUtils.hasText(functionDescription)) {
|
||||||
|
throw new IllegalStateException("Could not determine function description."
|
||||||
|
+ "Please provide a description either as a default parameter, via @Description annotation on the bean "
|
||||||
|
+ "or @JsonClassDescription annotation on the input class.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Object bean = this.applicationContext.getBean(beanName);
|
||||||
|
|
||||||
|
if (bean instanceof Function<?, ?> function) {
|
||||||
|
return FunctionCallbackWrapper.builder(function)
|
||||||
|
.withName(functionName)
|
||||||
|
.withSchemaType(this.schemaType)
|
||||||
|
.withDescription(functionDescription)
|
||||||
|
.withInputType(functionInputClass)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
throw new IllegalArgumentException("Bean must be of type Function");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,140 @@
|
|||||||
|
package cn.iocoder.yudao.framework.ai.model.function;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.model.ModelOptionsUtils;
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Note that the underlying function is responsible for converting the output into format
|
||||||
|
* that can be consumed by the Model. The default implementation converts the output into
|
||||||
|
* String before sending it to the Model. Provide a custom function responseConverter
|
||||||
|
* implementation to override this.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class FunctionCallbackWrapper<I, O> extends AbstractFunctionCallback<I, O> {
|
||||||
|
|
||||||
|
private final Function<I, O> function;
|
||||||
|
|
||||||
|
private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, Class<I> inputType,
|
||||||
|
Function<O, String> responseConverter, Function<I, O> function) {
|
||||||
|
super(name, description, inputTypeSchema, inputType, responseConverter,
|
||||||
|
new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false));
|
||||||
|
Assert.notNull(function, "Function must not be null");
|
||||||
|
this.function = function;
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private static <I, O> Class<I> resolveInputType(Function<I, O> function) {
|
||||||
|
return (Class<I>) TypeResolverHelper.getFunctionInputClass((Class<Function<I, O>>) function.getClass());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public O apply(I input) {
|
||||||
|
return this.function.apply(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <I, O> Builder<I, O> builder(Function<I, O> function) {
|
||||||
|
return new Builder<>(function);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder<I, O> {
|
||||||
|
|
||||||
|
public enum SchemaType {
|
||||||
|
|
||||||
|
JSON_SCHEMA, OPEN_API_SCHEMA
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
private String description;
|
||||||
|
|
||||||
|
private Class<I> inputType;
|
||||||
|
|
||||||
|
private final Function<I, O> function;
|
||||||
|
|
||||||
|
private SchemaType schemaType = SchemaType.JSON_SCHEMA;
|
||||||
|
|
||||||
|
public Builder(Function<I, O> function) {
|
||||||
|
Assert.notNull(function, "Function must not be null");
|
||||||
|
this.function = function;
|
||||||
|
}
|
||||||
|
|
||||||
|
// By default the response is converted to a JSON string.
|
||||||
|
private Function<O, String> responseConverter = (response) -> ModelOptionsUtils.toJsonString(response);
|
||||||
|
|
||||||
|
private String inputTypeSchema;
|
||||||
|
|
||||||
|
private ObjectMapper objectMapper = new ObjectMapper()
|
||||||
|
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||||
|
|
||||||
|
public Builder<I, O> withName(String name) {
|
||||||
|
Assert.hasText(name, "Name must not be empty");
|
||||||
|
this.name = name;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder<I, O> withDescription(String description) {
|
||||||
|
Assert.hasText(description, "Description must not be empty");
|
||||||
|
this.description = description;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public Builder<I, O> withInputType(Class<?> inputType) {
|
||||||
|
this.inputType = (Class<I>) inputType;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder<I, O> withResponseConverter(Function<O, String> responseConverter) {
|
||||||
|
Assert.notNull(responseConverter, "ResponseConverter must not be null");
|
||||||
|
this.responseConverter = responseConverter;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder<I, O> withInputTypeSchema(String inputTypeSchema) {
|
||||||
|
Assert.hasText(inputTypeSchema, "InputTypeSchema must not be empty");
|
||||||
|
this.inputTypeSchema = inputTypeSchema;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder<I, O> withObjectMapper(ObjectMapper objectMapper) {
|
||||||
|
Assert.notNull(objectMapper, "ObjectMapper must not be null");
|
||||||
|
this.objectMapper = objectMapper;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder<I, O> withSchemaType(SchemaType schemaType) {
|
||||||
|
Assert.notNull(schemaType, "SchemaType must not be null");
|
||||||
|
this.schemaType = schemaType;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionCallbackWrapper<I, O> build() {
|
||||||
|
|
||||||
|
Assert.hasText(this.name, "Name must not be empty");
|
||||||
|
Assert.hasText(this.description, "Description must not be empty");
|
||||||
|
// Assert.notNull(this.inputType, "InputType must not be null");
|
||||||
|
Assert.notNull(this.function, "Function must not be null");
|
||||||
|
Assert.notNull(this.responseConverter, "ResponseConverter must not be null");
|
||||||
|
Assert.notNull(this.objectMapper, "ObjectMapper must not be null");
|
||||||
|
|
||||||
|
if (this.inputType == null) {
|
||||||
|
this.inputType = resolveInputType(this.function);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.inputTypeSchema == null) {
|
||||||
|
boolean upperCaseTypeValues = this.schemaType == SchemaType.OPEN_API_SCHEMA;
|
||||||
|
this.inputTypeSchema = ModelOptionsUtils.getJsonSchema(this.inputType, upperCaseTypeValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new FunctionCallbackWrapper<>(this.name, this.description, this.inputTypeSchema, this.inputType,
|
||||||
|
this.responseConverter, this.function);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,66 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model.function;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
public interface FunctionCallingOptions {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Function Callbacks to be registered with the ChatClient. For Prompt Options the
|
||||||
|
* functionCallbacks are automatically enabled for the duration of the prompt
|
||||||
|
* execution. For Default Options the FunctionCallbacks are registered but disabled by
|
||||||
|
* default. You have to use "functions" property to list the function names from the
|
||||||
|
* ChatClient registry to be used in the chat completion requests.
|
||||||
|
* @return Return the Function Callbacks to be registered with the ChatClient.
|
||||||
|
*/
|
||||||
|
List<FunctionCallback> getFunctionCallbacks();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the Function Callbacks to be registered with the ChatClient.
|
||||||
|
* @param functionCallbacks the Function Callbacks to be registered with the
|
||||||
|
* ChatClient.
|
||||||
|
*/
|
||||||
|
void setFunctionCallbacks(List<FunctionCallback> functionCallbacks);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return List of function names from the ChatClient registry to be used in the next
|
||||||
|
* chat completion requests.
|
||||||
|
*/
|
||||||
|
Set<String> getFunctions();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the list of function names from the ChatClient registry to be used in the next
|
||||||
|
* chat completion requests.
|
||||||
|
* @param functions the list of function names from the ChatClient registry to be used
|
||||||
|
* in the next chat completion requests.
|
||||||
|
*/
|
||||||
|
void setFunctions(Set<String> functions);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Returns FunctionCallingOptionsBuilder to create a new instance of
|
||||||
|
* FunctionCallingOptions.
|
||||||
|
*/
|
||||||
|
public static FunctionCallingOptionsBuilder builder() {
|
||||||
|
return new FunctionCallingOptionsBuilder();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,150 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model.function;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builder for {@link FunctionCallingOptions}. Using the {@link FunctionCallingOptions}
|
||||||
|
* permits options portability between different AI providers that support
|
||||||
|
* function-calling.
|
||||||
|
*
|
||||||
|
* @author Christian Tzolov
|
||||||
|
* @since 0.8.1
|
||||||
|
*/
|
||||||
|
public class FunctionCallingOptionsBuilder {
|
||||||
|
|
||||||
|
private final PortableFunctionCallingOptions options;
|
||||||
|
|
||||||
|
public FunctionCallingOptionsBuilder() {
|
||||||
|
this.options = new PortableFunctionCallingOptions();
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionCallingOptionsBuilder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
|
||||||
|
this.options.setFunctionCallbacks(functionCallbacks);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionCallingOptionsBuilder withFunctionCallback(FunctionCallback functionCallback) {
|
||||||
|
Assert.notNull(functionCallback, "FunctionCallback must not be null");
|
||||||
|
this.options.getFunctionCallbacks().add(functionCallback);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionCallingOptionsBuilder withFunctions(Set<String> functions) {
|
||||||
|
this.options.setFunctions(functions);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionCallingOptionsBuilder withFunction(String function) {
|
||||||
|
Assert.notNull(function, "Function must not be null");
|
||||||
|
this.options.getFunctions().add(function);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionCallingOptionsBuilder withTemperature(Float temperature) {
|
||||||
|
this.options.setTemperature(temperature);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionCallingOptionsBuilder withTopP(Float topP) {
|
||||||
|
this.options.setTopP(topP);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionCallingOptionsBuilder withTopK(Integer topK) {
|
||||||
|
this.options.setTopK(topK);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PortableFunctionCallingOptions build() {
|
||||||
|
return this.options;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class PortableFunctionCallingOptions implements FunctionCallingOptions, ChatOptions {
|
||||||
|
|
||||||
|
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
|
||||||
|
|
||||||
|
private Set<String> functions = new HashSet<>();
|
||||||
|
|
||||||
|
private Float temperature;
|
||||||
|
|
||||||
|
private Float topP;
|
||||||
|
|
||||||
|
private Integer topK;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<FunctionCallback> getFunctionCallbacks() {
|
||||||
|
return this.functionCallbacks;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
|
||||||
|
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
|
||||||
|
this.functionCallbacks = functionCallbacks;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Set<String> getFunctions() {
|
||||||
|
return this.functions;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setFunctions(Set<String> functions) {
|
||||||
|
Assert.notNull(functions, "Functions must not be null");
|
||||||
|
this.functions = functions;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Float getTemperature() {
|
||||||
|
return this.temperature;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setTemperature(Float temperature) {
|
||||||
|
this.temperature = temperature;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Float getTopP() {
|
||||||
|
return this.topP;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setTopP(Float topP) {
|
||||||
|
this.topP = topP;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer getTopK() {
|
||||||
|
return this.topK;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setTopK(Integer topK) {
|
||||||
|
this.topK = topK;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,87 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024-2024 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.model.function;
|
||||||
|
|
||||||
|
import net.jodah.typetools.TypeResolver;
|
||||||
|
|
||||||
|
import java.lang.reflect.GenericArrayType;
|
||||||
|
import java.lang.reflect.ParameterizedType;
|
||||||
|
import java.lang.reflect.Type;
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
public class TypeResolverHelper {
|
||||||
|
|
||||||
|
public static Class<?> getFunctionInputClass(Class<? extends Function<?, ?>> functionClass) {
|
||||||
|
return getFunctionArgumentClass(functionClass, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Class<?> getFunctionOutputClass(Class<? extends Function<?, ?>> functionClass) {
|
||||||
|
return getFunctionArgumentClass(functionClass, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Class<?> getFunctionArgumentClass(Class<? extends Function<?, ?>> functionClass, int argumentIndex) {
|
||||||
|
Type type = TypeResolver.reify(Function.class, functionClass);
|
||||||
|
|
||||||
|
var argumentType = type instanceof ParameterizedType
|
||||||
|
? ((ParameterizedType) type).getActualTypeArguments()[argumentIndex] : Object.class;
|
||||||
|
|
||||||
|
return toRawClass(argumentType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Type getFunctionInputType(Class<? extends Function<?, ?>> functionClass) {
|
||||||
|
return getFunctionArgumentType(functionClass, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Type getFunctionOutputType(Class<? extends Function<?, ?>> functionClass) {
|
||||||
|
return getFunctionArgumentType(functionClass, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Type getFunctionArgumentType(Class<? extends Function<?, ?>> functionClass, int argumentIndex) {
|
||||||
|
Type functionType = TypeResolver.reify(Function.class, functionClass);
|
||||||
|
return getFunctionArgumentType(functionType, argumentIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Type getFunctionArgumentType(Type functionType, int argumentIndex) {
|
||||||
|
var argumentType = functionType instanceof ParameterizedType
|
||||||
|
? ((ParameterizedType) functionType).getActualTypeArguments()[argumentIndex] : Object.class;
|
||||||
|
|
||||||
|
return argumentType;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Effectively converts {@link Type} which could be {@link ParameterizedType} to raw
|
||||||
|
* Class (no generics).
|
||||||
|
* @param type actual {@link Type} instance
|
||||||
|
* @return instance of {@link Class} as raw representation of the provided
|
||||||
|
* {@link Type}
|
||||||
|
*/
|
||||||
|
public static Class<?> toRawClass(Type type) {
|
||||||
|
return type != null
|
||||||
|
? TypeResolver.resolveRawClass(type instanceof GenericArrayType ? type : TypeResolver.reify(type), null)
|
||||||
|
: null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// public static void main(String[] args) {
|
||||||
|
// Class<? extends Function<?, ?>> clazz = MockWeatherService.class;
|
||||||
|
// System.out.println(getFunctionInputType(clazz));
|
||||||
|
|
||||||
|
// }
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,11 @@
|
|||||||
|
/**
|
||||||
|
* Provides a set of interfaces and classes for a generic API designed to interact with
|
||||||
|
* various AI models. This package includes interfaces for handling AI model calls,
|
||||||
|
* requests, responses, results, and associated metadata. It is designed to offer a
|
||||||
|
* flexible and adaptable framework for interacting with different types of AI models,
|
||||||
|
* abstracting the complexities involved in model invocation and result processing. The
|
||||||
|
* use of generics enhances the API's capability to work with a wide range of models,
|
||||||
|
* ensuring a broad applicability across diverse AI scenarios.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
package cn.iocoder.yudao.framework.ai.model;
|
@ -0,0 +1,5 @@
|
|||||||
|
/**
|
||||||
|
* author: fansili
|
||||||
|
* time: 2024/3/12 20:29
|
||||||
|
*/
|
||||||
|
package cn.iocoder.yudao.framework.ai;
|
@ -0,0 +1,42 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.parser;
|
||||||
|
|
||||||
|
import org.springframework.core.convert.support.DefaultConversionService;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract {@link OutputParser} implementation that uses a pre-configured
|
||||||
|
* {@link DefaultConversionService} to convert the LLM output into the desired type
|
||||||
|
* format.
|
||||||
|
*
|
||||||
|
* @param <T> Specifies the desired response type.
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
public abstract class AbstractConversionServiceOutputParser<T> implements OutputParser<T> {
|
||||||
|
|
||||||
|
private final DefaultConversionService conversionService;
|
||||||
|
|
||||||
|
public AbstractConversionServiceOutputParser(DefaultConversionService conversionService) {
|
||||||
|
this.conversionService = conversionService;
|
||||||
|
}
|
||||||
|
|
||||||
|
public DefaultConversionService getConversionService() {
|
||||||
|
return conversionService;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,41 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.parser;
|
||||||
|
|
||||||
|
import org.springframework.messaging.converter.MessageConverter;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract {@link OutputParser} implementation that uses a pre-configured
|
||||||
|
* {@link MessageConverter} to convert the LLM output into the desired type format.
|
||||||
|
*
|
||||||
|
* @param <T> Specifies the desired response type.
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
public abstract class AbstractMessageConverterOutputParser<T> implements OutputParser<T> {
|
||||||
|
|
||||||
|
private MessageConverter messageConverter;
|
||||||
|
|
||||||
|
public AbstractMessageConverterOutputParser(MessageConverter messageConverter) {
|
||||||
|
this.messageConverter = messageConverter;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageConverter getMessageConverter() {
|
||||||
|
return this.messageConverter;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,166 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.parser;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.core.util.DefaultIndenter;
|
||||||
|
import com.fasterxml.jackson.core.util.DefaultPrettyPrinter;
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectWriter;
|
||||||
|
import com.github.victools.jsonschema.generator.SchemaGenerator;
|
||||||
|
import com.github.victools.jsonschema.generator.SchemaGeneratorConfig;
|
||||||
|
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
|
||||||
|
import com.github.victools.jsonschema.module.jackson.JacksonModule;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static com.github.victools.jsonschema.generator.OptionPreset.PLAIN_JSON;
|
||||||
|
import static com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An implementation of {@link OutputParser} that transforms the LLM output to a specific
|
||||||
|
* object type using JSON schema. This parser works by generating a JSON schema based on a
|
||||||
|
* given Java class, which is then used to validate and transform the LLM output into the
|
||||||
|
* desired type.
|
||||||
|
*
|
||||||
|
* @param <T> The target type to which the output will be converted.
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @author Christian Tzolov
|
||||||
|
* @author Sebastian Ullrich
|
||||||
|
* @author Kirk Lund
|
||||||
|
*/
|
||||||
|
public class BeanOutputParser<T> implements OutputParser<T> {
|
||||||
|
|
||||||
|
/** Holds the generated JSON schema for the target type. */
|
||||||
|
private String jsonSchema;
|
||||||
|
|
||||||
|
/** The Java class representing the target type. */
|
||||||
|
@SuppressWarnings({ "FieldMayBeFinal", "rawtypes" })
|
||||||
|
private Class<T> clazz;
|
||||||
|
|
||||||
|
/** The object mapper used for deserialization and other JSON operations. */
|
||||||
|
@SuppressWarnings("FieldMayBeFinal")
|
||||||
|
private ObjectMapper objectMapper;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize with the target type's class.
|
||||||
|
* @param clazz The target type's class.
|
||||||
|
*/
|
||||||
|
public BeanOutputParser(Class<T> clazz) {
|
||||||
|
this(clazz, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize with the target type's class, a custom object mapper, and
|
||||||
|
* a line endings normalizer to ensure consistent line endings on any platform.
|
||||||
|
* @param clazz The target type's class.
|
||||||
|
* @param objectMapper Custom object mapper for JSON operations. endings.
|
||||||
|
*/
|
||||||
|
public BeanOutputParser(Class<T> clazz, ObjectMapper objectMapper) {
|
||||||
|
Objects.requireNonNull(clazz, "Java Class cannot be null;");
|
||||||
|
this.clazz = clazz;
|
||||||
|
this.objectMapper = objectMapper != null ? objectMapper : getObjectMapper();
|
||||||
|
generateSchema();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates the JSON schema for the target type.
|
||||||
|
*/
|
||||||
|
private void generateSchema() {
|
||||||
|
JacksonModule jacksonModule = new JacksonModule();
|
||||||
|
SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(DRAFT_2020_12, PLAIN_JSON)
|
||||||
|
.with(jacksonModule);
|
||||||
|
SchemaGeneratorConfig config = configBuilder.build();
|
||||||
|
SchemaGenerator generator = new SchemaGenerator(config);
|
||||||
|
JsonNode jsonNode = generator.generateSchema(this.clazz);
|
||||||
|
ObjectWriter objectWriter = new ObjectMapper()
|
||||||
|
.writer(new DefaultPrettyPrinter().withObjectIndenter(new DefaultIndenter().withLinefeed("\n")));
|
||||||
|
try {
|
||||||
|
this.jsonSchema = objectWriter.writeValueAsString(jsonNode);
|
||||||
|
}
|
||||||
|
catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException("Could not pretty print json schema for " + this.clazz, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
/**
|
||||||
|
* Parses the given text to transform it to the desired target type.
|
||||||
|
* @param text The LLM output in string format.
|
||||||
|
* @return The parsed output in the desired target type.
|
||||||
|
*/
|
||||||
|
public T parse(String text) {
|
||||||
|
try {
|
||||||
|
// If the response is a JSON Schema, extract the properties and use them as
|
||||||
|
// the response.
|
||||||
|
text = this.jsonSchemaToInstance(text);
|
||||||
|
return (T) this.objectMapper.readValue(text, this.clazz);
|
||||||
|
}
|
||||||
|
catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a JSON Schema to an instance based on a given text.
|
||||||
|
* @param text The JSON Schema in string format.
|
||||||
|
* @return The JSON instance generated from the JSON Schema, or the original text if
|
||||||
|
* the input is not a JSON Schema.
|
||||||
|
*/
|
||||||
|
private String jsonSchemaToInstance(String text) {
|
||||||
|
try {
|
||||||
|
Map<String, Object> map = this.objectMapper.readValue(text, Map.class);
|
||||||
|
if (map.containsKey("$schema")) {
|
||||||
|
return this.objectMapper.writeValueAsString(map.get("properties"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (Exception e) {
|
||||||
|
}
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configures and returns an object mapper for JSON operations.
|
||||||
|
* @return Configured object mapper.
|
||||||
|
*/
|
||||||
|
protected ObjectMapper getObjectMapper() {
|
||||||
|
ObjectMapper mapper = new ObjectMapper();
|
||||||
|
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||||
|
return mapper;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provides the expected format of the response, instructing that it should adhere to
|
||||||
|
* the generated JSON schema.
|
||||||
|
* @return The instruction format string.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public String getFormat() {
|
||||||
|
String template = """
|
||||||
|
Your response should be in JSON format.
|
||||||
|
Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
|
||||||
|
Do not include markdown code blocks in your response.
|
||||||
|
Here is the JSON Schema instance your output must adhere to:
|
||||||
|
```%s```
|
||||||
|
""";
|
||||||
|
return String.format(template, this.jsonSchema);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,33 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.parser;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implementations of this interface provides instructions for how the output of a
|
||||||
|
* language generative should be formatted.
|
||||||
|
*
|
||||||
|
* @author Mark Pollack
|
||||||
|
*/
|
||||||
|
public interface FormatProvider {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Returns a string containing instructions for how the output of a language
|
||||||
|
* generative should be formatted.
|
||||||
|
*/
|
||||||
|
String getFormat();
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,48 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package cn.iocoder.yudao.framework.ai.parser;
|
||||||
|
|
||||||
|
import org.springframework.core.convert.support.DefaultConversionService;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link OutputParser} implementation that uses a {@link DefaultConversionService} to
|
||||||
|
* convert the LLM output into a {@link List} instance.
|
||||||
|
*
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
public class ListOutputParser extends AbstractConversionServiceOutputParser<List<String>> {
|
||||||
|
|
||||||
|
public ListOutputParser(DefaultConversionService defaultConversionService) {
|
||||||
|
super(defaultConversionService);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getFormat() {
|
||||||
|
return """
|
||||||
|
Your response should be a list of comma separated values
|
||||||
|
eg: `foo, bar, baz`
|
||||||
|
""";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> parse(String text) {
|
||||||
|
return getConversionService().convert(text, List.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,57 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.parser;
|
||||||
|
|
||||||
|
import org.springframework.messaging.Message;
|
||||||
|
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
|
||||||
|
import org.springframework.messaging.support.MessageBuilder;
|
||||||
|
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link OutputParser} implementation that uses a pre-configured
|
||||||
|
* {@link MappingJackson2MessageConverter} to convert the LLM output into a
|
||||||
|
* java.util.Map<String, Object> instance.
|
||||||
|
*
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
public class MapOutputParser extends AbstractMessageConverterOutputParser<Map<String, Object>> {
|
||||||
|
|
||||||
|
public MapOutputParser() {
|
||||||
|
super(new MappingJackson2MessageConverter());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Object> parse(String text) {
|
||||||
|
Message<?> message = MessageBuilder.withPayload(text.getBytes(StandardCharsets.UTF_8)).build();
|
||||||
|
return (Map) getMessageConverter().fromMessage(message, HashMap.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getFormat() {
|
||||||
|
String raw = """
|
||||||
|
Your response should be in JSON format.
|
||||||
|
The data structure for the JSON should match this Java class: %s
|
||||||
|
Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
|
||||||
|
""";
|
||||||
|
return String.format(raw, "java.util.HashMap");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,30 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.parser;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the (raw) LLM output into a structured responses of type. The
|
||||||
|
* {@link FormatProvider#getFormat()} method should provide the LLM prompt description of
|
||||||
|
* the desired format.
|
||||||
|
*
|
||||||
|
* @param <T> Specifies the desired response type.
|
||||||
|
* @author Mark Pollack
|
||||||
|
* @author Christian Tzolov
|
||||||
|
*/
|
||||||
|
public interface OutputParser<T> extends Parser<T>, FormatProvider {
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cn.iocoder.yudao.framework.ai.parser;
|
||||||
|
|
||||||
|
@FunctionalInterface
|
||||||
|
public interface Parser<T> {
|
||||||
|
|
||||||
|
T parse(String text);
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,12 @@
|
|||||||
|
# Output Parsing
|
||||||
|
|
||||||
|
* [Documentation](https://docs.spring.io/spring-ai/reference/concepts.html#_output_parsing)
|
||||||
|
* [Usage examples](https://github.com/spring-projects/spring-ai/blob/main/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIT.java)
|
||||||
|
|
||||||
|
The output of AI models traditionally arrives as a java.util.String, even if you ask for the reply to be in JSON. It may be the correct JSON, but it isn’t a JSON data structure. It is just a string. Also, asking "for JSON" as part of the prompt isn’t 100% accurate.
|
||||||
|
|
||||||
|
This intricacy has led to the emergence of a specialized field involving the creation of prompts to yield the intended output, followed by parsing the resulting simple string into a usable data structure for application integration.
|
||||||
|
|
||||||
|
Output parsing employs meticulously crafted prompts, often necessitating multiple interactions with the model to achieve the desired formatting.
|
||||||
|
|
||||||
|
This challenge has prompted OpenAI to introduce 'OpenAI Functions' as a means to specify the desired output format from the model precisely.
|
Loading…
Reference in New Issue
Block a user