【调整】调整Ai model

This commit is contained in:
cherishsince 2024-05-07 11:47:58 +08:00
parent f0a1666e84
commit 929f3597fd
10 changed files with 163 additions and 167 deletions

View File

@ -1,26 +0,0 @@
package cn.iocoder.yudao.module.ai.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* ai modal disable
*
* @author fansili
* @time 2024/4/24 20:15
* @since 1.0
*/
@AllArgsConstructor
@Getter
public enum AiChatModalDisableEnum {
NO(0, "未禁用"),
YES(1, "禁用"),
;
private Integer value;
private String name;
}

View File

@ -2,19 +2,19 @@ package cn.iocoder.yudao.module.ai.controller.admin.model;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.service.AiChatModalService; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReq; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
// TODO @fan调整下接口相关 vo 的命名等等modal => model // TODO @fan调整下接口相关 vo 的命名等等modal => model
/** /**
* ai 模型 * ai 模型
* *
@ -24,7 +24,7 @@ import org.springframework.web.multipart.MultipartFile;
*/ */
@Tag(name = "A6-AI模型") @Tag(name = "A6-AI模型")
@RestController @RestController
@RequestMapping("/ai/chat") @RequestMapping("/ai/chat/modal")
@Slf4j @Slf4j
@AllArgsConstructor @AllArgsConstructor
public class AiChatModalController { public class AiChatModalController {
@ -32,37 +32,29 @@ public class AiChatModalController {
private final AiChatModalService aiChatModalService; private final AiChatModalService aiChatModalService;
@Operation(summary = "ai模型 - 模型列表") @Operation(summary = "ai模型 - 模型列表")
@GetMapping("/modal/list") @GetMapping("/list")
public PageResult<AiChatModalListRes> list(@ModelAttribute AiChatModalListReq req) { public PageResult<AiChatModalListRes> list(@ModelAttribute AiChatModalListReqVO req) {
return aiChatModalService.list(req); return aiChatModalService.list(req);
} }
@Operation(summary = "ai模型 - 添加") @Operation(summary = "ai模型 - 添加")
@PutMapping("/modal") @PutMapping("/add")
public CommonResult<Void> add(@RequestBody @Validated AiChatModalAddReq req) { public CommonResult<Void> add(@RequestBody @Validated AiChatModalAddReqVO req) {
aiChatModalService.add(req); aiChatModalService.add(req);
return CommonResult.success(null); return CommonResult.success(null);
} }
@Operation(summary = "ai模型 - 模型照片上传")
@PostMapping("/modal/{id}/updateImage")
public CommonResult<Void> updateImage(@PathVariable("id") Long id,
MultipartFile file) {
// todo yunai 文件上传这里放哪里
return CommonResult.success(null);
}
@Operation(summary = "ai模型 - 修改") @Operation(summary = "ai模型 - 修改")
@PostMapping("/modal/{id}") @PostMapping("/update")
public CommonResult<Void> update(@PathVariable Long id, public CommonResult<Void> update(@RequestParam("id") Long id,
@RequestBody @Validated AiChatModalAddReq req) { @RequestBody @Validated AiChatModalAddReqVO req) {
aiChatModalService.update(id, req); aiChatModalService.update(id, req);
return CommonResult.success(null); return CommonResult.success(null);
} }
@Operation(summary = "ai模型 - 删除") @Operation(summary = "ai模型 - 删除")
@DeleteMapping("/modal/{id}") @DeleteMapping("/delete")
public CommonResult<Void> delete(@PathVariable Long id) { public CommonResult<Void> delete(@RequestParam("id") Long id) {
aiChatModalService.delete(id); aiChatModalService.delete(id);
return CommonResult.success(null); return CommonResult.success(null);
} }

View File

@ -6,8 +6,6 @@ import jakarta.validation.constraints.Size;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
import java.util.Map;
/** /**
* ai chat modal * ai chat modal
* *
@ -17,32 +15,40 @@ import java.util.Map;
*/ */
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public class AiChatModalAddReq { public class AiChatModalAddReqVO {
@Schema(description = "API 秘钥编号")
@Size(max = 32, message = "API 秘钥编号最大32个字符")
@NotNull(message = "API 秘钥编号不能为空!")
private Long keyId;
@Schema(description = "模型名字") @Schema(description = "模型名字")
@Size(max = 60, message = "模型名字最大60个字符") @Size(max = 60, message = "模型名字最大60个字符")
@NotNull(message = "模型名字不能为空!") @NotNull(message = "模型名字不能为空!")
private String name; private String name;
@Schema(description = "模型类型(qianwen、yiyan、xinghuo、openai)")
@Size(max = 32, message = "模型类型最大32个字符")
@NotNull(message = "model模型不能为空!")
private String model;
@Size(max = 32, message = "模型平台最大32个字符") @Size(max = 32, message = "模型平台最大32个字符")
@Schema(description = "模型平台 参考 AiPlatformEnum") @Schema(description = "模型平台 参考 AiPlatformEnum")
@NotNull(message = "平台不能为空!") @NotNull(message = "平台不能为空!")
private String platform; private String platform;
@Schema(description = "模型类型(qianwen、yiyan、xinghuo、openai)")
@Size(max = 32, message = "模型类型最大32个字符")
@NotNull(message = "modal模型不能为空!")
private String modal;
@Schema(description = "模型照片")
@Size(max = 256, message = "模型照片地址最大256个字符")
private String imageUrl;
@Schema(description = "排序") @Schema(description = "排序")
@NotNull(message = "sort排序不能为空!") @NotNull(message = "sort排序不能为空!")
private Integer sort; private Integer sort;
@Schema(description = "模型配置JSON") // ========== 会话配置 ==========
// @Size(max = 1024, message = "模型配置最大1024个字符")
private Map<String, Object> config; @Schema(description = "温度参数")
private Integer temperature;
@Schema(description = "单条回复的最大 Token 数量")
private Integer maxTokens;
@Schema(description = "上下文的最大 Message 数量")
private Integer maxContexts;
} }

View File

@ -14,7 +14,7 @@ import lombok.experimental.Accessors;
*/ */
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public class AiChatModalListReq extends PageParam { public class AiChatModalListReqVO extends PageParam {
@Schema(description = "名字搜搜") @Schema(description = "名字搜搜")
private String search; private String search;

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model; package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.Size;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
@ -15,27 +16,35 @@ import lombok.experimental.Accessors;
@Accessors(chain = true) @Accessors(chain = true)
public class AiChatModalRes { public class AiChatModalRes {
@Schema(description = "id") @Schema(description = "编号")
private Long id; private Long id;
@Schema(description = "模型平台 参考 AiPlatformEnum") @Schema(description = "API 秘钥编号")
private String platform; private Long keyId;
@Schema(description = "模型类型 参考 YiYanChatModel、XingHuoChatModel")
private String modal;
@Schema(description = "模型名字") @Schema(description = "模型名字")
private String name; private String name;
@Schema(description = "模型照片") @Schema(description = "模型类型(qianwen、yiyan、xinghuo、openai)")
private String image; private String model;
@Schema(description = "禁用 0、正常 1、禁用") @Size(max = 32, message = "模型平台最大32个字符")
private Integer disable; private String platform;
@Schema(description = "排序 asc 排序") @Schema(description = "排序")
private Integer sort; private Integer sort;
@Schema(description = "modal 配置") @Schema(description = "状态")
private String config; private Integer status;
// ========== 会话配置 ==========
@Schema(description = "温度参数")
private Integer temperature;
@Schema(description = "单条回复的最大 Token 数量")
private Integer maxTokens;
@Schema(description = "上下文的最大 Message 数量")
private Integer maxContexts;
} }

View File

@ -0,0 +1,63 @@
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* ai chat modal
*
* @author fansili
* @time 2024/4/24 19:47
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class AiChatModalUpdateReqVO {
@Schema(description = "编号")
@Size(max = 32, message = "编号最大32个字符")
@NotNull(message = "编号不能为空")
private Long id;
@Schema(description = "API 秘钥编号")
@Size(max = 32, message = "API 秘钥编号最大32个字符")
@NotNull(message = "API 秘钥编号不能为空!")
private Long keyId;
@Schema(description = "模型名字")
@Size(max = 60, message = "模型名字最大60个字符")
@NotNull(message = "模型名字不能为空!")
private String name;
@Schema(description = "模型类型(qianwen、yiyan、xinghuo、openai)")
@Size(max = 32, message = "模型类型最大32个字符")
@NotNull(message = "model模型不能为空!")
private String model;
@Size(max = 32, message = "模型平台最大32个字符")
@Schema(description = "模型平台 参考 AiPlatformEnum")
@NotNull(message = "平台不能为空!")
private String platform;
@Schema(description = "排序")
@NotNull(message = "sort排序不能为空!")
private Integer sort;
@Schema(description = "状态")
@NotNull(message = "状态不能为空!")
private Integer status;
// ========== 会话配置 ==========
@Schema(description = "温度参数")
private Integer temperature;
@Schema(description = "单条回复的最大 Token 数量")
private Integer maxTokens;
@Schema(description = "上下文的最大 Message 数量")
private Integer maxContexts;
}

View File

@ -1,12 +1,10 @@
package cn.iocoder.yudao.module.ai.convert; package cn.iocoder.yudao.module.ai.convert;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import org.mapstruct.Mapper; import org.mapstruct.Mapper;
import org.mapstruct.Mapping;
import org.mapstruct.Mappings;
import org.mapstruct.factory.Mappers; import org.mapstruct.factory.Mappers;
import java.util.List; import java.util.List;
@ -37,10 +35,7 @@ public interface AiChatModalConvert {
* @param req * @param req
* @return * @return
*/ */
@Mappings({ AiChatModalDO convertAiChatModalDO(AiChatModalAddReqVO req);
@Mapping(target = "config", ignore = true)
})
AiChatModalDO convertAiChatModalDO(AiChatModalAddReq req);
/** /**

View File

@ -1,10 +1,11 @@
package cn.iocoder.yudao.module.ai.service; package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReq; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReq; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
/** /**
* ai modal * ai modal
@ -21,14 +22,14 @@ public interface AiChatModalService {
* @param req * @param req
* @return * @return
*/ */
PageResult<AiChatModalListRes> list(AiChatModalListReq req); PageResult<AiChatModalListRes> list(AiChatModalListReqVO req);
/** /**
* ai modal - 添加 * ai modal - 添加
* *
* @param req * @param req
*/ */
void add(AiChatModalAddReq req); void add(AiChatModalAddReqVO req);
/** /**
* ai modal - 更新 * ai modal - 更新
@ -36,7 +37,7 @@ public interface AiChatModalService {
* @param id * @param id
* @param req * @param req
*/ */
void update(Long id, AiChatModalAddReq req); void update(Long id, AiChatModalAddReqVO req);
/** /**
* ai modal - 删除 * ai modal - 删除
@ -53,6 +54,14 @@ public interface AiChatModalService {
*/ */
AiChatModalRes getChatModalOfValidate(Long modalId); AiChatModalRes getChatModalOfValidate(Long modalId);
/**
* 校验 - 是否存在
*
* @param id
* @return
*/
AiChatModalDO validateExists(Long id);
/** /**
* 校验 - 校验是否可用 * 校验 - 校验是否可用
* *

View File

@ -3,26 +3,20 @@ package cn.iocoder.yudao.module.ai.service.impl;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.validation.ValidationUtil; import cn.hutool.extra.validation.ValidationUtil;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatModal; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.convert.AiChatModalConvert; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReqVO;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalChatConfigVO;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalDallConfigVO;
import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalAddReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListReq;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalListRes;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
import cn.iocoder.yudao.module.ai.convert.AiChatModalConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper;
import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO;
import cn.iocoder.yudao.module.ai.service.AiChatModalService;
import jakarta.validation.ConstraintViolation; import jakarta.validation.ConstraintViolation;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -46,10 +40,10 @@ public class AiChatModalServiceImpl implements AiChatModalService {
private final AiChatModalMapper aiChatModalMapper; private final AiChatModalMapper aiChatModalMapper;
@Override @Override
public PageResult<AiChatModalListRes> list(AiChatModalListReq req) { public PageResult<AiChatModalListRes> list(AiChatModalListReqVO req) {
LambdaQueryWrapperX<AiChatModalDO> queryWrapperX = new LambdaQueryWrapperX<>(); LambdaQueryWrapperX<AiChatModalDO> queryWrapperX = new LambdaQueryWrapperX<>();
// 查询的都是未禁用的模型 // 查询的都是未禁用的模型
queryWrapperX.eq(AiChatModalDO::getDisable, AiChatModalDisableEnum.NO.getValue()); queryWrapperX.eq(AiChatModalDO::getStatus, CommonStatusEnum.ENABLE.getStatus());
// search // search
if (!StrUtil.isBlank(req.getSearch())) { if (!StrUtil.isBlank(req.getSearch())) {
queryWrapperX.like(AiChatModalDO::getName, req.getSearch().trim()); queryWrapperX.like(AiChatModalDO::getName, req.getSearch().trim());
@ -64,39 +58,26 @@ public class AiChatModalServiceImpl implements AiChatModalService {
} }
@Override @Override
public void add(AiChatModalAddReq req) { public void add(AiChatModalAddReqVO req) {
// 校验 platformtype // 校验 platformtype
validatePlatform(req.getPlatform()); validatePlatform(req.getPlatform());
validateModal(req.getPlatform(), req.getModal());
// 转换config
AiChatModalConfigVO aiChatModalConfigVO = convertConfig(req);
// 校验 modal config
validateModalConfig(aiChatModalConfigVO);
// 转换 do // 转换 do
AiChatModalDO insertChatModalDO = AiChatModalConvert.INSTANCE.convertAiChatModalDO(req); AiChatModalDO insertChatModalDO = AiChatModalConvert.INSTANCE.convertAiChatModalDO(req);
// 设置默认属性 // 设置默认属性
insertChatModalDO.setDisable(AiChatModalDisableEnum.NO.getValue()); insertChatModalDO.setStatus(CommonStatusEnum.ENABLE.getStatus());
insertChatModalDO.setConfig(JsonUtils.toJsonString(aiChatModalConfigVO));
// 保存数据库 // 保存数据库
aiChatModalMapper.insert(insertChatModalDO); aiChatModalMapper.insert(insertChatModalDO);
} }
@Override @Override
public void update(Long id, AiChatModalAddReq req) { public void update(Long id, AiChatModalAddReqVO req) {
// 校验 platformtype // 校验 platform
validatePlatform(req.getPlatform()); validatePlatform(req.getPlatform());
validateModal(req.getPlatform(), req.getModal());
// 转换config
AiChatModalConfigVO aiChatModalConfigVO = convertConfig(req);
// 校验 modal config
validateModalConfig(aiChatModalConfigVO);
// 校验模型是否存在 // 校验模型是否存在
validateChatModalExists(id); validateExists(id);
// 转换 updateChatModalDO // 转换 updateChatModalDO
AiChatModalDO updateChatModalDO = AiChatModalConvert.INSTANCE.convertAiChatModalDO(req); AiChatModalDO updateChatModalDO = AiChatModalConvert.INSTANCE.convertAiChatModalDO(req);
updateChatModalDO.setId(id); updateChatModalDO.setId(id);
updateChatModalDO.setConfig(JsonUtils.toJsonString(aiChatModalConfigVO));
// 更新数据库 // 更新数据库
aiChatModalMapper.updateById(updateChatModalDO); aiChatModalMapper.updateById(updateChatModalDO);
} }
@ -104,7 +85,7 @@ public class AiChatModalServiceImpl implements AiChatModalService {
@Override @Override
public void delete(Long id) { public void delete(Long id) {
// 检查 modal 是否存在 // 检查 modal 是否存在
validateChatModalExists(id); validateExists(id);
// 删除 delete // 删除 delete
aiChatModalMapper.deleteById(id); aiChatModalMapper.deleteById(id);
} }
@ -112,19 +93,19 @@ public class AiChatModalServiceImpl implements AiChatModalService {
@Override @Override
public AiChatModalRes getChatModalOfValidate(Long modalId) { public AiChatModalRes getChatModalOfValidate(Long modalId) {
// 检查 modal 是否存在 // 检查 modal 是否存在
AiChatModalDO aiChatModalDO = validateChatModalExists(modalId); AiChatModalDO aiChatModalDO = validateExists(modalId);
return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO); return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);
} }
@Override @Override
public void validateAvailable(AiChatModalRes chatModal) { public void validateAvailable(AiChatModalRes chatModal) {
// 对话模型是否可用 // 对话模型是否可用
if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) { if (CommonStatusEnum.ENABLE.getStatus().equals(chatModal.getStatus())) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED); throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
} }
} }
private AiChatModalDO validateChatModalExists(Long id) { public AiChatModalDO validateExists(Long id) {
AiChatModalDO aiChatModalDO = aiChatModalMapper.selectById(id); AiChatModalDO aiChatModalDO = aiChatModalMapper.selectById(id);
if (aiChatModalDO == null) { if (aiChatModalDO == null) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_NOT_EXIST); throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_NOT_EXIST);
@ -132,23 +113,6 @@ public class AiChatModalServiceImpl implements AiChatModalService {
return aiChatModalDO; return aiChatModalDO;
} }
private void validateModal(String platform, String modal) {
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(platform);
try {
if (AiPlatformEnum.QIAN_WEN == platformEnum) {
QianWenChatModal.valueOfModel(modal);
} else if (AiPlatformEnum.XING_HUO == platformEnum) {
XingHuoChatModel.valueOfModel(modal);
} else if (AiPlatformEnum.YI_YAN == platformEnum) {
YiYanChatModel.valueOfModel(modal);
} else {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_NOT_SUPPORTED_MODAL, platform);
}
} catch (IllegalArgumentException e) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_CONFIG_PARAMS_INCORRECT, e.getMessage());
}
}
private void validatePlatform(String platform) { private void validatePlatform(String platform) {
try { try {
AiPlatformEnum.valueOfPlatform(platform); AiPlatformEnum.valueOfPlatform(platform);
@ -163,20 +127,4 @@ public class AiChatModalServiceImpl implements AiChatModalService {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_CONFIG_PARAMS_INCORRECT, constraintViolation.getMessage()); throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_CONFIG_PARAMS_INCORRECT, constraintViolation.getMessage());
} }
} }
private static AiChatModalConfigVO convertConfig(AiChatModalAddReq req) {
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getPlatform());
AiChatModalConfigVO resVo = null;
if (AiPlatformEnum.CHAT_PLATFORM_LIST.contains(platformEnum)) {
resVo = JsonUtils.parseObject(JsonUtils.toJsonString(req.getConfig()), AiChatModalChatConfigVO.class);
} else if (AiPlatformEnum.OPEN_AI_DALL == platformEnum) {
resVo = JsonUtils.parseObject(JsonUtils.toJsonString(req.getConfig()), AiChatModalDallConfigVO.class);
}
if (resVo == null) {
throw new IllegalArgumentException("ai模型中config不能转换! json: " + req.getConfig());
}
resVo.setType(req.getModal());
resVo.setPlatform(req.getPlatform());
return resVo;
}
} }

View File

@ -72,10 +72,10 @@ public class AiChatServiceImpl implements AiChatService {
// 校验角色是否公开 // 校验角色是否公开
aiChatRoleService.validateIsPublic(aiChatRoleDO); aiChatRoleService.validateIsPublic(aiChatRoleDO);
// 获取 client 类型 // 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal()); AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModel());
// 保存 chat message // 保存 chat message
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(), chatModal.getModel(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
String content = null; String content = null;
try { try {
@ -96,7 +96,7 @@ public class AiChatServiceImpl implements AiChatService {
} finally { } finally {
// 保存 chat message // 保存 chat message
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(), chatModal.getModel(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
} }
return new AiChatMessageRespVO().setContent(content); return new AiChatMessageRespVO().setContent(content);
@ -150,11 +150,11 @@ public class AiChatServiceImpl implements AiChatService {
// 保存 chat message // 保存 chat message
// 保存 chat message // 保存 chat message
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(), chatModal.getModel(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
// 获取 client 类型 // 获取 client 类型
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal()); AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModel());
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt); Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
@ -183,7 +183,7 @@ public class AiChatServiceImpl implements AiChatService {
sseEmitter.complete(); sseEmitter.complete();
// 保存 chat message // 保存 chat message
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
chatModal.getModal(), chatModal.getId(), req.getContent(), chatModal.getModel(), chatModal.getId(), req.getContent(),
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
} }