【优化】处理百度 system 角色定制失效问题。

This commit is contained in:
cherishsince 2024-04-27 18:29:58 +08:00
parent c811f3a4c2
commit 10a94c3ef2
3 changed files with 77 additions and 26 deletions

View File

@ -2,6 +2,8 @@ package cn.iocoder.yudao.framework.ai.chatyiyan;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.chat.*;
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
@ -9,6 +11,7 @@ import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletion;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest;
import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
@ -18,10 +21,11 @@ import reactor.core.publisher.Flux;
import java.time.Duration;
import java.util.List;
import java.util.stream.Collectors;
/**
* 文心一言
*
* <p>
* author: fansili
* time: 2024/3/8 19:11
*/
@ -52,7 +56,9 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
public <T extends Object, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) {
log.warn("重试异常:" + context.getRetryCount(), throwable);
};
}
;
})
.build();
@ -92,6 +98,42 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
}
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 获取配置
YiYanOptions useOptions = getYiYanOptions(prompt);
// 创建 request
// tip: 百度的 system 不在 message 里面
// tip百度的 message 只有 user assistant
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
// 获取 user assistant
List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream()
// 过滤 system
.filter(msg -> MessageType.SYSTEM != msg.getMessageType())
.map(msg -> new YiYanChatCompletionRequest.Message()
.setRole(msg.getMessageType().getValue())
.setContent(msg.getContent())
).toList();
// 获取 system
String systemPrompt = prompt.getInstructions().stream()
.filter(msg -> MessageType.SYSTEM == msg.getMessageType())
.map(Message::getContent)
.collect(Collectors.joining());
YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
// 复制 qianWenOptions 属性取 request这里 options 属性和 request 基本保持一致
// top: 由于遵循 spring-ai规范支持在构建client的时候传入默认的 chatOptions
BeanUtil.copyProperties(useOptions, request);
request.setTop_p(useOptions.getTopP());
request.setMax_output_tokens(useOptions.getMaxOutputTokens());
request.setTemperature(useOptions.getTemperature());
request.setSystem(systemPrompt);
// 设置 stream
request.setStream(stream);
return request;
}
private @NotNull YiYanOptions getYiYanOptions(Prompt prompt) {
// 两个都为null 则没有配置文件
if (yiYanOptions == null && prompt.getOptions() == null) {
throw new ChatException("ChatOptions 未配置参数!");
@ -106,19 +148,7 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
throw new ChatException("Prompt 传入的不是 YiYanOptions!");
}
// 转换 YiYanOptions
YiYanOptions qianWenOptions = (YiYanOptions) options;
// 创建 request
List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream().map(
msg -> new YiYanChatCompletionRequest.Message()
.setRole(msg.getMessageType().getValue())
.setContent(msg.getContent())
).toList();
YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
// 复制 qianWenOptions 属性取 request这里 options 属性和 request 基本保持一致
// top: 由于遵循 spring-ai规范支持在构建client的时候传入默认的 chatOptions
BeanUtil.copyProperties(qianWenOptions, request);
// 设置 stream
request.setStream(stream);
return request;
YiYanOptions useOptions = (YiYanOptions) options;
return useOptions;
}
}

View File

@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.chatyiyan;
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import lombok.experimental.Accessors;
@ -40,7 +39,7 @@ public class YiYanOptions implements ChatOptions {
* 2默认0.8取值范围 [0, 1.0]
* 必填
*/
private Float top_p;
private Float topP;
/**
* 通过对已生成的token增加惩罚减少重复生成的现象说明
* 1值越大表示惩罚越大
@ -84,7 +83,7 @@ public class YiYanOptions implements ChatOptions {
* 指定模型最大输出token数范围[2, 2048]
* 必填
*/
private Integer max_output_tokens;
private Integer maxOutputTokens;
/**
* 指定响应内容的格式说明
* 1可选值
@ -122,12 +121,12 @@ public class YiYanOptions implements ChatOptions {
@Override
public Float getTopP() {
return top_p;
return topP;
}
@Override
public void setTopP(Float topP) {
this.top_p = topP;
this.topP = topP;
}
// 百度么有 topK
@ -139,6 +138,5 @@ public class YiYanOptions implements ChatOptions {
@Override
public void setTopK(Integer topK) {
}
}

View File

@ -1,5 +1,8 @@
package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
import cn.iocoder.yudao.framework.ai.chat.messages.SystemMessage;
import cn.iocoder.yudao.framework.ai.chat.messages.UserMessage;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
@ -9,11 +12,13 @@ import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
/**
* chat 文心一言
*
* <p>
* author: fansili
* time: 2024/3/12 20:59
*/
@ -29,18 +34,36 @@ public class YiYanChatTests {
YiYanChatModel.ERNIE4_3_5_8K,
86400
);
yiYanChatClient = new YiYanChatClient(yiYanApi, new YiYanOptions().setMax_output_tokens(2048));
YiYanOptions yiYanOptions = new YiYanOptions();
yiYanOptions.setMaxOutputTokens(2048);
yiYanOptions.setTopP(0.6f);
yiYanOptions.setTemperature(0.85f);
yiYanChatClient = new YiYanChatClient(
yiYanApi,
yiYanOptions
);
}
@Test
public void callTest() {
ChatResponse call = yiYanChatClient.call(new Prompt("什么编程语言最好?"));
// tip: 百度的message 有特殊规则(最后一个message为当前请求的信息前面的message为历史对话信息)
// tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
messages.add(new UserMessage("长沙怎么样?"));
ChatResponse call = yiYanChatClient.call(new Prompt(messages));
System.err.println(call.getResult());
}
@Test
public void streamTest() {
Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt("用java帮我写一个快排算法"));
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
messages.add(new UserMessage("长沙怎么样?"));
Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt(messages));
fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
// 阻止退出
Scanner scanner = new Scanner(System.in);