【增加】MidjourneyProxyClient 增加批量拉去任务信息

This commit is contained in:
cherishsince 2024-06-05 15:21:27 +08:00
parent f15ecda727
commit 000bcf1143

View File

@ -3,15 +3,20 @@ package cn.iocoder.yudao.module.ai.client;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
import cn.iocoder.yudao.module.ai.config.MidjourneyProperties;
import com.google.common.collect.ImmutableMap;
import jakarta.validation.constraints.NotNull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.*;
import org.springframework.stereotype.Component;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.client.RestTemplate;
import java.util.Collection;
import java.util.List;
// TODO @fan这个写到 starter-ai 里哈搞个 MidjourneyApi参考 https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java 的风格写哈
/**
* Midjourney Proxy 客户端
@ -25,12 +30,10 @@ public class MidjourneyProxyClient {
private static final String URI_IMAGINE = "/submit/imagine";
private static final String URI_ACTON = "/submit/action";
private static final String URI_LIST_BY_CONDITION = "/task/list-by-condition";
@Value("${ai.midjourney-proxy.url:http://127.0.0.1:8080/mj}")
private String url;
@Value("${ai.midjourney-proxy.key}")
private String key;
@Autowired
private MidjourneyProperties midjourneyProperties;
@Autowired
private RestTemplate restTemplate;
@ -59,14 +62,28 @@ public class MidjourneyProxyClient {
return JsonUtils.parseObject(response.getBody(), MidjourneySubmitRespVO.class);
}
/**
* 批量查询 task 任务
*
* @param taskIds
* @return
*/
public List<MidjourneyNotifyReqVO> listByCondition(Collection<String> taskIds) {
// 1发送 post 请求
ResponseEntity<String> res = post(URI_LIST_BY_CONDITION, ImmutableMap.of("ids", taskIds));
// 2转换 对象
return JsonUtils.parseArray(res.getBody(), MidjourneyNotifyReqVO.class);
}
private ResponseEntity<String> post(String uri, Object body) {
// 1创建 HttpHeaders 对象
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.set("Authorization", "Bearer ".concat(key));
headers.set("Authorization", "Bearer ".concat(midjourneyProperties.getKey()));
// 2创建 HttpEntity 对象 HttpHeaders 和请求体传递给它
HttpEntity<String> requestEntity = new HttpEntity<>(JsonUtils.toJsonString(body), headers);
// 3发送 post 请求
return restTemplate.exchange(url.concat(uri), HttpMethod.POST, requestEntity, String.class);
return restTemplate.exchange(midjourneyProperties.getUrl().concat(uri), HttpMethod.POST, requestEntity, String.class);
}
}