diff --git a/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/dataobject/oauth2/OAuth2AccessTokenDO.java b/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/dataobject/oauth2/OAuth2AccessTokenDO.java index 8cecb8f1a..5c2340e51 100644 --- a/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/dataobject/oauth2/OAuth2AccessTokenDO.java +++ b/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/dataobject/oauth2/OAuth2AccessTokenDO.java @@ -9,7 +9,6 @@ import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler; import lombok.Data; import lombok.EqualsAndHashCode; -import lombok.experimental.Accessors; import java.util.Date; import java.util.List; @@ -22,11 +21,10 @@ import java.util.List; * * @author 芋道源码 */ -@TableName("system_oauth2_access_token") +@TableName(value = "system_oauth2_access_token", autoResultMap = true) @KeySequence("system_oauth2_access_token_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。 @Data @EqualsAndHashCode(callSuper = true) -@Accessors(chain = true) public class OAuth2AccessTokenDO extends TenantBaseDO { /** diff --git a/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/dataobject/oauth2/OAuth2RefreshTokenDO.java b/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/dataobject/oauth2/OAuth2RefreshTokenDO.java index a933e1c52..deb7d18aa 100644 --- a/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/dataobject/oauth2/OAuth2RefreshTokenDO.java +++ b/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/dataobject/oauth2/OAuth2RefreshTokenDO.java @@ -18,7 +18,7 @@ import java.util.List; * * @author 芋道源码 */ -@TableName("system_oauth2_refresh_token") +@TableName(value = "system_oauth2_refresh_token", autoResultMap = true) // 由于 Oracle 的 SEQ 的名字长度有限制,所以就先用 system_oauth2_access_token_seq 吧,反正也没啥问题 @KeySequence("system_oauth2_access_token_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。 @Data diff --git a/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/mysql/oauth2/OAuth2AccessTokenMapper.java b/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/mysql/oauth2/OAuth2AccessTokenMapper.java index 3c91b4e67..fc250d56a 100644 --- a/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/mysql/oauth2/OAuth2AccessTokenMapper.java +++ b/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/mysql/oauth2/OAuth2AccessTokenMapper.java @@ -25,7 +25,7 @@ public interface OAuth2AccessTokenMapper extends BaseMapperX() .eqIfPresent(OAuth2AccessTokenDO::getUserId, reqVO.getUserId()) .eqIfPresent(OAuth2AccessTokenDO::getUserType, reqVO.getUserType()) - .eqIfPresent(OAuth2AccessTokenDO::getClientId, reqVO.getClientId()) + .likeIfPresent(OAuth2AccessTokenDO::getClientId, reqVO.getClientId()) .gt(OAuth2AccessTokenDO::getExpiresTime, new Date()) .orderByDesc(OAuth2AccessTokenDO::getId)); } diff --git a/yudao-module-system/yudao-module-system-biz/src/test/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2TokenServiceImplTest.java b/yudao-module-system/yudao-module-system-biz/src/test/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2TokenServiceImplTest.java new file mode 100644 index 000000000..090ab89ec --- /dev/null +++ b/yudao-module-system/yudao-module-system-biz/src/test/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2TokenServiceImplTest.java @@ -0,0 +1,289 @@ +package cn.iocoder.yudao.module.system.service.oauth2; + +import cn.hutool.core.util.RandomUtil; +import cn.iocoder.yudao.framework.common.enums.UserTypeEnum; +import cn.iocoder.yudao.framework.common.exception.ErrorCode; +import cn.iocoder.yudao.framework.common.pojo.PageResult; +import cn.iocoder.yudao.framework.common.util.date.DateUtils; +import cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder; +import cn.iocoder.yudao.framework.test.core.ut.BaseDbAndRedisUnitTest; +import cn.iocoder.yudao.module.system.controller.admin.oauth2.vo.token.OAuth2AccessTokenPageReqVO; +import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2AccessTokenDO; +import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ClientDO; +import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2RefreshTokenDO; +import cn.iocoder.yudao.module.system.dal.mysql.oauth2.OAuth2AccessTokenMapper; +import cn.iocoder.yudao.module.system.dal.mysql.oauth2.OAuth2RefreshTokenMapper; +import cn.iocoder.yudao.module.system.dal.redis.oauth2.OAuth2AccessTokenRedisDAO; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.Test; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.context.annotation.Import; + +import javax.annotation.Resource; +import java.time.Duration; +import java.util.Date; +import java.util.List; + +import static cn.iocoder.yudao.framework.common.util.date.DateUtils.addTime; +import static cn.iocoder.yudao.framework.common.util.object.ObjectUtils.cloneIgnoreId; +import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertPojoEquals; +import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertServiceException; +import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.*; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +/** + * {@link OAuth2TokenServiceImpl} 的单元测试类 + * + * @author 芋道源码 + */ +@Import({OAuth2TokenServiceImpl.class, OAuth2AccessTokenRedisDAO.class}) +public class OAuth2TokenServiceImplTest extends BaseDbAndRedisUnitTest { + + @Resource + private OAuth2TokenServiceImpl oauth2TokenService; + + @Resource + private OAuth2AccessTokenMapper oauth2AccessTokenMapper; + @Resource + private OAuth2RefreshTokenMapper oauth2RefreshTokenMapper; + + @Resource + private OAuth2AccessTokenRedisDAO oauth2AccessTokenRedisDAO; + + @MockBean + private OAuth2ClientService oauth2ClientService; + + @Test + public void testCreateAccessToken() { + TenantContextHolder.setTenantId(0L); + // 准备参数 + Long userId = randomLongId(); + Integer userType = RandomUtil.randomEle(UserTypeEnum.values()).getValue(); + String clientId = randomString(); + List scopes = Lists.newArrayList("read", "write"); + // mock 方法 + OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId) + .setAccessTokenValiditySeconds(30).setRefreshTokenValiditySeconds(60); + when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO); + + // 调用 + OAuth2AccessTokenDO accessTokenDO = oauth2TokenService.createAccessToken(userId, userType, clientId, scopes); + // 断言访问令牌 + OAuth2AccessTokenDO dbAccessTokenDO = oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken()); + assertPojoEquals(accessTokenDO, dbAccessTokenDO, "createTime", "updateTime", "deleted"); + assertEquals(userId, accessTokenDO.getUserId()); + assertEquals(userType, accessTokenDO.getUserType()); + assertEquals(clientId, accessTokenDO.getClientId()); + assertEquals(scopes, accessTokenDO.getScopes()); + assertFalse(DateUtils.isExpired(accessTokenDO.getExpiresTime())); + // 断言访问令牌的缓存 + OAuth2AccessTokenDO redisAccessTokenDO = oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken()); + assertPojoEquals(accessTokenDO, redisAccessTokenDO, "createTime", "updateTime", "deleted"); + // 断言刷新令牌 + OAuth2RefreshTokenDO refreshTokenDO = oauth2RefreshTokenMapper.selectList().get(0); + assertPojoEquals(accessTokenDO, refreshTokenDO, "id", "expiresTime", "createTime", "updateTime", "deleted"); + assertFalse(DateUtils.isExpired(refreshTokenDO.getExpiresTime())); + } + + @Test + public void testRefreshAccessToken_null() { + // 准备参数 + String refreshToken = randomString(); + String clientId = randomString(); + // mock 方法 + + // 调用,并断言 + assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId), + new ErrorCode(400, "无效的刷新令牌")); + } + + @Test + public void testRefreshAccessToken_clientIdError() { + // 准备参数 + String refreshToken = randomString(); + String clientId = randomString(); + // mock 方法 + OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId); + when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO); + // mock 数据(访问令牌) + OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class) + .setRefreshToken(refreshToken).setClientId("error"); + oauth2RefreshTokenMapper.insert(refreshTokenDO); + + // 调用,并断言 + assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId), + new ErrorCode(400, "刷新令牌的客户端编号不正确")); + } + + @Test + public void testRefreshAccessToken_expired() { + // 准备参数 + String refreshToken = randomString(); + String clientId = randomString(); + // mock 方法 + OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId); + when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO); + // mock 数据(访问令牌) + OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class) + .setRefreshToken(refreshToken).setClientId(clientId) + .setExpiresTime(addTime(Duration.ofDays(-1))); + oauth2RefreshTokenMapper.insert(refreshTokenDO); + + // 调用,并断言 + assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId), + new ErrorCode(401, "刷新令牌已过期")); + } + + @Test + public void testRefreshAccessToken_success() { + TenantContextHolder.setTenantId(0L); + // 准备参数 + String refreshToken = randomString(); + String clientId = randomString(); + // mock 方法 + OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId) + .setAccessTokenValiditySeconds(30); + when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO); + // mock 数据(访问令牌) + OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class) + .setRefreshToken(refreshToken).setClientId(clientId) + .setExpiresTime(addTime(Duration.ofDays(1))); + oauth2RefreshTokenMapper.insert(refreshTokenDO); + // mock 数据(访问令牌) + OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class).setRefreshToken(refreshToken); + oauth2AccessTokenMapper.insert(accessTokenDO); + oauth2AccessTokenRedisDAO.set(accessTokenDO); + + // 调用 + OAuth2AccessTokenDO newAccessTokenDO = oauth2TokenService.refreshAccessToken(refreshToken, clientId); + // 断言,老的访问令牌被删除 + assertNull(oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken())); + assertNull(oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken())); + // 断言,新的访问令牌 + OAuth2AccessTokenDO dbAccessTokenDO = oauth2AccessTokenMapper.selectByAccessToken(newAccessTokenDO.getAccessToken()); + assertPojoEquals(newAccessTokenDO, dbAccessTokenDO, "createTime", "updateTime", "deleted"); + assertPojoEquals(newAccessTokenDO, refreshTokenDO, "id", "expiresTime", "createTime", "updateTime", "deleted", + "creator", "updater"); + assertFalse(DateUtils.isExpired(newAccessTokenDO.getExpiresTime())); + // 断言,新的访问令牌的缓存 + OAuth2AccessTokenDO redisAccessTokenDO = oauth2AccessTokenRedisDAO.get(newAccessTokenDO.getAccessToken()); + assertPojoEquals(newAccessTokenDO, redisAccessTokenDO, "createTime", "updateTime", "deleted"); + } + + @Test + public void testGetAccessToken() { + // mock 数据(访问令牌) + OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class) + .setExpiresTime(addTime(Duration.ofDays(1))); + oauth2AccessTokenMapper.insert(accessTokenDO); + // 准备参数 + String accessToken = accessTokenDO.getAccessToken(); + + // 调用 + OAuth2AccessTokenDO result = oauth2TokenService.getAccessToken(accessToken); + // 断言 + assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted", + "creator", "updater"); + assertPojoEquals(accessTokenDO, oauth2AccessTokenRedisDAO.get(accessToken), "createTime", "updateTime", "deleted", + "creator", "updater"); + } + + @Test + public void testCheckAccessToken_null() { + // 调研,并断言 + assertServiceException(() -> oauth2TokenService.checkAccessToken(randomString()), + new ErrorCode(401, "访问令牌不存在")); + } + + @Test + public void testCheckAccessToken_expired() { + // mock 数据(访问令牌) + OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class) + .setExpiresTime(addTime(Duration.ofDays(-1))); + oauth2AccessTokenMapper.insert(accessTokenDO); + // 准备参数 + String accessToken = accessTokenDO.getAccessToken(); + + // 调研,并断言 + assertServiceException(() -> oauth2TokenService.checkAccessToken(accessToken), + new ErrorCode(401, "访问令牌已过期")); + } + + @Test + public void testCheckAccessToken_success() { + // mock 数据(访问令牌) + OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class) + .setExpiresTime(addTime(Duration.ofDays(1))); + oauth2AccessTokenMapper.insert(accessTokenDO); + // 准备参数 + String accessToken = accessTokenDO.getAccessToken(); + + // 调研,并断言 + OAuth2AccessTokenDO result = oauth2TokenService.getAccessToken(accessToken); + // 断言 + assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted", + "creator", "updater"); + } + + @Test + public void testRemoveAccessToken_null() { + // 调用,并断言 + assertNull(oauth2TokenService.removeAccessToken(randomString())); + } + + @Test + public void testRemoveAccessToken_success() { + // mock 数据(访问令牌) + OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class) + .setExpiresTime(addTime(Duration.ofDays(1))); + oauth2AccessTokenMapper.insert(accessTokenDO); + // mock 数据(刷新令牌) + OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class) + .setRefreshToken(accessTokenDO.getRefreshToken()); + oauth2RefreshTokenMapper.insert(refreshTokenDO); + // 调用 + OAuth2AccessTokenDO result = oauth2TokenService.removeAccessToken(accessTokenDO.getAccessToken()); + assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted", + "creator", "updater"); + // 断言数据 + assertNull(oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken())); + assertNull(oauth2RefreshTokenMapper.selectByRefreshToken(accessTokenDO.getRefreshToken())); + assertNull(oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken())); + } + + + @Test + public void testGetAccessTokenPage() { + // mock 数据 + OAuth2AccessTokenDO dbAccessToken = randomPojo(OAuth2AccessTokenDO.class, o -> { // 等会查询到 + o.setUserId(10L); + o.setUserType(1); + o.setClientId("test_client"); + o.setExpiresTime(DateUtils.addTime(Duration.ofDays(1))); + }); + oauth2AccessTokenMapper.insert(dbAccessToken); + // 测试 userId 不匹配 + oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setUserId(20L))); + // 测试 userType 不匹配 + oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setUserType(2))); + // 测试 userType 不匹配 + oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setClientId("it_client"))); + // 测试 expireTime 不匹配 + oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setExpiresTime(new Date()))); + // 准备参数 + OAuth2AccessTokenPageReqVO reqVO = new OAuth2AccessTokenPageReqVO(); + reqVO.setUserId(10L); + reqVO.setUserType(1); + reqVO.setClientId("test"); + + // 调用 + PageResult pageResult = oauth2TokenService.getAccessTokenPage(reqVO); + // 断言 + assertEquals(1, pageResult.getTotal()); + assertEquals(1, pageResult.getList().size()); + assertPojoEquals(dbAccessToken, pageResult.getList().get(0)); + } + +} diff --git a/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/clean.sql b/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/clean.sql index 217e7e28e..8f882fa7a 100644 --- a/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/clean.sql +++ b/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/clean.sql @@ -22,3 +22,5 @@ DELETE FROM "system_tenant_package"; DELETE FROM "system_sensitive_word"; DELETE FROM "system_oauth2_client"; DELETE FROM "system_oauth2_approve"; +DELETE FROM "system_oauth2_access_token"; +DELETE FROM "system_oauth2_refresh_token"; diff --git a/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/create_tables.sql b/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/create_tables.sql index 5bf473db0..374b5104a 100644 --- a/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/create_tables.sql +++ b/yudao-module-system/yudao-module-system-biz/src/test/resources/sql/create_tables.sql @@ -511,3 +511,39 @@ CREATE TABLE IF NOT EXISTS "system_oauth2_approve" ( "deleted" bit NOT NULL DEFAULT FALSE, PRIMARY KEY ("id") ) COMMENT 'OAuth2 批准表'; + +CREATE TABLE IF NOT EXISTS "system_oauth2_access_token" ( + "id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY, + "user_id" bigint NOT NULL, + "user_type" tinyint NOT NULL, + "access_token" varchar NOT NULL, + "refresh_token" varchar NOT NULL, + "client_id" varchar NOT NULL, + "scopes" varchar NOT NULL, + "approved" bit NOT NULL DEFAULT FALSE, + "expires_time" datetime NOT NULL, + "creator" varchar DEFAULT '', + "create_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updater" varchar DEFAULT '', + "update_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + "deleted" bit NOT NULL DEFAULT FALSE, + "tenant_id" bigint NOT NULL, + PRIMARY KEY ("id") +) COMMENT 'OAuth2 访问令牌'; + +CREATE TABLE IF NOT EXISTS "system_oauth2_refresh_token" ( + "id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY, + "user_id" bigint NOT NULL, + "user_type" tinyint NOT NULL, + "refresh_token" varchar NOT NULL, + "client_id" varchar NOT NULL, + "scopes" varchar NOT NULL, + "approved" bit NOT NULL DEFAULT FALSE, + "expires_time" datetime NOT NULL, + "creator" varchar DEFAULT '', + "create_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updater" varchar DEFAULT '', + "update_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + "deleted" bit NOT NULL DEFAULT FALSE, + PRIMARY KEY ("id") +) COMMENT 'OAuth2 刷新令牌';