【优化】AI:调整 image 相关的包结构

This commit is contained in:
YunaiV 2024-05-24 20:56:21 +08:00
parent 17afb14c88
commit 4fddec5f02
21 changed files with 52 additions and 68 deletions

View File

@ -89,7 +89,7 @@ public class AiChatMessageController {
@DeleteMapping("/delete")
@Parameter(name = "id", required = true, description = "消息编号", example = "1024")
public CommonResult<Boolean> deleteChatMessage(@RequestParam("id") Long id) {
chatMessageService.deleteMessage(id, getLoginUserId());
chatMessageService.deleteChatMessage(id, getLoginUserId());
return success(true);
}

View File

@ -3,7 +3,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.image;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.service.AiImageService;
import cn.iocoder.yudao.module.ai.service.image.AiImageService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor;

View File

@ -18,7 +18,6 @@ import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import java.util.function.Function;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@ -77,7 +76,7 @@ public class AiChatModelController {
@Operation(summary = "获得聊天模型列表")
@Parameter(name = "status", description = "状态", required = true, example = "1")
public CommonResult<List<AiChatModelRespVO>> getChatModelSimpleList(@RequestParam("status") Integer status) {
List<AiChatModelDO> list = chatModelService.getChatModelList(status);
List<AiChatModelDO> list = chatModelService.getChatModelListByStatus(status);
return success(convertList(list, model -> new AiChatModelRespVO().setId(model.getId())
.setName(model.getName()).setModel(model.getModel())));
}

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.dal.mysql;
package cn.iocoder.yudao.module.ai.dal.mysql.image;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
@ -26,4 +26,5 @@ public interface AiImageMapper extends BaseMapperX<AiImageDO> {
default void updateByMjNonce(Long mjNonceId, AiImageDO aiImageDO) {
this.update(aiImageDO, new LambdaQueryWrapperX<AiImageDO>().eq(AiImageDO::getMjNonceId, mjNonceId));
}
}

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.dal.mysql;
package cn.iocoder.yudao.module.ai.dal.mysql.model;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
@ -26,17 +26,6 @@ public interface AiChatModelMapper extends BaseMapperX<AiChatModelDO> {
.orderByAsc("sort"));
}
// TODO 芋艿不需要哈
/**
* 查询 - 根据 ids
*
* @param modalIds
* @return
*/
default List<AiChatModelDO> selectByIds(Collection<Long> modalIds) {
return this.selectList(new LambdaQueryWrapperX<AiChatModelDO>().eq(AiChatModelDO::getId, modalIds));
}
default PageResult<AiChatModelDO> selectPage(AiChatModelPageReqVO reqVO) {
return selectPage(reqVO, new LambdaQueryWrapperX<AiChatModelDO>()
.likeIfPresent(AiChatModelDO::getName, reqVO.getName())
@ -50,4 +39,5 @@ public interface AiChatModelMapper extends BaseMapperX<AiChatModelDO> {
.eq(AiChatModelDO::getStatus, status)
.orderByAsc(AiChatModelDO::getSort));
}
}

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.dal.mysql;
package cn.iocoder.yudao.module.ai.dal.mysql.model;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;

View File

@ -45,7 +45,7 @@ public interface AiChatMessageService {
* @param id 消息编号
* @param userId 用户编号
*/
void deleteMessage(Long id, Long userId);
void deleteChatMessage(Long id, Long userId);
/**
* 删除指定会话的消息

View File

@ -40,6 +40,8 @@ import java.time.LocalDateTime;
import java.util.*;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
@ -138,7 +140,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
// 3.3 流式返回
// 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
// TODO 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@ -149,14 +151,14 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
o -> o.setUserAvatar(user.getAvatar()));
AiChatMessageSendRespVO.Message receive = BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class,
o -> o.setRoleAvatar(role != null ? role.getAvatar() : null)).setContent(newContent);
return CommonResult.success(new AiChatMessageSendRespVO().setSend(send).setReceive(receive));
return success(new AiChatMessageSendRespVO().setSend(send).setReceive(receive));
}).doOnComplete(() -> {
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString()));
}).doOnError(throwable -> {
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage()));
}).onErrorResume( error -> {
return Flux.just(CommonResult.error(ErrorCodeConstants.AI_CHAT_STREAM_ERROR));
}).onErrorResume(error -> {
return Flux.just(error(ErrorCodeConstants.AI_CHAT_STREAM_ERROR));
});
}
@ -235,7 +237,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
@Override
public void deleteMessage(Long id, Long userId) {
public void deleteChatMessage(Long id, Long userId) {
// 1. 校验消息存在
AiChatMessageDO message = chatMessageMapper.selectById(id);
if (message == null || ObjUtil.notEqual(message.getUserId(), userId)) {

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.service;
package cn.iocoder.yudao.module.ai.service.image;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.service.impl;
package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.StrUtil;
@ -21,9 +21,8 @@ import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
import cn.iocoder.yudao.module.ai.service.AiImageService;
import jakarta.annotation.PostConstruct;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.service.midjourneyHandler;
package cn.iocoder.yudao.module.ai.service.image.midjourneyHandler;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
@ -10,7 +10,7 @@ import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperationsVO;
import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.dal.vo;
package cn.iocoder.yudao.module.ai.service.image.midjourneyHandler.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.dal.vo;
package cn.iocoder.yudao.module.ai.service.image.midjourneyHandler.vo;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import lombok.Data;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.dal.vo;
package cn.iocoder.yudao.module.ai.service.image.midjourneyHandler.vo;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
import lombok.Data;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.dal.vo;
package cn.iocoder.yudao.module.ai.service.image.midjourneyHandler.vo;
import lombok.Data;
import lombok.experimental.Accessors;

View File

@ -6,6 +6,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatMode
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import jakarta.validation.Valid;
import java.util.Collection;
import java.util.List;
import java.util.Set;
@ -79,13 +80,13 @@ public interface AiChatModelService {
* @param status 状态
* @return 聊天模型列表
*/
List<AiChatModelDO> getChatModelList(Integer status);
List<AiChatModelDO> getChatModelListByStatus(Integer status);
/**
* - 根据多个 ids 获取
* 得聊天模型列表
*
* @param modalIds
* @return
* @param ids 编号数组
* @return 模型列表
*/
List<AiChatModelDO> getModalByIds(Set<Long> modalIds);
List<AiChatModelDO> getChatModelList(Collection<Long> ids);
}

View File

@ -7,15 +7,14 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModelMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatModelMapper;
import jakarta.annotation.Resource;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.*;
@ -103,13 +102,13 @@ public class AiChatModelServiceImpl implements AiChatModelService {
}
@Override
public List<AiChatModelDO> getChatModelList(Integer status) {
public List<AiChatModelDO> getChatModelListByStatus(Integer status) {
return chatModelMapper.selectList(status);
}
@Override
public List<AiChatModelDO> getModalByIds(Set<Long> modalIds) {
return chatModelMapper.selectByIds(modalIds);
public List<AiChatModelDO> getChatModelList(Collection<Long> ids) {
return chatModelMapper.selectBatchIds(ids);
}
}

View File

@ -10,7 +10,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRoleP
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRoleSaveMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRoleSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatRoleMapper;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

View File

@ -20,6 +20,7 @@ import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import com.google.cloud.vertexai.VertexAI;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.chat.StreamingChatClient;
@ -31,6 +32,8 @@ import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatClient;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions;
import java.util.List;
@ -57,8 +60,8 @@ public class AiClientFactoryImpl implements AiClientFactory {
return buildXingHuoChatClient(apiKey);
case QIAN_WEN:
return buildQianWenChatClient(apiKey);
// case GEMIR:
// return buildGoogleGemir(apiKey);
case GEMIR:
return buildGoogleGemir(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@ -165,24 +168,13 @@ public class AiClientFactoryImpl implements AiClientFactory {
QianWenApi qianWenApi = new QianWenApi(key, QianWenChatModal.QWEN_72B_CHAT);
return new QianWenChatClient(qianWenApi);
}
//
// private static VertexAiGeminiChatClient buildGoogleGemir(String key) {
// List<String> keys = StrUtil.split(key, '|');
// Assert.equals(keys.size(), 2, "VertexAiGeminiChatClient 的密钥需要 (projectId|location) 格式");
//// VertexAiGeminiConnectionProperties connectionProperties = new VertexAiGeminiConnectionProperties();
//// connectionProperties.setApiKey("AIzaSyBpe376HTA8uPKJN_OJTh7MEO3v6LMqfXU");
////
//// GoogleCredentials credentials = GoogleCredentials.fromStream(connectionProperties.getCredentialsUri().getInputStream());
// // todo @芋艿 google gemini 没找到对于初始化 client 方式文档中说是用过 GoogleCredentials 来初始化凭证
// // api-key: AIzaSyBpe376HTA8uPKJN_OJTh7MEO3v6LMqfXU
// VertexAI vertexApi = new VertexAI(
// "skilled-snow-409401",
// "us-central1"
// );
// return new VertexAiGeminiChatClient(vertexApi,
// VertexAiGeminiChatOptions.builder()
// .withTemperature(0.4F)
// .withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue())
// .build());
// }
private static VertexAiGeminiChatClient buildGoogleGemir(String key) {
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "VertexAiGeminiChatClient 的密钥需要 (projectId|location) 格式");
VertexAI vertexApi = new VertexAI(keys.get(0), keys.get(1));
return new VertexAiGeminiChatClient(vertexApi);
}
}

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.framework.ai.core.model.tongyi;
import cn.hutool.core.util.NumberUtil;
import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import org.springframework.ai.chat.*;