适配讯飞星火 chatOptions

This commit is contained in:
cherishsince 2024-03-16 21:24:57 +08:00
parent f41e43713c
commit 94e9ee9590
8 changed files with 162 additions and 58 deletions

View File

@ -6,7 +6,6 @@ import cn.iocoder.yudao.framework.ai.chat.*;
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType; import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions; import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanOptions;
import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException; import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
import com.aliyun.broadscope.bailian.sdk.models.*; import com.aliyun.broadscope.bailian.sdk.models.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -19,7 +18,6 @@ import reactor.core.publisher.Flux;
import java.time.Duration; import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**

View File

@ -39,22 +39,20 @@ public class XingHuoApi {
private String appKey; private String appKey;
private String secretKey; private String secretKey;
private WebClient webClient; private WebClient webClient;
private XingHuoChatModel useChatModel;
// 创建 WebSocketClient 实例 // 创建 WebSocketClient 实例
private ReactorNettyWebSocketClient socketClient = new ReactorNettyWebSocketClient(); private ReactorNettyWebSocketClient socketClient = new ReactorNettyWebSocketClient();
public XingHuoApi(String appId, String appKey, String secretKey, XingHuoChatModel useChatModel) { public XingHuoApi(String appId, String appKey, String secretKey) {
this.appId = appId; this.appId = appId;
this.appKey = appKey; this.appKey = appKey;
this.secretKey = secretKey; this.secretKey = secretKey;
this.useChatModel = useChatModel;
} }
public ResponseEntity<XingHuoChatCompletion> chatCompletionEntity(XingHuoChatCompletionRequest request) { public ResponseEntity<XingHuoChatCompletion> chatCompletionEntity(XingHuoChatCompletionRequest request, XingHuoChatModel xingHuoChatModel) {
String authUrl; String authUrl;
try { try {
authUrl = getAuthorizationUrl("spark-api.xf-yun.com", useChatModel.getUri()); // XingHuoChatModel useChatModel;
authUrl = getAuthorizationUrl("spark-api.xf-yun.com", xingHuoChatModel.getUri());
} catch (NoSuchAlgorithmException | InvalidKeyException e) { } catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -125,10 +123,10 @@ public class XingHuoApi {
return "wss://" + host + path + "?" + toParams; return "wss://" + host + path + "?" + toParams;
} }
public Flux<XingHuoChatCompletion> chatCompletionStream(XingHuoChatCompletionRequest request) { public Flux<XingHuoChatCompletion> chatCompletionStream(XingHuoChatCompletionRequest request, XingHuoChatModel xingHuoChatModel) {
String authUrl; String authUrl;
try { try {
authUrl = getAuthorizationUrl("spark-api.xf-yun.com", useChatModel.getUri()); authUrl = getAuthorizationUrl("spark-api.xf-yun.com", xingHuoChatModel.getUri());
} catch (NoSuchAlgorithmException | InvalidKeyException e) { } catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -1,13 +1,12 @@
package cn.iocoder.yudao.framework.ai.chatxinghuo; package cn.iocoder.yudao.framework.ai.chatxinghuo;
import cn.iocoder.yudao.framework.ai.chat.ChatClient; import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse; import cn.hutool.core.exceptions.ExceptionUtil;
import cn.iocoder.yudao.framework.ai.chat.Generation; import cn.iocoder.yudao.framework.ai.chat.*;
import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient; import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatxinghuo.api.XingHuoChatCompletion; import cn.iocoder.yudao.framework.ai.chatxinghuo.api.XingHuoChatCompletion;
import cn.iocoder.yudao.framework.ai.chatxinghuo.api.XingHuoChatCompletionRequest; import cn.iocoder.yudao.framework.ai.chatxinghuo.api.XingHuoChatCompletionRequest;
import cn.iocoder.yudao.framework.ai.chatxinghuo.exception.XingHuoApiException;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryCallback;
@ -31,16 +30,19 @@ public class XingHuoChatClient implements ChatClient, StreamingChatClient {
private XingHuoApi xingHuoApi; private XingHuoApi xingHuoApi;
private XingHuoOptions xingHuoOptions;
public final RetryTemplate retryTemplate = RetryTemplate.builder() public final RetryTemplate retryTemplate = RetryTemplate.builder()
// 最大重试次数 10 // 最大重试次数 10
.maxAttempts(10) .maxAttempts(3)
.retryOn(XingHuoApiException.class) .retryOn(ChatException.class)
// 最大重试5次第一次间隔3000ms第二次3000ms * 2第三次3000ms * 3以此类推最大间隔3 * 60000ms // 最大重试5次第一次间隔3000ms第二次3000ms * 2第三次3000ms * 3以此类推最大间隔3 * 60000ms
.exponentialBackoff(Duration.ofMillis(3000), 2, Duration.ofMillis(3 * 60000)) .exponentialBackoff(Duration.ofMillis(3000), 2, Duration.ofMillis(3 * 60000))
.withListener(new RetryListener() { .withListener(new RetryListener() {
@Override @Override
public <T extends Object, E extends Throwable> void onError(RetryContext context, public <T extends Object, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) { RetryCallback<T, E> callback, Throwable throwable) {
System.err.println("正在重试... " + ExceptionUtil.getMessage(throwable));
log.warn("重试异常:" + context.getRetryCount(), throwable); log.warn("重试异常:" + context.getRetryCount(), throwable);
} }
@ -52,26 +54,67 @@ public class XingHuoChatClient implements ChatClient, StreamingChatClient {
this.xingHuoApi = xingHuoApi; this.xingHuoApi = xingHuoApi;
} }
public XingHuoChatClient(XingHuoApi xingHuoApi, XingHuoOptions xingHuoOptions) {
this.xingHuoApi = xingHuoApi;
this.xingHuoOptions = xingHuoOptions;
}
@Override @Override
public ChatResponse call(Prompt prompt) { public ChatResponse call(Prompt prompt) {
return this.retryTemplate.execute(ctx -> { return this.retryTemplate.execute(ctx -> {
// ctx 会有重试的信息 // ctx 会有重试的信息
// 获取 chatOptions 属性
XingHuoOptions chatOptions = this.getChatOptions(prompt);
// 创建 request 请求stream模式需要供应商支持 // 创建 request 请求stream模式需要供应商支持
XingHuoChatCompletionRequest request = this.createRequest(prompt, false); XingHuoChatCompletionRequest request = this.createRequest(prompt, chatOptions);
// 调用 callWithFunctionSupport 发送请求 // 调用 callWithFunctionSupport 发送请求
ResponseEntity<XingHuoChatCompletion> response = xingHuoApi.chatCompletionEntity(request); ResponseEntity<XingHuoChatCompletion> response = xingHuoApi.chatCompletionEntity(request, chatOptions.getDomain());
// 获取结果封装 ChatResponse // 获取结果封装 ChatResponse
return new ChatResponse(List.of(new Generation(response.getBody().getPayload().getChoices().getText().get(0).getContent()))); return new ChatResponse(List.of(new Generation(response.getBody().getPayload().getChoices().getText().get(0).getContent())));
}); });
} }
private XingHuoChatCompletionRequest createRequest(Prompt prompt, boolean b) { @Override
public Flux<ChatResponse> stream(Prompt prompt) {
// 获取 chatOptions 属性
XingHuoOptions chatOptions = this.getChatOptions(prompt);
// 创建 request 请求stream模式需要供应商支持
XingHuoChatCompletionRequest request = this.createRequest(prompt, chatOptions);
// 发送请求
Flux<XingHuoChatCompletion> response = this.xingHuoApi.chatCompletionStream(request, chatOptions.getDomain());
return response.map(res -> {
String content = res.getPayload().getChoices().getText().stream()
.map(item -> item.getContent()).collect(Collectors.joining());
return new ChatResponse(List.of(new Generation(content)));
});
}
private XingHuoOptions getChatOptions(Prompt prompt) {
// 两个都为null 则没有配置文件
if (xingHuoOptions == null && prompt.getOptions() == null) {
throw new ChatException("ChatOptions 未配置参数!");
}
// 优先使用 Prompt 里面的 ChatOptions
ChatOptions options = xingHuoOptions;
if (prompt.getOptions() != null) {
options = (ChatOptions) prompt.getOptions();
}
// Prompt 里面是一个 ChatOptions用户可以随意传入这里做一下判断
if (!(options instanceof XingHuoOptions)) {
throw new ChatException("Prompt 传入的不是 XingHuoOptions!");
}
return (XingHuoOptions) options;
}
private XingHuoChatCompletionRequest createRequest(Prompt prompt, XingHuoOptions xingHuoOptions) {
// 创建 header // 创建 header
XingHuoChatCompletionRequest.Header header = new XingHuoChatCompletionRequest.Header().setApp_id(xingHuoApi.getAppId()); XingHuoChatCompletionRequest.Header header = new XingHuoChatCompletionRequest.Header().setApp_id(xingHuoApi.getAppId());
// 创建 params // 创建 params
XingHuoChatCompletionRequest.Parameter parameter = new XingHuoChatCompletionRequest.Parameter() XingHuoChatCompletionRequest.Parameter.Chat chatParameter = new XingHuoChatCompletionRequest.Parameter.Chat();
.setChat(new XingHuoChatCompletionRequest.Parameter.Chat().setDomain(xingHuoApi.getUseChatModel().getValue())); BeanUtil.copyProperties(xingHuoOptions, chatParameter);
chatParameter.setDomain(xingHuoOptions.getDomain().getValue());
XingHuoChatCompletionRequest.Parameter parameter = new XingHuoChatCompletionRequest.Parameter().setChat(chatParameter);
// 创建 payload text 信息 // 创建 payload text 信息
XingHuoChatCompletionRequest.Payload.Message.Text text = new XingHuoChatCompletionRequest.Payload.Message.Text(); XingHuoChatCompletionRequest.Payload.Message.Text text = new XingHuoChatCompletionRequest.Payload.Message.Text();
text.setRole(XingHuoChatCompletionRequest.Payload.Message.Text.Role.USER.getName()); text.setRole(XingHuoChatCompletionRequest.Payload.Message.Text.Role.USER.getName());
@ -85,17 +128,4 @@ public class XingHuoChatClient implements ChatClient, StreamingChatClient {
.setParameter(parameter) .setParameter(parameter)
.setPayload(payload); .setPayload(payload);
} }
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
// 创建 request 请求stream模式需要供应商支持
XingHuoChatCompletionRequest request = this.createRequest(prompt, false);
// 发送请求
Flux<XingHuoChatCompletion> response = this.xingHuoApi.chatCompletionStream(request);
return response.map(res -> {
String content = res.getPayload().getChoices().getText().stream()
.map(item -> item.getContent()).collect(Collectors.joining());
return new ChatResponse(List.of(new Generation(content)));
});
}
} }

View File

@ -0,0 +1,76 @@
package cn.iocoder.yudao.framework.ai.chatxinghuo;
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* 讯飞星火
*
* author: fansili
* time: 2024/3/16 20:29
*/
@Data
@Accessors(chain = true)
public class XingHuoOptions implements ChatOptions {
/**
* https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
*
* 指定访问的领域:
* general指向V1.5版本;
* generalv2指向V2版本;
* generalv3指向V3版本;
* generalv3.5指向V3.5版本;
* 注意不同的取值对应的url也不一样
*/
private XingHuoChatModel domain = XingHuoChatModel.XING_HUO_3_5;
/**
* 取值范围 (01] 默认值0.5
*/
private Float temperature;
/**
* V1.5取值为[1,4096]
* V2.0V3.0和V3.5取值为[1,8192]默认为2048
*/
private Integer max_tokens;
/**
* 取值为[16],默认为4
*/
private Integer top_k;
/**
* 需要保障用户下的唯一性用于关联用户会话
*/
private String chat_id;
@Override
public Float getTemperature() {
return null;
}
@Override
public void setTemperature(Float temperature) {
}
@Override
public Float getTopP() {
return null;
}
@Override
public void setTopP(Float topP) {
}
@Override
public Integer getTopK() {
return null;
}
@Override
public void setTopK(Integer topK) {
}
}

View File

@ -45,9 +45,24 @@ public class XingHuoChatCompletionRequest {
* generalv3.5指向V3.5版本; * generalv3.5指向V3.5版本;
* 注意不同的取值对应的url也不一样 * 注意不同的取值对应的url也不一样
*/ */
private String domain = "general"; private String domain = "generalv3.5";
private Double temperature = 0.5; /**
private Integer max_tokens = 2048; * 取值范围 (01] 默认值0.5
*/
private Float temperature;
/**
* V1.5取值为[1,4096]
* V2.0V3.0和V3.5取值为[1,8192]默认为2048
*/
private Integer max_tokens;
/**
* 取值为[16],默认为4
*/
private Integer top_k;
/**
* 需要保障用户下的唯一性用于关联用户会话
*/
private String chat_id;
} }
} }

View File

@ -1,14 +0,0 @@
package cn.iocoder.yudao.framework.ai.chatxinghuo.exception;
/**
* 讯飞星火 exception
*
* author: fansili
* time: 2024/3/11 10:22
*/
public class XingHuoApiException extends RuntimeException {
public XingHuoApiException(String message) {
super(message);
}
}

View File

@ -23,9 +23,9 @@ public class QianWenChatClientTests {
@Before @Before
public void setup() { public void setup() {
QianWenApi qianWenApi = new QianWenApi( QianWenApi qianWenApi = new QianWenApi(
"LTAI5tNTVhXW4fLKUjMrr98z", "",
"ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP", "",
"f0c1088824594f589c8f10567ccd929f_p_efm", "",
null null
); );
qianWenChatClient = new QianWenChatClient( qianWenChatClient = new QianWenChatClient(

View File

@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoApi; import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoApi;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatClient; import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoOptions;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -28,9 +29,9 @@ public class XingHuoChatClientTests {
new XingHuoApi( new XingHuoApi(
"13c8cca6", "13c8cca6",
"cb6415c19d6162cda07b47316fcb0416", "cb6415c19d6162cda07b47316fcb0416",
"Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh", "Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh"
XingHuoChatModel.XING_HUO_3_5 ),
) new XingHuoOptions().setDomain(XingHuoChatModel.XING_HUO_3_5)
); );
} }