diff --git a/yudao-framework/yudao-spring-boot-starter-test/src/main/java/cn/iocoder/yudao/framework/test/core/util/RandomUtils.java b/yudao-framework/yudao-spring-boot-starter-test/src/main/java/cn/iocoder/yudao/framework/test/core/util/RandomUtils.java index aa2f18c76..935929e8d 100644 --- a/yudao-framework/yudao-spring-boot-starter-test/src/main/java/cn/iocoder/yudao/framework/test/core/util/RandomUtils.java +++ b/yudao-framework/yudao-spring-boot-starter-test/src/main/java/cn/iocoder/yudao/framework/test/core/util/RandomUtils.java @@ -2,13 +2,16 @@ package cn.iocoder.yudao.framework.test.core.util; import cn.hutool.core.util.ArrayUtil; import cn.hutool.core.util.RandomUtil; +import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; -import cn.iocoder.yudao.framework.common.util.collection.SetUtils; import uk.co.jemos.podam.api.PodamFactory; import uk.co.jemos.podam.api.PodamFactoryImpl; import java.lang.reflect.Type; -import java.util.*; +import java.util.Arrays; +import java.util.Date; +import java.util.List; +import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -22,7 +25,6 @@ public class RandomUtils { private static final int RANDOM_STRING_LENGTH = 10; - private static final Set TINYINT_FIELDS = SetUtils.asSet("type", "category"); private static final int TINYINT_MAX = 127; private static final int RANDOM_DATE_MAX = 30; @@ -41,9 +43,10 @@ public class RandomUtils { if (attributeMetadata.getAttributeName().equals("status")) { return RandomUtil.randomEle(CommonStatusEnum.values()).getStatus(); } - // 针对部分字段,使用 tinyint 范围 - if (TINYINT_FIELDS.contains(attributeMetadata.getAttributeName())) { - return RandomUtil.randomInt(1, TINYINT_MAX + 1); + // 如果是 type、status 结尾的字段,返回 tinyint 范围 + if (StrUtil.endWithAnyIgnoreCase(attributeMetadata.getAttributeName(), + "type", "status", "category")) { + return RandomUtil.randomInt(0, TINYINT_MAX + 1); } return RandomUtil.randomInt(); }); diff --git a/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2ApproveServiceImpl.java b/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2ApproveServiceImpl.java index b830278c8..9ceb0d3ca 100644 --- a/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2ApproveServiceImpl.java +++ b/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2ApproveServiceImpl.java @@ -6,7 +6,9 @@ import cn.iocoder.yudao.framework.common.util.date.DateUtils; import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ApproveDO; import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ClientDO; import cn.iocoder.yudao.module.system.dal.mysql.oauth2.OAuth2ApproveMapper; +import com.google.common.annotations.VisibleForTesting; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import org.springframework.validation.annotation.Validated; import javax.annotation.Resource; @@ -35,6 +37,7 @@ public class OAuth2ApproveServiceImpl implements OAuth2ApproveService { private OAuth2ApproveMapper oauth2ApproveMapper; @Override + @Transactional public boolean checkForPreApproval(Long userId, Integer userType, String clientId, Collection requestedScopes) { // 第一步,基于 Client 的自动授权计算,如果 scopes 都在自动授权中,则返回 true 通过 OAuth2ClientDO clientDO = oauth2ClientService.validOAuthClientFromCache(clientId); @@ -49,14 +52,14 @@ public class OAuth2ApproveServiceImpl implements OAuth2ApproveService { } // 第二步,算上用户已经批准的授权。如果 scopes 都包含,则返回 true - List approveDOs = oauth2ApproveMapper.selectListByUserIdAndUserTypeAndClientId( - userId, userType, clientId); + List approveDOs = getApproveList(userId, userType, clientId); Set scopes = convertSet(approveDOs, OAuth2ApproveDO::getScope, - o -> o.getApproved() && !DateUtils.isExpired(o.getExpiresTime())); // 只保留未过期的 + OAuth2ApproveDO::getApproved); // 只保留未过期的 + 同意的 return CollUtil.containsAll(scopes, requestedScopes); } @Override + @Transactional public boolean updateAfterApproval(Long userId, Integer userType, String clientId, Map requestedScopes) { // 如果 requestedScopes 为空,说明没有要求,则返回 true 通过 if (CollUtil.isEmpty(requestedScopes)) { @@ -83,8 +86,9 @@ public class OAuth2ApproveServiceImpl implements OAuth2ApproveService { return approveDOs; } - private void saveApprove(Long userId, Integer userType, String clientId, - String scope, Boolean approved, Date expireTime) { + @VisibleForTesting + void saveApprove(Long userId, Integer userType, String clientId, + String scope, Boolean approved, Date expireTime) { // 先更新 OAuth2ApproveDO approveDO = new OAuth2ApproveDO().setUserId(userId).setUserType(userType) .setClientId(clientId).setScope(scope).setApproved(approved).setExpiresTime(expireTime); diff --git a/yudao-module-system/yudao-module-system-biz/src/test/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2ApproveServiceImplTest.java b/yudao-module-system/yudao-module-system-biz/src/test/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2ApproveServiceImplTest.java new file mode 100644 index 000000000..a9ea70ff4 --- /dev/null +++ b/yudao-module-system/yudao-module-system-biz/src/test/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2ApproveServiceImplTest.java @@ -0,0 +1,267 @@ +package cn.iocoder.yudao.module.system.service.oauth2; + +import cn.hutool.core.util.ObjectUtil; +import cn.iocoder.yudao.framework.common.enums.UserTypeEnum; +import cn.iocoder.yudao.framework.common.util.date.DateUtils; +import cn.iocoder.yudao.framework.test.core.ut.BaseDbUnitTest; +import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ApproveDO; +import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ClientDO; +import cn.iocoder.yudao.module.system.dal.mysql.oauth2.OAuth2ApproveMapper; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.Test; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.context.annotation.Import; + +import javax.annotation.Resource; +import java.time.Duration; +import java.util.*; + +import static cn.hutool.core.util.RandomUtil.*; +import static cn.iocoder.yudao.framework.common.util.date.DateUtils.addTime; +import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertPojoEquals; +import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.randomString; +import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.*; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +/** + * {@link OAuth2ApproveServiceImpl} 的单元测试类 + * + * @author 芋道源码 + */ +@Import(OAuth2ApproveServiceImpl.class) +public class OAuth2ApproveServiceImplTest extends BaseDbUnitTest { + + @Resource + private OAuth2ApproveServiceImpl oauth2ApproveService; + + @Resource + private OAuth2ApproveMapper oauth2ApproveMapper; + + @MockBean + private OAuth2ClientService oauth2ClientService; + + @Test + public void checkForPreApproval_clientAutoApprove() { + // 准备参数 + Long userId = randomLongId(); + Integer userType = randomEle(UserTypeEnum.values()).getValue(); + String clientId = randomString(); + List requestedScopes = Lists.newArrayList("read"); + // mock 方法 + when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))) + .thenReturn(randomPojo(OAuth2ClientDO.class).setAutoApproveScopes(requestedScopes)); + + // 调用 + boolean success = oauth2ApproveService.checkForPreApproval(userId, userType, + clientId, requestedScopes); + // 断言 + assertTrue(success); + List result = oauth2ApproveMapper.selectList(); + assertEquals(1, result.size()); + assertEquals(userId, result.get(0).getUserId()); + assertEquals(userType, result.get(0).getUserType()); + assertEquals(clientId, result.get(0).getClientId()); + assertEquals("read", result.get(0).getScope()); + assertTrue(result.get(0).getApproved()); + assertFalse(DateUtils.isExpired(result.get(0).getExpiresTime())); + } + + @Test + public void checkForPreApproval_approve() { + // 准备参数 + Long userId = randomLongId(); + Integer userType = randomEle(UserTypeEnum.values()).getValue(); + String clientId = randomString(); + List requestedScopes = Lists.newArrayList("read"); + // mock 方法 + when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))) + .thenReturn(randomPojo(OAuth2ClientDO.class).setAutoApproveScopes(null)); + // mock 数据 + OAuth2ApproveDO approve = randomPojo(OAuth2ApproveDO.class).setUserId(userId) + .setUserType(userType).setClientId(clientId).setScope("read") + .setExpiresTime(addTime(Duration.ofDays(1))).setApproved(true); // 同意 + oauth2ApproveMapper.insert(approve); + + // 调用 + boolean success = oauth2ApproveService.checkForPreApproval(userId, userType, + clientId, requestedScopes); + // 断言 + assertTrue(success); + } + + @Test + public void checkForPreApproval_reject() { + // 准备参数 + Long userId = randomLongId(); + Integer userType = randomEle(UserTypeEnum.values()).getValue(); + String clientId = randomString(); + List requestedScopes = Lists.newArrayList("read"); + // mock 方法 + when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))) + .thenReturn(randomPojo(OAuth2ClientDO.class).setAutoApproveScopes(null)); + // mock 数据 + OAuth2ApproveDO approve = randomPojo(OAuth2ApproveDO.class).setUserId(userId) + .setUserType(userType).setClientId(clientId).setScope("read") + .setExpiresTime(addTime(Duration.ofDays(1))).setApproved(false); // 拒绝 + oauth2ApproveMapper.insert(approve); + + // 调用 + boolean success = oauth2ApproveService.checkForPreApproval(userId, userType, + clientId, requestedScopes); + // 断言 + assertFalse(success); + } + + @Test + public void testUpdateAfterApproval_none() { + // 准备参数 + Long userId = randomLongId(); + Integer userType = randomEle(UserTypeEnum.values()).getValue(); + String clientId = randomString(); + + // 调用 + boolean success = oauth2ApproveService.updateAfterApproval(userId, userType, clientId, + null); + // 断言 + assertTrue(success); + List result = oauth2ApproveMapper.selectList(); + assertEquals(0, result.size()); + } + + @Test + public void testUpdateAfterApproval_approved() { + // 准备参数 + Long userId = randomLongId(); + Integer userType = randomEle(UserTypeEnum.values()).getValue(); + String clientId = randomString(); + Map requestedScopes = new LinkedHashMap<>(); // 有序,方便判断 + requestedScopes.put("read", true); + requestedScopes.put("write", false); + // mock 方法 + + // 调用 + boolean success = oauth2ApproveService.updateAfterApproval(userId, userType, clientId, + requestedScopes); + // 断言 + assertTrue(success); + List result = oauth2ApproveMapper.selectList(); + assertEquals(2, result.size()); + // read + assertEquals(userId, result.get(0).getUserId()); + assertEquals(userType, result.get(0).getUserType()); + assertEquals(clientId, result.get(0).getClientId()); + assertEquals("read", result.get(0).getScope()); + assertTrue(result.get(0).getApproved()); + assertFalse(DateUtils.isExpired(result.get(0).getExpiresTime())); + // write + assertEquals(userId, result.get(1).getUserId()); + assertEquals(userType, result.get(1).getUserType()); + assertEquals(clientId, result.get(1).getClientId()); + assertEquals("write", result.get(1).getScope()); + assertFalse(result.get(1).getApproved()); + assertFalse(DateUtils.isExpired(result.get(1).getExpiresTime())); + } + + @Test + public void testUpdateAfterApproval_reject() { + // 准备参数 + Long userId = randomLongId(); + Integer userType = randomEle(UserTypeEnum.values()).getValue(); + String clientId = randomString(); + Map requestedScopes = new LinkedHashMap<>(); + requestedScopes.put("write", false); + // mock 方法 + + // 调用 + boolean success = oauth2ApproveService.updateAfterApproval(userId, userType, clientId, + requestedScopes); + // 断言 + assertFalse(success); + List result = oauth2ApproveMapper.selectList(); + assertEquals(1, result.size()); + // write + assertEquals(userId, result.get(0).getUserId()); + assertEquals(userType, result.get(0).getUserType()); + assertEquals(clientId, result.get(0).getClientId()); + assertEquals("write", result.get(0).getScope()); + assertFalse(result.get(0).getApproved()); + assertFalse(DateUtils.isExpired(result.get(0).getExpiresTime())); + } + + @Test + public void testGetApproveList() { + // 准备参数 + Long userId = 10L; + Integer userType = UserTypeEnum.ADMIN.getValue(); + String clientId = randomString(); + // mock 数据 + OAuth2ApproveDO approve = randomPojo(OAuth2ApproveDO.class).setUserId(userId) + .setUserType(userType).setClientId(clientId).setExpiresTime(addTime(Duration.ofDays(1L))); + oauth2ApproveMapper.insert(approve); // 未过期 + oauth2ApproveMapper.insert(ObjectUtil.clone(approve).setId(null) + .setExpiresTime(addTime(Duration.ofDays(-1L)))); // 已过期 + + // 调用 + List result = oauth2ApproveService.getApproveList(userId, userType, clientId); + // 断言 + assertEquals(1, result.size()); + assertPojoEquals(approve, result.get(0)); + } + + @Test + public void testSaveApprove_insert() { + // 准备参数 + Long userId = randomLongId(); + Integer userType = randomEle(UserTypeEnum.values()).getValue(); + String clientId = randomString(); + String scope = randomString(); + Boolean approved = randomBoolean(); + Date expireTime = randomDay(1, 30); + // mock 方法 + + // 调用 + oauth2ApproveService.saveApprove(userId, userType, clientId, + scope, approved, expireTime); + // 断言 + List result = oauth2ApproveMapper.selectList(); + assertEquals(1, result.size()); + assertEquals(userId, result.get(0).getUserId()); + assertEquals(userType, result.get(0).getUserType()); + assertEquals(clientId, result.get(0).getClientId()); + assertEquals(scope, result.get(0).getScope()); + assertEquals(approved, result.get(0).getApproved()); + assertEquals(expireTime, result.get(0).getExpiresTime()); + } + + @Test + public void testSaveApprove_update() { + // mock 数据 + OAuth2ApproveDO approve = randomPojo(OAuth2ApproveDO.class); + oauth2ApproveMapper.insert(approve); + // 准备参数 + Long userId = approve.getUserId(); + Integer userType = approve.getUserType(); + String clientId = approve.getClientId(); + String scope = approve.getScope(); + Boolean approved = randomBoolean(); + Date expireTime = randomDay(1, 30); + // mock 方法 + + // 调用 + oauth2ApproveService.saveApprove(userId, userType, clientId, + scope, approved, expireTime); + // 断言 + List result = oauth2ApproveMapper.selectList(); + assertEquals(1, result.size()); + assertEquals(approve.getId(), result.get(0).getId()); + assertEquals(userId, result.get(0).getUserId()); + assertEquals(userType, result.get(0).getUserType()); + assertEquals(clientId, result.get(0).getClientId()); + assertEquals(scope, result.get(0).getScope()); + assertEquals(approved, result.get(0).getApproved()); + assertEquals(expireTime, result.get(0).getExpiresTime()); + } + +} diff --git a/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/clean.sql b/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/clean.sql index 8eec9fefc..217e7e28e 100644 --- a/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/clean.sql +++ b/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/clean.sql @@ -21,3 +21,4 @@ DELETE FROM "system_tenant"; DELETE FROM "system_tenant_package"; DELETE FROM "system_sensitive_word"; DELETE FROM "system_oauth2_client"; +DELETE FROM "system_oauth2_approve"; diff --git a/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/create_tables.sql b/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/create_tables.sql index b9f9bca86..5bf473db0 100644 --- a/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/create_tables.sql +++ b/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/create_tables.sql @@ -495,3 +495,19 @@ CREATE TABLE IF NOT EXISTS "system_oauth2_client" ( "deleted" bit NOT NULL DEFAULT FALSE, PRIMARY KEY ("id") ) COMMENT 'OAuth2 客户端表'; + +CREATE TABLE IF NOT EXISTS "system_oauth2_approve" ( + "id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY, + "user_id" bigint NOT NULL, + "user_type" tinyint NOT NULL, + "client_id" varchar NOT NULL, + "scope" varchar NOT NULL, + "approved" bit NOT NULL DEFAULT FALSE, + "expires_time" datetime NOT NULL, + "creator" varchar DEFAULT '', + "create_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updater" varchar DEFAULT '', + "update_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + "deleted" bit NOT NULL DEFAULT FALSE, + PRIMARY KEY ("id") +) COMMENT 'OAuth2 批准表';