1、增加 chat service

2、可动态传入 modal,选择模型
This commit is contained in:
cherishsince 2024-04-14 17:24:02 +08:00
parent 7024c5ab60
commit ef701167b7
6 changed files with 154 additions and 9 deletions

View File

@ -25,4 +25,13 @@ public enum AiClientNameEnum {
private String name;
private String message;
public static AiClientNameEnum valueOfName(String name) {
for (AiClientNameEnum nameEnum : AiClientNameEnum.values()) {
if (nameEnum.getName().equals(name)) {
return nameEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + name);
}
}

View File

@ -2,20 +2,20 @@ package cn.iocoder.yudao.module.ai.controller;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.config.AiClient;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
import cn.iocoder.yudao.module.ai.service.ChatService;
import cn.iocoder.yudao.module.ai.vo.ChatReq;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
@ -43,17 +43,16 @@ public class ChatController {
@Operation(summary = "聊天-chat", description = "这个一般等待时间比较久,需要全部完成才会返回!")
@GetMapping("/chat")
public CommonResult<String> chat(@RequestParam("prompt") String prompt) {
ChatResponse callRes = aiClient.call(new Prompt(prompt), AiClientNameEnum.QIAN_WEN.getName());
return CommonResult.success(callRes.getResult().getOutput().getContent());
public CommonResult<String> chat(@Validated @ModelAttribute ChatReq req) {
return CommonResult.success(chatService.chat(req));
}
// TODO @芋艿调用这个方法异常Unable to handle the Spring Security Exception because the response is already committed.
@Operation(summary = "聊天-stream", description = "这里跟通义千问一样采用的是 Server-Sent Events (SSE) 通讯模式")
@GetMapping(value = "/chatStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter chatStream(@RequestParam("prompt") String prompt) {
public SseEmitter chatStream(@Validated @ModelAttribute ChatReq req) {
Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
Flux<ChatResponse> streamResponse = aiClient.stream(new Prompt(prompt), AiClientNameEnum.QIAN_WEN.getName());
Flux<ChatResponse> streamResponse = chatService.chatStream(req);
streamResponse.subscribe(
new Consumer<ChatResponse>() {
@Override

View File

@ -17,7 +17,7 @@ import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/chat-role")
@AllArgsConstructor
public class AiChatRoleController {
public class ChatRoleController {
private final ChatRoleService chatRoleService;

View File

@ -0,0 +1,33 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
import cn.iocoder.yudao.module.ai.vo.ChatReq;
import reactor.core.publisher.Flux;
/**
* 聊天 chat
*
* @author fansili
* @time 2024/4/14 15:55
* @since 1.0
*/
public interface ChatService {
/**
* chat
*
* @param req
* @return
*/
String chat(ChatReq req);
/**
* chat stream
*
* @param req
* @param clientNameEnum
* @return
*/
Flux<ChatResponse> chatStream(ChatReq req);
}

View File

@ -0,0 +1,62 @@
package cn.iocoder.yudao.module.ai.service.impl;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.config.AiClient;
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
import cn.iocoder.yudao.module.ai.service.ChatService;
import cn.iocoder.yudao.module.ai.vo.ChatReq;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
/**
* 聊天 service
*
* @author fansili
* @time 2024/4/14 15:55
* @since 1.0
*/
@Slf4j
@Service
@AllArgsConstructor
public class ChatServiceImpl implements ChatService {
private final AiClient aiClient;
/**
* chat
*
* @param req
* @return
*/
public String chat(ChatReq req) {
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt());
req.setTopK(req.getTopK());
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
// 发送 call 调用
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
return call.getResult().getOutput().getContent();
}
/**
* chat stream
*
* @param req
* @return
*/
@Override
public Flux<ChatResponse> chatStream(ChatReq req) {
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt());
req.setTopK(req.getTopK());
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
return aiClient.stream(prompt, clientNameEnum.getName());
}
}

View File

@ -0,0 +1,42 @@
package cn.iocoder.yudao.module.ai.vo;
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* chat req
*
* @author fansili
* @time 2024/4/14 16:12
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class ChatReq {
@NotNull(message = "提示词不能为空!")
@Size(max = 3000, message = "提示词最大3000个字符!")
@Schema(description = "填入固定值1 issues, 2 pr")
private String prompt;
@Schema(description = "用于控制随机性和多样性的温度参数")
private Float temperature;
@Schema(description = "生成时核采样方法的概率阈值。例如取值为0.8时仅保留累计概率之和大于等于0.8的概率分布中的token\n" +
" * 作为随机采样的候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。\n" +
" * 默认值为0.8。注意取值不要大于等于1\n")
private Float topP;
@Schema(description = "在生成消息时采用的Top-K采样大小表示模型生成回复时考虑的候选项集合的大小")
private Integer topK;
@Schema(description = "ai模型(查看 AiClientNameEnum)")
@NotNull(message = "模型不能为空!")
@Size(max = 30, message = "模型字符最大30个字符!")
private String modal;
}