diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanChatClient.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanChatClient.java
index fae550da8..d95e9fb99 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanChatClient.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanChatClient.java
@@ -2,6 +2,8 @@ package cn.iocoder.yudao.framework.ai.chatyiyan;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.chat.*;
+import cn.iocoder.yudao.framework.ai.chat.messages.Message;
+import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
@@ -9,6 +11,7 @@ import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletion;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest;
import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
import lombok.extern.slf4j.Slf4j;
+import org.jetbrains.annotations.NotNull;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
@@ -18,10 +21,11 @@ import reactor.core.publisher.Flux;
import java.time.Duration;
import java.util.List;
+import java.util.stream.Collectors;
/**
* 文心一言
- *
+ *
* author: fansili
* time: 2024/3/8 19:11
*/
@@ -52,7 +56,9 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
public void onError(RetryContext context,
RetryCallback callback, Throwable throwable) {
log.warn("重试异常:" + context.getRetryCount(), throwable);
- };
+ }
+
+ ;
})
.build();
@@ -92,6 +98,42 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
}
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
+ // 获取配置
+ YiYanOptions useOptions = getYiYanOptions(prompt);
+ // 创建 request
+
+ // tip: 百度的 system 不在 message 里面
+ // tip:百度的 message 只有 user 和 assistant
+ // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
+
+ // 获取 user 和 assistant
+ List messageList = prompt.getInstructions().stream()
+ // 过滤 system
+ .filter(msg -> MessageType.SYSTEM != msg.getMessageType())
+ .map(msg -> new YiYanChatCompletionRequest.Message()
+ .setRole(msg.getMessageType().getValue())
+ .setContent(msg.getContent())
+ ).toList();
+ // 获取 system
+ String systemPrompt = prompt.getInstructions().stream()
+ .filter(msg -> MessageType.SYSTEM == msg.getMessageType())
+ .map(Message::getContent)
+ .collect(Collectors.joining());
+
+ YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
+ // 复制 qianWenOptions 属性取 request(这里 options 属性和 request 基本保持一致)
+ // top: 由于遵循 spring-ai规范,支持在构建client的时候传入默认的 chatOptions
+ BeanUtil.copyProperties(useOptions, request);
+ request.setTop_p(useOptions.getTopP());
+ request.setMax_output_tokens(useOptions.getMaxOutputTokens());
+ request.setTemperature(useOptions.getTemperature());
+ request.setSystem(systemPrompt);
+ // 设置 stream
+ request.setStream(stream);
+ return request;
+ }
+
+ private @NotNull YiYanOptions getYiYanOptions(Prompt prompt) {
// 两个都为null 则没有配置文件
if (yiYanOptions == null && prompt.getOptions() == null) {
throw new ChatException("ChatOptions 未配置参数!");
@@ -106,19 +148,7 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
throw new ChatException("Prompt 传入的不是 YiYanOptions!");
}
// 转换 YiYanOptions
- YiYanOptions qianWenOptions = (YiYanOptions) options;
- // 创建 request
- List messageList = prompt.getInstructions().stream().map(
- msg -> new YiYanChatCompletionRequest.Message()
- .setRole(msg.getMessageType().getValue())
- .setContent(msg.getContent())
- ).toList();
- YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
- // 复制 qianWenOptions 属性取 request(这里 options 属性和 request 基本保持一致)
- // top: 由于遵循 spring-ai规范,支持在构建client的时候传入默认的 chatOptions
- BeanUtil.copyProperties(qianWenOptions, request);
- // 设置 stream
- request.setStream(stream);
- return request;
+ YiYanOptions useOptions = (YiYanOptions) options;
+ return useOptions;
}
}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanOptions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanOptions.java
index 4b1bc3fe6..84f0ced4c 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanOptions.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanOptions.java
@@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.chatyiyan;
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest;
-import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import lombok.experimental.Accessors;
@@ -40,7 +39,7 @@ public class YiYanOptions implements ChatOptions {
* (2)默认0.8,取值范围 [0, 1.0]
* 必填:否
*/
- private Float top_p;
+ private Float topP;
/**
* 通过对已生成的token增加惩罚,减少重复生成的现象。说明:
* (1)值越大表示惩罚越大
@@ -84,7 +83,7 @@ public class YiYanOptions implements ChatOptions {
* 指定模型最大输出token数,范围[2, 2048]
* 必填:否
*/
- private Integer max_output_tokens;
+ private Integer maxOutputTokens;
/**
* 指定响应内容的格式,说明:
* (1)可选值:
@@ -122,12 +121,12 @@ public class YiYanOptions implements ChatOptions {
@Override
public Float getTopP() {
- return top_p;
+ return topP;
}
@Override
public void setTopP(Float topP) {
- this.top_p = topP;
+ this.topP = topP;
}
// 百度么有 topK
@@ -139,6 +138,5 @@ public class YiYanOptions implements ChatOptions {
@Override
public void setTopK(Integer topK) {
-
}
}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java
index 326f52552..df98541fc 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java
@@ -1,5 +1,8 @@
package cn.iocoder.yudao.framework.ai.chat;
+import cn.iocoder.yudao.framework.ai.chat.messages.Message;
+import cn.iocoder.yudao.framework.ai.chat.messages.SystemMessage;
+import cn.iocoder.yudao.framework.ai.chat.messages.UserMessage;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
@@ -9,11 +12,13 @@ import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Flux;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Scanner;
/**
* chat 文心一言
- *
+ *
* author: fansili
* time: 2024/3/12 20:59
*/
@@ -29,18 +34,36 @@ public class YiYanChatTests {
YiYanChatModel.ERNIE4_3_5_8K,
86400
);
- yiYanChatClient = new YiYanChatClient(yiYanApi, new YiYanOptions().setMax_output_tokens(2048));
+ YiYanOptions yiYanOptions = new YiYanOptions();
+ yiYanOptions.setMaxOutputTokens(2048);
+ yiYanOptions.setTopP(0.6f);
+ yiYanOptions.setTemperature(0.85f);
+ yiYanChatClient = new YiYanChatClient(
+ yiYanApi,
+ yiYanOptions
+ );
}
@Test
public void callTest() {
- ChatResponse call = yiYanChatClient.call(new Prompt("什么编程语言最好?"));
+
+ // tip: 百度的message 有特殊规则(最后一个message为当前请求的信息,前面的message为历史对话信息)
+ // tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
+ List messages = new ArrayList<>();
+ messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
+ messages.add(new UserMessage("长沙怎么样?"));
+
+ ChatResponse call = yiYanChatClient.call(new Prompt(messages));
System.err.println(call.getResult());
}
@Test
public void streamTest() {
- Flux fluxResponse = yiYanChatClient.stream(new Prompt("用java帮我写一个快排算法?"));
+ List messages = new ArrayList<>();
+ messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
+ messages.add(new UserMessage("长沙怎么样?"));
+
+ Flux fluxResponse = yiYanChatClient.stream(new Prompt(messages));
fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
// 阻止退出
Scanner scanner = new Scanner(System.in);