mirror of
https://gitee.com/huangge1199_admin/vue-pro.git
synced 2025-01-31 09:30:05 +08:00
Merge remote-tracking branch 'origin/master-jdk21-ai' into master-jdk21-ai
# Conflicts: # yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/AiMusicController.http
This commit is contained in:
commit
4ef55fa921
@ -1,15 +1,13 @@
|
|||||||
|
### 发送消息(段式)
|
||||||
### chat call
|
POST {{baseUrl}}/ai/chat/message/send
|
||||||
POST {{baseUrl}}/admin-api/ai/chat/message/send
|
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
Authorization: {{token}}
|
Authorization: {{token}}
|
||||||
|
|
||||||
{
|
{
|
||||||
"conversationId": "1781604279872581649",
|
"conversationId": "1781604279872581724",
|
||||||
"content": "你是 OpenAI 么?"
|
"content": "你是 OpenAI 么?"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
### 发送消息(流式)
|
### 发送消息(流式)
|
||||||
POST {{baseUrl}}/ai/chat/message/send-stream
|
POST {{baseUrl}}/ai/chat/message/send-stream
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
@ -20,11 +18,10 @@ Authorization: {{token}}
|
|||||||
"content": "1+1=?"
|
"content": "1+1=?"
|
||||||
}
|
}
|
||||||
|
|
||||||
### message list
|
### 获得指定对话的消息列表
|
||||||
GET {{baseUrl}}/admin-api/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649
|
GET {{baseUrl}}/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649
|
||||||
Authorization: {{token}}
|
Authorization: {{token}}
|
||||||
|
|
||||||
|
### 删除消息
|
||||||
### message delete
|
DELETE {{baseUrl}}/ai/chat/message/delete?id=50
|
||||||
DELETE {{baseUrl}}/admin-api/ai/chat/message/delete?id=50
|
Authorization: {{token}}
|
||||||
Authorization: {{token}}
|
|
@ -21,10 +21,10 @@ import io.swagger.v3.oas.annotations.Parameter;
|
|||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import jakarta.annotation.security.PermitAll;
|
import jakarta.annotation.security.PermitAll;
|
||||||
|
import jakarta.validation.Valid;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.security.access.prepost.PreAuthorize;
|
import org.springframework.security.access.prepost.PreAuthorize;
|
||||||
import org.springframework.validation.annotation.Validated;
|
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
@ -51,14 +51,14 @@ public class AiChatMessageController {
|
|||||||
|
|
||||||
@Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢")
|
@Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢")
|
||||||
@PostMapping("/send")
|
@PostMapping("/send")
|
||||||
public CommonResult<AiChatMessageRespVO> sendMessage(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
|
public CommonResult<AiChatMessageSendRespVO> sendMessage(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
|
||||||
return success(chatMessageService.sendMessage(sendReqVO));
|
return success(chatMessageService.sendMessage(sendReqVO, getLoginUserId()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
|
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
|
||||||
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||||
@PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
|
@PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
|
||||||
public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
|
public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
|
||||||
return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId());
|
return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,14 +31,6 @@ public class AiChatMessageSendRespVO {
|
|||||||
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
|
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
|
||||||
private LocalDateTime createTime;
|
private LocalDateTime createTime;
|
||||||
|
|
||||||
// ========== 扩展字段 ==========
|
|
||||||
|
|
||||||
@Schema(description = "用户头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://iocoder.cn/1.png")
|
|
||||||
private String userAvatar;
|
|
||||||
|
|
||||||
@Schema(description = "角色头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://iocoder.cn/2.png")
|
|
||||||
private String roleAvatar;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ public class AiImageController {
|
|||||||
|
|
||||||
@Operation(summary = "生成图片")
|
@Operation(summary = "生成图片")
|
||||||
@PostMapping("/draw")
|
@PostMapping("/draw")
|
||||||
public CommonResult<Long> drawImage(@Validated @RequestBody AiImageDrawReqVO drawReqVO) {
|
public CommonResult<Long> drawImage(@Valid @RequestBody AiImageDrawReqVO drawReqVO) {
|
||||||
return success(imageService.drawImage(getLoginUserId(), drawReqVO));
|
return success(imageService.drawImage(getLoginUserId(), drawReqVO));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ public class AiImageController {
|
|||||||
|
|
||||||
@Operation(summary = "【Midjourney】生成图片")
|
@Operation(summary = "【Midjourney】生成图片")
|
||||||
@PostMapping("/midjourney/imagine")
|
@PostMapping("/midjourney/imagine")
|
||||||
public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiMidjourneyImagineReqVO reqVO) {
|
public CommonResult<Long> midjourneyImagine(@Valid @RequestBody AiMidjourneyImagineReqVO reqVO) {
|
||||||
Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO);
|
Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO);
|
||||||
return success(imageId);
|
return success(imageId);
|
||||||
}
|
}
|
||||||
@ -92,14 +92,14 @@ public class AiImageController {
|
|||||||
@Operation(summary = "【Midjourney】通知图片进展", description = "由 Midjourney Proxy 回调")
|
@Operation(summary = "【Midjourney】通知图片进展", description = "由 Midjourney Proxy 回调")
|
||||||
@PostMapping("/midjourney/notify") // 必须是 POST 方法,否则会报错
|
@PostMapping("/midjourney/notify") // 必须是 POST 方法,否则会报错
|
||||||
@PermitAll
|
@PermitAll
|
||||||
public CommonResult<Boolean> midjourneyNotify(@Validated @RequestBody MidjourneyApi.Notify notify) {
|
public CommonResult<Boolean> midjourneyNotify(@Valid @RequestBody MidjourneyApi.Notify notify) {
|
||||||
imageService.midjourneyNotify(notify);
|
imageService.midjourneyNotify(notify);
|
||||||
return success(true);
|
return success(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Operation(summary = "【Midjourney】Action 操作(二次生成图片)", description = "例如说:放大、缩小、U1、U2 等")
|
@Operation(summary = "【Midjourney】Action 操作(二次生成图片)", description = "例如说:放大、缩小、U1、U2 等")
|
||||||
@PostMapping("/midjourney/action")
|
@PostMapping("/midjourney/action")
|
||||||
public CommonResult<Long> midjourneyAction(@Validated @RequestBody AiMidjourneyActionReqVO reqVO) {
|
public CommonResult<Long> midjourneyAction(@Valid @RequestBody AiMidjourneyActionReqVO reqVO) {
|
||||||
Long imageId = imageService.midjourneyAction(getLoginUserId(), reqVO);
|
Long imageId = imageService.midjourneyAction(getLoginUserId(), reqVO);
|
||||||
return success(imageId);
|
return success(imageId);
|
||||||
}
|
}
|
||||||
|
@ -6,8 +6,8 @@ Authorization: {{token}}
|
|||||||
{
|
{
|
||||||
"platform": "Suno",
|
"platform": "Suno",
|
||||||
"generateMode": 2,
|
"generateMode": 2,
|
||||||
"prompt": "周末啦!",
|
"prompt": "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。",
|
||||||
"model": "chirp-v3-5",
|
"model": "chirp-v3.5",
|
||||||
"tags": ["Happy"],
|
"tags": ["Happy"],
|
||||||
"title": "Happy Song"
|
"title": "Happy Song"
|
||||||
}
|
}
|
||||||
@ -20,7 +20,7 @@ Authorization: {{token}}
|
|||||||
{
|
{
|
||||||
"platform": "Suno",
|
"platform": "Suno",
|
||||||
"generateMode": 1,
|
"generateMode": 1,
|
||||||
"model": "chirp-v3-5",
|
"model": "chirp-v3.5",
|
||||||
"prompt": "今天是星球六,结果是个下雨天,希望心情很美丽",
|
"prompt": "happy music",
|
||||||
"makeInstrumental": false
|
"makeInstrumental": false
|
||||||
}
|
}
|
@ -38,6 +38,10 @@ public class AiMusicController {
|
|||||||
@PostMapping("/generate")
|
@PostMapping("/generate")
|
||||||
@Operation(summary = "音乐生成")
|
@Operation(summary = "音乐生成")
|
||||||
public CommonResult<List<Long>> generateMusic(@RequestBody @Valid AiSunoGenerateReqVO reqVO) {
|
public CommonResult<List<Long>> generateMusic(@RequestBody @Valid AiSunoGenerateReqVO reqVO) {
|
||||||
|
// if (true) {
|
||||||
|
// musicService.syncMusic();
|
||||||
|
// return null;
|
||||||
|
// }
|
||||||
return success(musicService.generateMusic(getLoginUserId(), reqVO));
|
return success(musicService.generateMusic(getLoginUserId(), reqVO));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,8 +2,9 @@ package cn.iocoder.yudao.module.ai.service.chat;
|
|||||||
|
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||||
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
@ -22,9 +23,10 @@ public interface AiChatMessageService {
|
|||||||
* 发送消息
|
* 发送消息
|
||||||
*
|
*
|
||||||
* @param sendReqVO 发送信息
|
* @param sendReqVO 发送信息
|
||||||
|
* @param userId 用户编号
|
||||||
* @return 发送结果
|
* @return 发送结果
|
||||||
*/
|
*/
|
||||||
AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO sendReqVO);
|
AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 发送消息
|
* 发送消息
|
||||||
|
@ -10,22 +10,20 @@ import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
|||||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
|
||||||
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
|
||||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.chat.messages.*;
|
import org.springframework.ai.chat.messages.*;
|
||||||
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
import org.springframework.ai.chat.model.ChatResponse;
|
import org.springframework.ai.chat.model.ChatResponse;
|
||||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||||
@ -64,47 +62,37 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
@Resource
|
@Resource
|
||||||
private AiChatModelService chatModalService;
|
private AiChatModelService chatModalService;
|
||||||
@Resource
|
@Resource
|
||||||
private AiChatRoleService chatRoleService;
|
|
||||||
@Resource
|
|
||||||
private AiApiKeyService apiKeyService;
|
private AiApiKeyService apiKeyService;
|
||||||
|
|
||||||
@Transactional(rollbackFor = Exception.class)
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO req) {
|
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
|
||||||
return null; // TODO 芋艿:一起改
|
// 1.1 校验对话存在
|
||||||
// Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
|
||||||
// // 查询对话
|
if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
|
||||||
// AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
|
throw exception(CHAT_CONVERSATION_NOT_EXISTS);
|
||||||
// // 获取对话模型
|
}
|
||||||
// AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId());
|
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
|
||||||
// // 获取角色信息
|
// 1.2 校验模型
|
||||||
// AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
|
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
|
||||||
// // 获取 client 类型
|
ChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
|
||||||
// AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
|
|
||||||
// // 保存 chat message
|
// 2. 插入 user 发送消息
|
||||||
// createChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
||||||
// chatModel.getModel(), chatModel.getId(), req.getContent());
|
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
|
||||||
// String content = null;
|
|
||||||
// int tokens = 0;
|
// 3.1 插入 assistant 接收消息
|
||||||
// try {
|
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
||||||
// // 创建 chat 需要的 Prompt
|
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
||||||
// Prompt prompt = new Prompt(req.getContent());
|
|
||||||
// // TODO @芋艿 @范 看要不要支持这些
|
// 3.2 创建 chat 需要的 Prompt
|
||||||
//// req.setTopK(req.getTopK());
|
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
|
||||||
//// req.setTopP(req.getTopP());
|
ChatResponse chatResponse = chatClient.call(prompt);
|
||||||
//// req.setTemperature(req.getTemperature());
|
|
||||||
// // 发送 call 调用
|
// 3.3 段式返回
|
||||||
// ChatClient chatClient = chatClientFactory.getChatClient(platformEnum);
|
String newContent = chatResponse.getResult().getOutput().getContent();
|
||||||
// ChatResponse call = chatClient.call(prompt);
|
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
|
||||||
// content = call.getResult().getOutput().getContent();
|
return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
|
||||||
// // 更新 conversation
|
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
|
||||||
// } catch (Exception e) {
|
|
||||||
// content = ExceptionUtil.getMessage(e);
|
|
||||||
// } finally {
|
|
||||||
// // 保存 chat message
|
|
||||||
// createChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
||||||
// chatModel.getModel(), chatModel.getId(), content);
|
|
||||||
// }
|
|
||||||
// return new AiChatMessageRespVO().setContent(content);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -112,14 +100,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
// 1.1 校验对话存在
|
// 1.1 校验对话存在
|
||||||
AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
|
AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
|
||||||
if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
|
if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
|
||||||
throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿:异常情况的对接;
|
throw exception(CHAT_CONVERSATION_NOT_EXISTS);
|
||||||
}
|
}
|
||||||
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
|
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
|
||||||
// 1.2 校验模型
|
// 1.2 校验模型
|
||||||
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
|
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
|
||||||
StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
|
StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
|
||||||
// 1.3 获取用户头像、角色头像
|
|
||||||
AiChatRoleDO role = conversation.getRoleId() != null ? chatRoleService.getChatRole(conversation.getRoleId()) : null;
|
|
||||||
|
|
||||||
// 2. 插入 user 发送消息
|
// 2. 插入 user 发送消息
|
||||||
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
||||||
@ -149,9 +135,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
// TODO @芋艿:失败的情况下,要不要删除消息
|
// TODO @芋艿:失败的情况下,要不要删除消息
|
||||||
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
|
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
|
||||||
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage()));
|
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage()));
|
||||||
}).onErrorResume(error -> {
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
|
||||||
return Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR));
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
||||||
|
@ -8,7 +8,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageR
|
|||||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||||
import jakarta.validation.Valid;
|
import jakarta.validation.Valid;
|
||||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@ -76,12 +76,12 @@ public interface AiApiKeyService {
|
|||||||
// ========== 与 spring-ai 集成 ==========
|
// ========== 与 spring-ai 集成 ==========
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获得 StreamingChatClient 对象
|
* 获得 ChatModel 对象
|
||||||
*
|
*
|
||||||
* @param id 编号
|
* @param id 编号
|
||||||
* @return StreamingChatClient 对象
|
* @return ChatModel 对象
|
||||||
*/
|
*/
|
||||||
StreamingChatModel getStreamingChatClient(Long id);
|
ChatModel getChatClient(Long id);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获得 ImageClient 对象
|
* 获得 ImageClient 对象
|
||||||
|
@ -12,7 +12,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
|
|||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.validation.annotation.Validated;
|
import org.springframework.validation.annotation.Validated;
|
||||||
@ -98,10 +98,10 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
|
|||||||
// ========== 与 spring-ai 集成 ==========
|
// ========== 与 spring-ai 集成 ==========
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public StreamingChatModel getStreamingChatClient(Long id) {
|
public ChatModel getChatClient(Long id) {
|
||||||
AiApiKeyDO apiKey = validateApiKey(id);
|
AiApiKeyDO apiKey = validateApiKey(id);
|
||||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
||||||
return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
|
return clientFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -21,6 +21,7 @@ import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
|||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@ -49,6 +50,7 @@ public class AiMusicServiceImpl implements AiMusicService {
|
|||||||
private FileApi fileApi;
|
private FileApi fileApi;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
|
public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
|
||||||
// 1. 调用 Suno 生成音乐
|
// 1. 调用 Suno 生成音乐
|
||||||
SunoApi sunoApi = apiKeyService.getSunoApi();
|
SunoApi sunoApi = apiKeyService.getSunoApi();
|
||||||
@ -164,14 +166,9 @@ public class AiMusicServiceImpl implements AiMusicService {
|
|||||||
*/
|
*/
|
||||||
private List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicList) {
|
private List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicList) {
|
||||||
return convertList(musicList, musicData -> {
|
return convertList(musicList, musicData -> {
|
||||||
Integer status;
|
Integer status = Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.SUCCESS.getStatus()
|
||||||
if (Objects.equals("complete", musicData.status())) {
|
: Objects.equals("error", musicData.status()) ? AiMusicStatusEnum.FAIL.getStatus()
|
||||||
status = AiMusicStatusEnum.SUCCESS.getStatus();
|
: AiMusicStatusEnum.IN_PROGRESS.getStatus();
|
||||||
} else if (Objects.equals("error", musicData.status())) {
|
|
||||||
status = AiMusicStatusEnum.FAIL.getStatus();
|
|
||||||
} else {
|
|
||||||
status = AiMusicStatusEnum.IN_PROGRESS.getStatus();
|
|
||||||
}
|
|
||||||
return new AiMusicDO()
|
return new AiMusicDO()
|
||||||
.setTaskId(musicData.id()).setModel(musicData.modelName())
|
.setTaskId(musicData.id()).setModel(musicData.modelName())
|
||||||
.setDescription(musicData.gptDescriptionPrompt())
|
.setDescription(musicData.gptDescriptionPrompt())
|
||||||
|
@ -53,7 +53,7 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||||||
// 1.1 校验模型
|
// 1.1 校验模型
|
||||||
// TODO @xin:可以约定大于配置先,查询某个名字。例如说,写作助手!然后写作助手,上面是有个 model 的,可以使用它。
|
// TODO @xin:可以约定大于配置先,查询某个名字。例如说,写作助手!然后写作助手,上面是有个 model 的,可以使用它。
|
||||||
AiChatModelDO model = chatModalService.validateChatModel(14L);
|
AiChatModelDO model = chatModalService.validateChatModel(14L);
|
||||||
StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
|
StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
|
||||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ package cn.iocoder.yudao.framework.ai.core.factory;
|
|||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -14,26 +14,26 @@ import org.springframework.ai.image.ImageModel;
|
|||||||
public interface AiClientFactory {
|
public interface AiClientFactory {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 基于指定配置,获得 StreamingChatClient 对象
|
* 基于指定配置,获得 ChatModel 对象
|
||||||
*
|
*
|
||||||
* 如果不存在,则进行创建
|
* 如果不存在,则进行创建
|
||||||
*
|
*
|
||||||
* @param platform 平台
|
* @param platform 平台
|
||||||
* @param apiKey API KEY
|
* @param apiKey API KEY
|
||||||
* @param url API URL
|
* @param url API URL
|
||||||
* @return StreamingChatClient 对象
|
* @return ChatModel 对象
|
||||||
*/
|
*/
|
||||||
StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url);
|
ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 基于默认配置,获得 StreamingChatClient 对象
|
* 基于默认配置,获得 ChatModel 对象
|
||||||
*
|
*
|
||||||
* 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
|
* 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
|
||||||
*
|
*
|
||||||
* @param platform 平台
|
* @param platform 平台
|
||||||
* @return StreamingChatClient 对象
|
* @return ChatModel 对象
|
||||||
*/
|
*/
|
||||||
StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform);
|
ChatModel getDefaultChatClient(AiPlatformEnum platform);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 基于默认配置,获得 ImageClient 对象
|
* 基于默认配置,获得 ImageClient 对象
|
||||||
|
@ -23,7 +23,7 @@ import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
|||||||
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
|
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
|
||||||
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
|
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
|
||||||
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
|
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
|
||||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.image.ImageModel;
|
||||||
import org.springframework.ai.ollama.OllamaChatModel;
|
import org.springframework.ai.ollama.OllamaChatModel;
|
||||||
import org.springframework.ai.ollama.api.OllamaApi;
|
import org.springframework.ai.ollama.api.OllamaApi;
|
||||||
@ -50,9 +50,9 @@ import java.util.List;
|
|||||||
public class AiClientFactoryImpl implements AiClientFactory {
|
public class AiClientFactoryImpl implements AiClientFactory {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) {
|
public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) {
|
||||||
String cacheKey = buildClientCacheKey(StreamingChatModel.class, platform, apiKey, url);
|
String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url);
|
||||||
return Singleton.get(cacheKey, (Func0<StreamingChatModel>) () -> {
|
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
|
||||||
//noinspection EnhancedSwitchMigration
|
//noinspection EnhancedSwitchMigration
|
||||||
switch (platform) {
|
switch (platform) {
|
||||||
case OPENAI:
|
case OPENAI:
|
||||||
@ -74,7 +74,7 @@ public class AiClientFactoryImpl implements AiClientFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform) {
|
public ChatModel getDefaultChatClient(AiPlatformEnum platform) {
|
||||||
//noinspection EnhancedSwitchMigration
|
//noinspection EnhancedSwitchMigration
|
||||||
switch (platform) {
|
switch (platform) {
|
||||||
case OPENAI:
|
case OPENAI:
|
||||||
|
Loading…
Reference in New Issue
Block a user