【调整】调整Ai model

This commit is contained in:
cherishsince 2024-05-07 11:47:58 +08:00
parent f0a1666e84
commit 929f3597fd
10 changed files with 163 additions and 167 deletions

View File

@ -1,26 +0,0 @@
package cn.iocoder.yudao.module.ai.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* ai modal disable
*
* @author fansili
* @time 2024/4/24 20:15
* @since 1.0
*/
@AllArgsConstructor
@Getter
public enum AiChatModalDisableEnum {
NO(0, "未禁用"),
YES(1, "禁用"),
;
private Integer value;
private String name;
}

View File

@ -2,19 +2,19 @@ package cn.iocoder.yudao.module.ai.controller.admin.model;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
// TODO @fan调整下接口相关 vo 的命名等等modal => model
/**
* ai 模型
*
@ -24,7 +24,7 @@ import org.springframework.web.multipart.MultipartFile;
*/
@Tag(name = "A6-AI模型")
@RestController
@RequestMapping("/ai/chat")
@RequestMapping("/ai/chat/modal")
@Slf4j
@AllArgsConstructor
public class AiChatModalController {
@ -32,37 +32,29 @@ public class AiChatModalController {
private final AiChatModalService aiChatModalService;
@Operation(summary = "ai模型 - 模型列表")
@GetMapping("/modal/list")
public PageResult<AiChatModalListRes> list(@ModelAttribute AiChatModalListReq req) {
@GetMapping("/list")
public PageResult<AiChatModalListRes> list(@ModelAttribute AiChatModalListReqVO req) {
return aiChatModalService.list(req);
}
@Operation(summary = "ai模型 - 添加")
@PutMapping("/modal")
public CommonResult<Void> add(@RequestBody @Validated AiChatModalAddReq req) {
@PutMapping("/add")
public CommonResult<Void> add(@RequestBody @Validated AiChatModalAddReqVO req) {
aiChatModalService.add(req);
return CommonResult.success(null);
}
@Operation(summary = "ai模型 - 模型照片上传")
@PostMapping("/modal/{id}/updateImage")
public CommonResult<Void> updateImage(@PathVariable("id") Long id,
MultipartFile file) {
// todo yunai 文件上传这里放哪里
return CommonResult.success(null);
}
@Operation(summary = "ai模型 - 修改")
@PostMapping("/modal/{id}")
public CommonResult<Void> update(@PathVariable Long id,
@RequestBody @Validated AiChatModalAddReq req) {
@PostMapping("/update")
public CommonResult<Void> update(@RequestParam("id") Long id,
@RequestBody @Validated AiChatModalAddReqVO req) {
aiChatModalService.update(id, req);
return CommonResult.success(null);
}
@Operation(summary = "ai模型 - 删除")
@DeleteMapping("/modal/{id}")
public CommonResult<Void> delete(@PathVariable Long id) {
@DeleteMapping("/delete")
public CommonResult<Void> delete(@RequestParam("id") Long id) {
aiChatModalService.delete(id);
return CommonResult.success(null);
}

View File

@ -6,8 +6,6 @@ import jakarta.validation.constraints.Size;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.Map;
/**
* ai chat modal
*
@ -17,32 +15,40 @@ import java.util.Map;
*/
@Data
@Accessors(chain = true)
public class AiChatModalAddReq {
public class AiChatModalAddReqVO {
@Schema(description = "API 秘钥编号")
@Size(max = 32, message = "API 秘钥编号最大32个字符")
@NotNull(message = "API 秘钥编号不能为空!")
private Long keyId;
@Schema(description = "模型名字")
@Size(max = 60, message = "模型名字最大60个字符")
@NotNull(message = "模型名字不能为空!")
private String name;
@Schema(description = "模型类型(qianwen、yiyan、xinghuo、openai)")
@Size(max = 32, message = "模型类型最大32个字符")
@NotNull(message = "model模型不能为空!")
private String model;
@Size(max = 32, message = "模型平台最大32个字符")
@Schema(description = "模型平台 参考 AiPlatformEnum")
@NotNull(message = "平台不能为空!")
private String platform;
@Schema(description = "模型类型(qianwen、yiyan、xinghuo、openai)")
@Size(max = 32, message = "模型类型最大32个字符")
@NotNull(message = "modal模型不能为空!")
private String modal;
@Schema(description = "模型照片")
@Size(max = 256, message = "模型照片地址最大256个字符")
private String imageUrl;
@Schema(description = "排序")
@NotNull(message = "sort排序不能为空!")
private Integer sort;
@Schema(description = "模型配置JSON")
// @Size(max = 1024, message = "模型配置最大1024个字符")
private Map<String, Object> config;
// ========== 会话配置 ==========
@Schema(description = "温度参数")
private Integer temperature;
@Schema(description = "单条回复的最大 Token 数量")
private Integer maxTokens;
@Schema(description = "上下文的最大 Message 数量")
private Integer maxContexts;
}

View File

@ -14,7 +14,7 @@ import lombok.experimental.Accessors;
*/
@Data
@Accessors(chain = true)
public class AiChatModalListReq extends PageParam {
public class AiChatModalListReqVO extends PageParam {
@Schema(description = "名字搜搜")
private String search;

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.Size;
import lombok.Data;
import lombok.experimental.Accessors;
@ -15,27 +16,35 @@ import lombok.experimental.Accessors;
@Accessors(chain = true)
public class AiChatModalRes {
@Schema(description = "id")
@Schema(description = "编号")
private Long id;
@Schema(description = "模型平台 参考 AiPlatformEnum")
private String platform;
@Schema(description = "模型类型 参考 YiYanChatModel、XingHuoChatModel")
private String modal;
@Schema(description = "API 秘钥编号")
private Long keyId;
@Schema(description = "模型名字")
private String name;
@Schema(description = "模型照片")
private String image;
@Schema(description = "模型类型(qianwen、yiyan、xinghuo、openai)")
private String model;
@Schema(description = "禁用 0、正常 1、禁用")
private Integer disable;
@Size(max = 32, message = "模型平台最大32个字符")
private String platform;
@Schema(description = "排序 asc 排序")
@Schema(description = "排序")
private Integer sort;
@Schema(description = "modal 配置")
private String config;
@Schema(description = "状态")
private Integer status;
// ========== 会话配置 ==========
@Schema(description = "温度参数")
private Integer temperature;
@Schema(description = "单条回复的最大 Token 数量")
private Integer maxTokens;
@Schema(description = "上下文的最大 Message 数量")
private Integer maxContexts;
}

View File

@ -0,0 +1,63 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* ai chat modal
*
* @author fansili
* @time 2024/4/24 19:47
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiChatModalUpdateReqVO {
@Schema(description = "编号")
@Size(max = 32, message = "编号最大32个字符")
@NotNull(message = "编号不能为空")
private Long id;
@Schema(description = "API 秘钥编号")
@Size(max = 32, message = "API 秘钥编号最大32个字符")
@NotNull(message = "API 秘钥编号不能为空!")
private Long keyId;
@Schema(description = "模型名字")
@Size(max = 60, message = "模型名字最大60个字符")
@NotNull(message = "模型名字不能为空!")
private String name;
@Schema(description = "模型类型(qianwen、yiyan、xinghuo、openai)")
@Size(max = 32, message = "模型类型最大32个字符")
@NotNull(message = "model模型不能为空!")
private String model;
@Size(max = 32, message = "模型平台最大32个字符")
@Schema(description = "模型平台 参考 AiPlatformEnum")
@NotNull(message = "平台不能为空!")
private String platform;
@Schema(description = "排序")
@NotNull(message = "sort排序不能为空!")
private Integer sort;
@Schema(description = "状态")
@NotNull(message = "状态不能为空!")
private Integer status;
// ========== 会话配置 ==========
@Schema(description = "温度参数")
private Integer temperature;
@Schema(description = "单条回复的最大 Token 数量")
private Integer maxTokens;
@Schema(description = "上下文的最大 Message 数量")
private Integer maxContexts;
}

View File

@ -1,12 +1,10 @@
package cn.iocoder.yudao.module.ai.convert;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import org.mapstruct.Mapper;
import org.mapstruct.Mapping;
import org.mapstruct.Mappings;
import org.mapstruct.factory.Mappers;
import java.util.List;
@ -37,10 +35,7 @@ public interface AiChatModalConvert {
* @param req
* @return
*/
@Mappings({
@Mapping(target = "config", ignore = true)
})
AiChatModalDO convertAiChatModalDO(AiChatModalAddReq req);
AiChatModalDO convertAiChatModalDO(AiChatModalAddReqVO req);
/**

View File

@ -1,10 +1,11 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
/**
* ai modal
@ -21,14 +22,14 @@ public interface AiChatModalService {
* @param req
* @return
*/
PageResult<AiChatModalListRes> list(AiChatModalListReq req);
PageResult<AiChatModalListRes> list(AiChatModalListReqVO req);
/**
* ai modal - 添加
*
* @param req
*/
void add(AiChatModalAddReq req);
void add(AiChatModalAddReqVO req);
/**
* ai modal - 更新
@ -36,7 +37,7 @@ public interface AiChatModalService {
* @param id
* @param req
*/
void update(Long id, AiChatModalAddReq req);
void update(Long id, AiChatModalAddReqVO req);
/**
* ai modal - 删除
@ -53,6 +54,14 @@ public interface AiChatModalService {
*/
AiChatModalRes getChatModalOfValidate(Long modalId);
/**
* 校验 - 是否存在
*
* @param id
* @return
*/
AiChatModalDO validateExists(Long id);
/**
* 校验 - 校验是否可用
*

View File

@ -3,26 +3,20 @@ package cn.iocoder.yudao.module.ai.service.impl;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.validation.ValidationUtil;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.convert.AiChatModalConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalChatConfigVO;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalDallConfigVO;
import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
import cn.iocoder.yudao.module.ai.convert.AiChatModalConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import jakarta.validation.ConstraintViolation;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@ -46,10 +40,10 @@ public class AiChatModalServiceImpl implements AiChatModalService {
private final AiChatModalMapper aiChatModalMapper;
@Override
public PageResult<AiChatModalListRes> list(AiChatModalListReq req) {
public PageResult<AiChatModalListRes> list(AiChatModalListReqVO req) {
LambdaQueryWrapperX<AiChatModalDO> queryWrapperX = new LambdaQueryWrapperX<>();
// 查询的都是未禁用的模型
queryWrapperX.eq(AiChatModalDO::getDisable, AiChatModalDisableEnum.NO.getValue());
queryWrapperX.eq(AiChatModalDO::getStatus, CommonStatusEnum.ENABLE.getStatus());
// search
if (!StrUtil.isBlank(req.getSearch())) {
queryWrapperX.like(AiChatModalDO::getName, req.getSearch().trim());
@ -64,39 +58,26 @@ public class AiChatModalServiceImpl implements AiChatModalService {
}
@Override
public void add(AiChatModalAddReq req) {
public void add(AiChatModalAddReqVO req) {
// 校验 platformtype
validatePlatform(req.getPlatform());
validateModal(req.getPlatform(), req.getModal());
// 转换config
AiChatModalConfigVO aiChatModalConfigVO = convertConfig(req);
// 校验 modal config
validateModalConfig(aiChatModalConfigVO);
// 转换 do
AiChatModalDO insertChatModalDO = AiChatModalConvert.INSTANCE.convertAiChatModalDO(req);
// 设置默认属性
insertChatModalDO.setDisable(AiChatModalDisableEnum.NO.getValue());
insertChatModalDO.setConfig(JsonUtils.toJsonString(aiChatModalConfigVO));
insertChatModalDO.setStatus(CommonStatusEnum.ENABLE.getStatus());
// 保存数据库
aiChatModalMapper.insert(insertChatModalDO);
}
@Override
public void update(Long id, AiChatModalAddReq req) {
// 校验 platformtype
public void update(Long id, AiChatModalAddReqVO req) {
// 校验 platform
validatePlatform(req.getPlatform());
validateModal(req.getPlatform(), req.getModal());
// 转换config
AiChatModalConfigVO aiChatModalConfigVO = convertConfig(req);
// 校验 modal config
validateModalConfig(aiChatModalConfigVO);
// 校验模型是否存在
validateChatModalExists(id);
validateExists(id);
// 转换 updateChatModalDO
AiChatModalDO updateChatModalDO = AiChatModalConvert.INSTANCE.convertAiChatModalDO(req);
updateChatModalDO.setId(id);
updateChatModalDO.setConfig(JsonUtils.toJsonString(aiChatModalConfigVO));
// 更新数据库
aiChatModalMapper.updateById(updateChatModalDO);
}
@ -104,7 +85,7 @@ public class AiChatModalServiceImpl implements AiChatModalService {
@Override
public void delete(Long id) {
// 检查 modal 是否存在
validateChatModalExists(id);
validateExists(id);
// 删除 delete
aiChatModalMapper.deleteById(id);
}
@ -112,19 +93,19 @@ public class AiChatModalServiceImpl implements AiChatModalService {
@Override
public AiChatModalRes getChatModalOfValidate(Long modalId) {
// 检查 modal 是否存在
AiChatModalDO aiChatModalDO = validateChatModalExists(modalId);
AiChatModalDO aiChatModalDO = validateExists(modalId);
return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);
}
@Override
public void validateAvailable(AiChatModalRes chatModal) {
// 对话模型是否可用
if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
if (CommonStatusEnum.ENABLE.getStatus().equals(chatModal.getStatus())) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
}
}
private AiChatModalDO validateChatModalExists(Long id) {
public AiChatModalDO validateExists(Long id) {
AiChatModalDO aiChatModalDO = aiChatModalMapper.selectById(id);
if (aiChatModalDO == null) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_NOT_EXIST);
@ -132,23 +113,6 @@ public class AiChatModalServiceImpl implements AiChatModalService {
return aiChatModalDO;
}
private void validateModal(String platform, String modal) {
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(platform);
try {
if (AiPlatformEnum.QIAN_WEN == platformEnum) {
QianWenChatModal.valueOfModel(modal);
} else if (AiPlatformEnum.XING_HUO == platformEnum) {
XingHuoChatModel.valueOfModel(modal);
} else if (AiPlatformEnum.YI_YAN == platformEnum) {
YiYanChatModel.valueOfModel(modal);
} else {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_NOT_SUPPORTED_MODAL, platform);
}
} catch (IllegalArgumentException e) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_CONFIG_PARAMS_INCORRECT, e.getMessage());
}
}
private void validatePlatform(String platform) {
try {
AiPlatformEnum.valueOfPlatform(platform);
@ -163,20 +127,4 @@ public class AiChatModalServiceImpl implements AiChatModalService {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_CONFIG_PARAMS_INCORRECT, constraintViolation.getMessage());
}
}
private static AiChatModalConfigVO convertConfig(AiChatModalAddReq req) {
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getPlatform());
AiChatModalConfigVO resVo = null;
if (AiPlatformEnum.CHAT_PLATFORM_LIST.contains(platformEnum)) {
resVo = JsonUtils.parseObject(JsonUtils.toJsonString(req.getConfig()), AiChatModalChatConfigVO.class);
} else if (AiPlatformEnum.OPEN_AI_DALL == platformEnum) {
resVo = JsonUtils.parseObject(JsonUtils.toJsonString(req.getConfig()), AiChatModalDallConfigVO.class);
}
if (resVo == null) {
throw new IllegalArgumentException("ai模型中config不能转换! json: " + req.getConfig());
}
resVo.setType(req.getModal());
resVo.setPlatform(req.getPlatform());
return resVo;
}
}

View File

@ -72,10 +72,10 @@ public class AiChatServiceImpl implements AiChatService {
// 校验角色是否公开
aiChatRoleService.validateIsPublic(aiChatRoleDO);
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModel());
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
chatModal.getModel(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
String content = null;
try {
@ -96,7 +96,7 @@ public class AiChatServiceImpl implements AiChatService {
} finally {
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
chatModal.getModel(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
}
return new AiChatMessageRespVO().setContent(content);
@ -150,11 +150,11 @@ public class AiChatServiceImpl implements AiChatService {
// 保存 chat message
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
chatModal.getModel(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
// 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModel());
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
@ -183,7 +183,7 @@ public class AiChatServiceImpl implements AiChatService {
sseEmitter.complete();
// 保存 chat message
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(),
chatModal.getModel(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
}