From 10a94c3ef20aeb6767f3f5327a2552bdfa28c6da Mon Sep 17 00:00:00 2001 From: cherishsince Date: Sat, 27 Apr 2024 18:29:58 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91=E5=A4=84?= =?UTF-8?q?=E7=90=86=E7=99=BE=E5=BA=A6=20system=20=E8=A7=92=E8=89=B2?= =?UTF-8?q?=E5=AE=9A=E5=88=B6=E5=A4=B1=E6=95=88=E9=97=AE=E9=A2=98=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/chatyiyan/YiYanChatClient.java | 62 ++++++++++++++----- .../framework/ai/chatyiyan/YiYanOptions.java | 10 ++- .../framework/ai/chat/YiYanChatTests.java | 31 ++++++++-- 3 files changed, 77 insertions(+), 26 deletions(-) diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanChatClient.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanChatClient.java index fae550da8..d95e9fb99 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanChatClient.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanChatClient.java @@ -2,6 +2,8 @@ package cn.iocoder.yudao.framework.ai.chatyiyan; import cn.hutool.core.bean.BeanUtil; import cn.iocoder.yudao.framework.ai.chat.*; +import cn.iocoder.yudao.framework.ai.chat.messages.Message; +import cn.iocoder.yudao.framework.ai.chat.messages.MessageType; import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi; @@ -9,6 +11,7 @@ import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletion; import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest; import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException; import lombok.extern.slf4j.Slf4j; +import org.jetbrains.annotations.NotNull; import org.springframework.http.ResponseEntity; import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryContext; @@ -18,10 +21,11 @@ import reactor.core.publisher.Flux; import java.time.Duration; import java.util.List; +import java.util.stream.Collectors; /** * 文心一言 - * + *

* author: fansili * time: 2024/3/8 19:11 */ @@ -52,7 +56,9 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient { public void onError(RetryContext context, RetryCallback callback, Throwable throwable) { log.warn("重试异常:" + context.getRetryCount(), throwable); - }; + } + + ; }) .build(); @@ -92,6 +98,42 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient { } private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + // 获取配置 + YiYanOptions useOptions = getYiYanOptions(prompt); + // 创建 request + + // tip: 百度的 system 不在 message 里面 + // tip:百度的 message 只有 user 和 assistant + // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t + + // 获取 user 和 assistant + List messageList = prompt.getInstructions().stream() + // 过滤 system + .filter(msg -> MessageType.SYSTEM != msg.getMessageType()) + .map(msg -> new YiYanChatCompletionRequest.Message() + .setRole(msg.getMessageType().getValue()) + .setContent(msg.getContent()) + ).toList(); + // 获取 system + String systemPrompt = prompt.getInstructions().stream() + .filter(msg -> MessageType.SYSTEM == msg.getMessageType()) + .map(Message::getContent) + .collect(Collectors.joining()); + + YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList); + // 复制 qianWenOptions 属性取 request(这里 options 属性和 request 基本保持一致) + // top: 由于遵循 spring-ai规范,支持在构建client的时候传入默认的 chatOptions + BeanUtil.copyProperties(useOptions, request); + request.setTop_p(useOptions.getTopP()); + request.setMax_output_tokens(useOptions.getMaxOutputTokens()); + request.setTemperature(useOptions.getTemperature()); + request.setSystem(systemPrompt); + // 设置 stream + request.setStream(stream); + return request; + } + + private @NotNull YiYanOptions getYiYanOptions(Prompt prompt) { // 两个都为null 则没有配置文件 if (yiYanOptions == null && prompt.getOptions() == null) { throw new ChatException("ChatOptions 未配置参数!"); @@ -106,19 +148,7 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient { throw new ChatException("Prompt 传入的不是 YiYanOptions!"); } // 转换 YiYanOptions - YiYanOptions qianWenOptions = (YiYanOptions) options; - // 创建 request - List messageList = prompt.getInstructions().stream().map( - msg -> new YiYanChatCompletionRequest.Message() - .setRole(msg.getMessageType().getValue()) - .setContent(msg.getContent()) - ).toList(); - YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList); - // 复制 qianWenOptions 属性取 request(这里 options 属性和 request 基本保持一致) - // top: 由于遵循 spring-ai规范,支持在构建client的时候传入默认的 chatOptions - BeanUtil.copyProperties(qianWenOptions, request); - // 设置 stream - request.setStream(stream); - return request; + YiYanOptions useOptions = (YiYanOptions) options; + return useOptions; } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanOptions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanOptions.java index 4b1bc3fe6..84f0ced4c 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanOptions.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanOptions.java @@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.chatyiyan; import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions; import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest; -import com.fasterxml.jackson.annotation.JsonProperty; import lombok.Data; import lombok.experimental.Accessors; @@ -40,7 +39,7 @@ public class YiYanOptions implements ChatOptions { * (2)默认0.8,取值范围 [0, 1.0] * 必填:否 */ - private Float top_p; + private Float topP; /** * 通过对已生成的token增加惩罚,减少重复生成的现象。说明: * (1)值越大表示惩罚越大 @@ -84,7 +83,7 @@ public class YiYanOptions implements ChatOptions { * 指定模型最大输出token数,范围[2, 2048] * 必填:否 */ - private Integer max_output_tokens; + private Integer maxOutputTokens; /** * 指定响应内容的格式,说明: * (1)可选值: @@ -122,12 +121,12 @@ public class YiYanOptions implements ChatOptions { @Override public Float getTopP() { - return top_p; + return topP; } @Override public void setTopP(Float topP) { - this.top_p = topP; + this.topP = topP; } // 百度么有 topK @@ -139,6 +138,5 @@ public class YiYanOptions implements ChatOptions { @Override public void setTopK(Integer topK) { - } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java index 326f52552..df98541fc 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java @@ -1,5 +1,8 @@ package cn.iocoder.yudao.framework.ai.chat; +import cn.iocoder.yudao.framework.ai.chat.messages.Message; +import cn.iocoder.yudao.framework.ai.chat.messages.SystemMessage; +import cn.iocoder.yudao.framework.ai.chat.messages.UserMessage; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient; import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel; @@ -9,11 +12,13 @@ import org.junit.Before; import org.junit.Test; import reactor.core.publisher.Flux; +import java.util.ArrayList; +import java.util.List; import java.util.Scanner; /** * chat 文心一言 - * + *

* author: fansili * time: 2024/3/12 20:59 */ @@ -29,18 +34,36 @@ public class YiYanChatTests { YiYanChatModel.ERNIE4_3_5_8K, 86400 ); - yiYanChatClient = new YiYanChatClient(yiYanApi, new YiYanOptions().setMax_output_tokens(2048)); + YiYanOptions yiYanOptions = new YiYanOptions(); + yiYanOptions.setMaxOutputTokens(2048); + yiYanOptions.setTopP(0.6f); + yiYanOptions.setTemperature(0.85f); + yiYanChatClient = new YiYanChatClient( + yiYanApi, + yiYanOptions + ); } @Test public void callTest() { - ChatResponse call = yiYanChatClient.call(new Prompt("什么编程语言最好?")); + + // tip: 百度的message 有特殊规则(最后一个message为当前请求的信息,前面的message为历史对话信息) + // tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。")); + messages.add(new UserMessage("长沙怎么样?")); + + ChatResponse call = yiYanChatClient.call(new Prompt(messages)); System.err.println(call.getResult()); } @Test public void streamTest() { - Flux fluxResponse = yiYanChatClient.stream(new Prompt("用java帮我写一个快排算法?")); + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。")); + messages.add(new UserMessage("长沙怎么样?")); + + Flux fluxResponse = yiYanChatClient.stream(new Prompt(messages)); fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent())); // 阻止退出 Scanner scanner = new Scanner(System.in);