【优化】AI 写作:格式、语气、语言等抽枚举

This commit is contained in:
xiaoxin 2024-07-03 14:53:08 +08:00
parent b80a76d115
commit 6dfbfe5167
8 changed files with 221 additions and 31 deletions

View File

@ -0,0 +1,44 @@
package cn.iocoder.yudao.module.ai.enums.write;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
@AllArgsConstructor
@Getter
public enum AiLanguageEnum implements IntArrayValuable {
AUTO(1, "自动"),
CHINESE(2, "中文"),
ENGLISH(3, "英文"),
KOREAN(4, "韩语"),
JAPANESE(5, "日语");
/**
* Language code
*/
private final Integer language;
/**
* Language name
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiLanguageEnum::getLanguage).toArray();
@Override
public int[] array() {
return ARRAYS;
}
public static AiLanguageEnum valueOfLanguage(Integer language) {
for (AiLanguageEnum languageEnum : AiLanguageEnum.values()) {
if (languageEnum.getLanguage().equals(language)) {
return languageEnum;
}
}
throw new IllegalArgumentException("未知语言: " + language);
}
}

View File

@ -0,0 +1,53 @@
package cn.iocoder.yudao.module.ai.enums.write;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 写作类型的枚举
*
* @author xiaoxin
*/
@AllArgsConstructor
@Getter
public enum AiWriteFormatEnum implements IntArrayValuable {
AUTO(1, "自动"),
EMAIL(2, "电子邮件"),
MESSAGE(3, "消息"),
COMMENT(4, "评论"),
PARAGRAPH(5, "段落"),
ARTICLE(6, "文章"),
BLOG_POST(7, "博客文章"),
IDEA(8, "想法"),
OUTLINE(9, "大纲");
/**
* 格式
*/
private final Integer format;
/**
* 格式名
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteFormatEnum::getFormat).toArray();
@Override
public int[] array() {
return ARRAYS;
}
public static AiWriteFormatEnum valueOfFormat(Integer format) {
for (AiWriteFormatEnum formatEnum : AiWriteFormatEnum.values()) {
if (formatEnum.getFormat().equals(format)) {
return formatEnum;
}
}
throw new IllegalArgumentException("未知格式: " + format);
}
}

View File

@ -0,0 +1,47 @@
package cn.iocoder.yudao.module.ai.enums.write;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 写作类型的枚举
*
* @author xiaoxin
*/
@AllArgsConstructor
@Getter
public enum AiWriteLengthEnum implements IntArrayValuable {
AUTO(1, "自动"),
SHORT(2, ""),
MEDIUM(3, ""),
LONG(4, "");
/**
* 长度
*/
private final Integer length;
/**
* 长度名
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteLengthEnum::getLength).toArray();
@Override
public int[] array() {
return ARRAYS;
}
public static AiWriteLengthEnum valueOfLength(Integer length) {
for (AiWriteLengthEnum lengthEnum : AiWriteLengthEnum.values()) {
if (lengthEnum.getLength().equals(length)) {
return lengthEnum;
}
}
throw new IllegalArgumentException("未知长度: " + length);
}
}

View File

@ -0,0 +1,46 @@
package cn.iocoder.yudao.module.ai.enums.write;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
@AllArgsConstructor
@Getter
public enum AiWriteToneEnum implements IntArrayValuable {
AUTO(1, "自动"),
FRIENDLY(2, "友善"),
CASUAL(3, "随意"),
KIND(4, "友好"),
PROFESSIONAL(5, "专业"),
HUMOROUS(6, "谈谐"),
INTERESTING(7, "有趣"),
FORMAL(8, "正式");
/**
* 语气
*/
private final Integer tone;
/**
* 语气名
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteToneEnum::getTone).toArray();
@Override
public int[] array() {
return ARRAYS;
}
public static AiWriteToneEnum valueOfTone(Integer tone) {
for (AiWriteToneEnum toneEnum : AiWriteToneEnum.values()) {
if (toneEnum.getTone().equals(tone)) {
return toneEnum;
}
}
throw new IllegalArgumentException("未知语气: " + tone);
}
}

View File

@ -28,7 +28,7 @@ public class AiWriteController {
@PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@PermitAll @PermitAll
@Operation(summary = "写作生成(流式)", description = "流式返回,响应较快") @Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
public Flux<CommonResult<String>> generateComposition(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) { public Flux<CommonResult<String>> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
return writeService.generateWriteContent(generateReqVO, getLoginUserId()); return writeService.generateWriteContent(generateReqVO, getLoginUserId());
} }
} }

View File

@ -1,36 +1,36 @@
package cn.iocoder.yudao.module.ai.controller.admin.write.vo; package cn.iocoder.yudao.module.ai.controller.admin.write.vo;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
@Schema(description = "管理后台 - AI 写作生成 Request VO") @Schema(description = "管理后台 - AI 写作生成 Request VO")
@Data @Data
public class AiWriteGenerateReqVO { public class AiWriteGenerateReqVO {
@Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "田忌赛马") @Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1.撰写:田忌赛马2.回复:不批")
private String contentPrompt; private String prompt;
@Schema(description = "原文", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "领导我要辞职") @Schema(description = "原文", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "领导我要辞职")
private String originalContent; private String originalContent;
@Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "中等") @Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotBlank(message = "长度不能为空") @NotNull(message = "长度不能为空")
private String length; private Integer length;
@Schema(description = "格式", requiredMode = Schema.RequiredMode.REQUIRED, example = "文章") @Schema(description = "格式", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotBlank(message = "格式不能为空") @NotNull(message = "格式不能为空")
private String format; private Integer format;
@Schema(description = "语气", requiredMode = Schema.RequiredMode.REQUIRED, example = "随意") @Schema(description = "语气", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotBlank(message = "语气不能为空") @NotNull(message = "语气不能为空")
private String tone; private Integer tone;
@Schema(description = "语言", requiredMode = Schema.RequiredMode.REQUIRED, example = "中文") @Schema(description = "语言", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotBlank(message = "语言不能为空") @NotNull(message = "语言不能为空")
private String language; private Integer language;
@Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Integer writeType; //参见 AiWriteTypeEnum 枚举 private Integer type; //参见 AiWriteTypeEnum 枚举
} }

View File

@ -32,12 +32,12 @@ public class AiWriteDO extends BaseDO {
* <p> * <p>
* 枚举 {@link AiWriteTypeEnum} * 枚举 {@link AiWriteTypeEnum}
*/ */
private Integer writeType; private Integer type;
/** /**
* 生成内容提示 * 生成内容提示
*/ */
private String contentPrompt; private String prompt;
/** /**
* 生成的内容 * 生成的内容
@ -52,22 +52,22 @@ public class AiWriteDO extends BaseDO {
/** /**
* 长度提示词 * 长度提示词
*/ */
private String length; private Integer length;
/** /**
* 格式提示词 * 格式提示词
*/ */
private String format; private Integer format;
/** /**
* 语气提示词 * 语气提示词
*/ */
private String tone; private Integer tone;
/** /**
* 语言提示词 * 语言提示词
*/ */
private String language; private Integer language;
/** /**
* 模型 * 模型

View File

@ -11,7 +11,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; import cn.iocoder.yudao.module.ai.enums.write.*;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
@ -84,17 +84,17 @@ public class AiWriteServiceImpl implements AiWriteService {
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
String template; String template;
Integer writeType = generateReqVO.getWriteType(); Integer writeType = generateReqVO.getType();
String format = generateReqVO.getFormat(); String format = AiWriteFormatEnum.valueOfFormat(generateReqVO.getFormat()).getName();
String tone = generateReqVO.getTone(); String tone = AiWriteToneEnum.valueOfTone(generateReqVO.getTone()).getName();
String language = generateReqVO.getLanguage(); String language = AiLanguageEnum.valueOfLanguage(generateReqVO.getLanguage()).getName();
String length = generateReqVO.getLength(); String length = AiWriteLengthEnum.valueOfLength(generateReqVO.getLength()).getName();
if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) { if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
return StrUtil.format(template, generateReqVO.getContentPrompt(), format, tone, language, length); return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
} else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) { } else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
template = "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; template = "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
return StrUtil.format(template, generateReqVO.getOriginalContent(), generateReqVO.getContentPrompt(), format, tone, language, length); return StrUtil.format(template, generateReqVO.getOriginalContent(), generateReqVO.getPrompt(), format, tone, language, length);
} else { } else {
throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", writeType)); throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", writeType));
} }