【代码评审】AI:写作实现

This commit is contained in:
YunaiV 2024-07-03 21:26:38 +08:00
parent 9ddd2eddf8
commit f20c27a7ef
7 changed files with 43 additions and 30 deletions

View File

@ -33,7 +33,7 @@ public interface ErrorCodeConstants {
// ========== API 聊天消息 1-040-004-000 ========== // ========== API 聊天消息 1-040-004-000 ==========
ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!"); ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "Stream 对话异常!"); ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "对话生成异常!");
// ========== API 绘画 1-040-005-000 ========== // ========== API 绘画 1-040-005-000 ==========
@ -48,6 +48,6 @@ public interface ErrorCodeConstants {
// ========== API 写作 1-022-007-000 ========== // ========== API 写作 1-022-007-000 ==========
ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!"); ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!");
ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "Stream 对话异常!"); ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "写作生成异常!");
} }

View File

@ -6,6 +6,7 @@ import lombok.Getter;
import java.util.Arrays; import java.util.Arrays;
// TODO @xin写作的几个不用枚举类哈直接搞字段就好了AiWriteTypeEnum 还是需要的哈
@AllArgsConstructor @AllArgsConstructor
@Getter @Getter
public enum AiLanguageEnum implements IntArrayValuable { public enum AiLanguageEnum implements IntArrayValuable {

View File

@ -1,5 +1,7 @@
package cn.iocoder.yudao.module.ai.controller.admin.write.vo; package cn.iocoder.yudao.module.ai.controller.admin.write.vo;
import cn.iocoder.yudao.framework.common.validation.InEnum;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
@ -8,6 +10,11 @@ import lombok.Data;
@Data @Data
public class AiWriteGenerateReqVO { public class AiWriteGenerateReqVO {
@Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@InEnum(AiWriteTypeEnum.class)
private Integer type;
// TODO @xin如果非必填可以不用写 requiredMode
@Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1.撰写田忌赛马2.回复:不批") @Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1.撰写田忌赛马2.回复:不批")
private String prompt; private String prompt;
@ -30,7 +37,4 @@ public class AiWriteGenerateReqVO {
@NotNull(message = "语言不能为空") @NotNull(message = "语言不能为空")
private Integer language; private Integer language;
@Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Integer type; //参见 AiWriteTypeEnum 枚举
} }

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.write; package cn.iocoder.yudao.module.ai.dal.dataobject.write;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import com.baomidou.mybatisplus.annotation.IdType; import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
@ -34,6 +35,18 @@ public class AiWriteDO extends BaseDO {
*/ */
private Integer type; private Integer type;
/**
* 模型
*/
private String model;
/**
* 平台
*
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
/** /**
* 生成内容提示 * 生成内容提示
*/ */
@ -69,16 +82,6 @@ public class AiWriteDO extends BaseDO {
*/ */
private Integer language; private Integer language;
/**
* 模型
*/
private String model;
/**
* 平台
*/
private String platform;
/** /**
* 错误信息 * 错误信息
*/ */

View File

@ -11,7 +11,6 @@ import reactor.core.publisher.Flux;
*/ */
public interface AiWriteService { public interface AiWriteService {
/** /**
* 生成写作内容 * 生成写作内容
* *
@ -21,5 +20,4 @@ public interface AiWriteService {
*/ */
Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId); Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId);
} }

View File

@ -46,13 +46,12 @@ public class AiWriteServiceImpl implements AiWriteService {
@Resource @Resource
private AiChatModelService chatModalService; private AiChatModelService chatModalService;
@Resource @Resource
private AiWriteMapper writeMapper; private AiWriteMapper writeMapper; // TODO @xin上面空一行因为同类之间不要空行非同类空行
@Override @Override
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
//TODO 芋艿 写作的模型配置放哪好 先用千问测试
// 1.1 校验模型 // 1.1 校验模型
// TODO @xin可以约定大于配置先查询某个名字例如说写作助手然后写作助手上面是有个 model 可以使用它
AiChatModelDO model = chatModalService.validateChatModel(14L); AiChatModelDO model = chatModalService.validateChatModel(14L);
StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId()); StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
@ -81,7 +80,6 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
} }
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
String template; String template;
Integer writeType = generateReqVO.getType(); Integer writeType = generateReqVO.getType();

View File

@ -9,12 +9,17 @@ import lombok.Data;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.openai.api.ApiUtils; import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.http.HttpRequest;
import org.springframework.http.HttpStatusCode;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Function;
import java.util.function.Predicate;
/** /**
* Midjourney API * Midjourney API
@ -25,6 +30,16 @@ import java.util.Map;
@Slf4j @Slf4j
public class MidjourneyApi { public class MidjourneyApi {
private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
private final Function<Object, Function<ClientResponse, Mono<? extends Throwable>>> EXCEPTION_FUNCTION =
reqParam -> response -> response.bodyToMono(String.class).handle((responseBody, sink) -> {
HttpRequest request = response.request();
log.error("[midjourney-api] 调用失败!请求方式:[{}],请求地址:[{}],请求参数:[{}],响应数据: [{}]",
request.getMethod(), request.getURI(), reqParam, responseBody);
sink.error(new IllegalStateException("[midjourney-api] 调用失败!"));
});
private final WebClient webClient; private final WebClient webClient;
/** /**
@ -80,17 +95,11 @@ public class MidjourneyApi {
} }
private String post(String uri, Object body) { private String post(String uri, Object body) {
// 1发送 post 请求
return webClient.post() return webClient.post()
.uri(uri) .uri(uri)
.body(Mono.just(JsonUtils.toJsonString(body)), String.class) .body(Mono.just(JsonUtils.toJsonString(body)), String.class)
.retrieve() .retrieve()
.onStatus(status -> !status.is2xxSuccessful(), .onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(body))
response -> response.bodyToMono(String.class)
.handle((respBody, sink) -> {
log.error("【Midjourney api】调用失败resp: 【{}】", respBody);
sink.error(new IllegalStateException("【Midjourney api】调用失败"));
}))
.bodyToMono(String.class) .bodyToMono(String.class)
.block(); .block();
} }