From 38e88b02f52e8b4fa143a74f64d7c6408c89d387 Mon Sep 17 00:00:00 2001 From: YunaiV Date: Sat, 10 Sep 2022 21:37:16 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=E4=BF=AE=E5=A4=8D=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=9D=83=E9=99=90=EF=BC=8C=E4=B8=8D=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E9=9A=90=E5=BC=8F=E5=86=85=E8=BF=9E=E6=8E=A5=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../db/DataPermissionDatabaseInterceptor.java | 291 +++++++++++++----- ...ataPermissionDatabaseInterceptorTest2.java | 245 ++++++++++++--- .../mybatis/core/util/MyBatisUtils.java | 6 +- 3 files changed, 420 insertions(+), 122 deletions(-) diff --git a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptor.java b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptor.java index a9a4d24d4..5fc4e55d0 100644 --- a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptor.java +++ b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptor.java @@ -18,7 +18,6 @@ import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.ExistsExpression; import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.expression.operators.relational.InExpression; -import net.sf.jsqlparser.expression.operators.relational.ItemsList; import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.delete.Delete; import net.sf.jsqlparser.statement.select.*; @@ -37,7 +36,7 @@ import java.util.concurrent.ConcurrentHashMap; /** * 数据权限拦截器,通过 {@link DataPermissionRule} 数据权限规则,重写 SQL 的方式来实现 - * 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, Table)} 方法 + * 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, List)} 方法 * * 整体的代码实现上,参考 {@link com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor} 实现。 * 所以每次 MyBatis Plus 升级时,需要 Review 下其具体的实现是否有变更! @@ -53,8 +52,7 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme private final MappedStatementCache mappedStatementCache = new MappedStatementCache(); @Override // SELECT 场景 - public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, - RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) { + public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) { // 获得 Mapper 对应的数据权限的规则 List rules = ruleFactory.getDataPermissionRule(ms.getId()); if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过 @@ -68,12 +66,14 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme // 处理 SQL mpBs.sql(parserSingle(mpBs.sql(), null)); } finally { + // 添加是否需要重写的缓存 addMappedStatementCache(ms); + // 清空上下文 ContextHolder.clear(); } } - @Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景 + @Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景(因为 INSERT 不需要数据权限) public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) { PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh); MappedStatement ms = mpSh.mappedStatement(); @@ -92,7 +92,9 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme // 处理 SQL mpBs.sql(parserMulti(mpBs.sql(), null)); } finally { + // 添加是否需要重写的缓存 addMappedStatementCache(ms); + // 清空上下文 ContextHolder.clear(); } } @@ -107,24 +109,6 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme } } - protected void processSelectBody(SelectBody selectBody) { - if (selectBody == null) { - return; - } - if (selectBody instanceof PlainSelect) { - processPlainSelect((PlainSelect) selectBody); - } else if (selectBody instanceof WithItem) { - WithItem withItem = (WithItem) selectBody; - processSelectBody(withItem.getSubSelect().getSelectBody()); - } else { - SetOperationList operationList = (SetOperationList) selectBody; - List selectBodys = operationList.getSelects(); - if (CollectionUtils.isNotEmpty(selectBodys)) { - selectBodys.forEach(this::processSelectBody); - } - } - } - /** * update 语句处理 */ @@ -142,28 +126,77 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable())); } + // ========== 和 TenantLineInnerInterceptor 一致的逻辑 ========== + + protected void processSelectBody(SelectBody selectBody) { + if (selectBody == null) { + return; + } + if (selectBody instanceof PlainSelect) { + processPlainSelect((PlainSelect) selectBody); + } else if (selectBody instanceof WithItem) { + WithItem withItem = (WithItem) selectBody; + processSelectBody(withItem.getSubSelect().getSelectBody()); + } else { + SetOperationList operationList = (SetOperationList) selectBody; + List selectBodyList = operationList.getSelects(); + if (CollectionUtils.isNotEmpty(selectBodyList)) { + selectBodyList.forEach(this::processSelectBody); + } + } + } + /** * 处理 PlainSelect */ protected void processPlainSelect(PlainSelect plainSelect) { - FromItem fromItem = plainSelect.getFromItem(); - Expression where = plainSelect.getWhere(); - processWhereSubSelect(where); - if (fromItem instanceof Table) { - Table fromTable = (Table) fromItem; - plainSelect.setWhere(builderExpression(where, fromTable)); - } else { - processFromItem(fromItem); - } //#3087 github List selectItems = plainSelect.getSelectItems(); if (CollectionUtils.isNotEmpty(selectItems)) { selectItems.forEach(this::processSelectItem); } + + // 处理 where 中的子查询 + Expression where = plainSelect.getWhere(); + processWhereSubSelect(where); + + // 处理 fromItem + FromItem fromItem = plainSelect.getFromItem(); + List list = processFromItem(fromItem); + List
mainTables = new ArrayList<>(list); + + // 处理 join List joins = plainSelect.getJoins(); if (CollectionUtils.isNotEmpty(joins)) { - processJoins(joins); + mainTables = processJoins(mainTables, joins); } + + // 当有 mainTable 时,进行 where 条件追加 + if (CollectionUtils.isNotEmpty(mainTables)) { + plainSelect.setWhere(builderExpression(where, mainTables)); + } + } + + private List
processFromItem(FromItem fromItem) { + // 处理括号括起来的表达式 + while (fromItem instanceof ParenthesisFromItem) { + fromItem = ((ParenthesisFromItem) fromItem).getFromItem(); + } + + List
mainTables = new ArrayList<>(); + // 无 join 时的处理逻辑 + if (fromItem instanceof Table) { + Table fromTable = (Table) fromItem; + mainTables.add(fromTable); + } else if (fromItem instanceof SubJoin) { + // SubJoin 类型则还需要添加上 where 条件 + List
tables = processSubJoin((SubJoin) fromItem); + mainTables.addAll(tables); + } else { + // 处理下 fromItem + processOtherFromItem(fromItem); + } + return mainTables; } /** @@ -191,7 +224,7 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme return; } if (where instanceof FromItem) { - processFromItem((FromItem) where); + processOtherFromItem((FromItem) where); return; } if (where.toString().indexOf("SELECT") > 0) { @@ -204,9 +237,9 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme } else if (where instanceof InExpression) { // in InExpression expression = (InExpression) where; - ItemsList itemsList = expression.getRightItemsList(); - if (itemsList instanceof SubSelect) { - processSelectBody(((SubSelect) itemsList).getSelectBody()); + Expression inExpression = expression.getRightExpression(); + if (inExpression instanceof SubSelect) { + processSelectBody(((SubSelect) inExpression).getSelectBody()); } } else if (where instanceof ExistsExpression) { // exists @@ -239,7 +272,7 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme *

支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)

*

fixed gitee pulls/141

* - * @param function 函数 + * @param function */ protected void processFunction(Function function) { ExpressionList parameters = function.getParameters(); @@ -257,22 +290,19 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme /** * 处理子查询等 */ - protected void processFromItem(FromItem fromItem) { - if (fromItem instanceof SubJoin) { - SubJoin subJoin = (SubJoin) fromItem; - if (subJoin.getJoinList() != null) { - processJoins(subJoin.getJoinList()); - } - if (subJoin.getLeft() != null) { - processFromItem(subJoin.getLeft()); - } - } else if (fromItem instanceof SubSelect) { + protected void processOtherFromItem(FromItem fromItem) { + // 去除括号 + while (fromItem instanceof ParenthesisFromItem) { + fromItem = ((ParenthesisFromItem) fromItem).getFromItem(); + } + + if (fromItem instanceof SubSelect) { SubSelect subSelect = (SubSelect) fromItem; if (subSelect.getSelectBody() != null) { processSelectBody(subSelect.getSelectBody()); } } else if (fromItem instanceof ValuesList) { - logger.debug("Perform a subquery, if you do not give us feedback"); + logger.debug("Perform a subQuery, if you do not give us feedback"); } else if (fromItem instanceof LateralSubSelect) { LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem; if (lateralSubSelect.getSubSelect() != null) { @@ -284,75 +314,176 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme } } + /** + * 处理 sub join + * + * @param subJoin subJoin + * @return Table subJoin 中的主表 + */ + private List
processSubJoin(SubJoin subJoin) { + List
mainTables = new ArrayList<>(); + if (subJoin.getJoinList() != null) { + List
list = processFromItem(subJoin.getLeft()); + mainTables.addAll(list); + mainTables = processJoins(mainTables, subJoin.getJoinList()); + } + return mainTables; + } + /** * 处理 joins * - * @param joins join 集合 + * @param mainTables 可以为 null + * @param joins join 集合 + * @return List
右连接查询的 Table 列表 */ - private void processJoins(List joins) { + private List
processJoins(List
mainTables, List joins) { + // join 表达式中最终的主表 + Table mainTable = null; + // 当前 join 的左表 + Table leftTable = null; + + if (mainTables == null) { + mainTables = new ArrayList<>(); + } else if (mainTables.size() == 1) { + mainTable = mainTables.get(0); + leftTable = mainTable; + } + //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名 - Deque
tables = new LinkedList<>(); + Deque> onTableDeque = new LinkedList<>(); for (Join join : joins) { // 处理 on 表达式 - FromItem fromItem = join.getRightItem(); - if (fromItem instanceof Table) { - Table fromTable = (Table) fromItem; + FromItem joinItem = join.getRightItem(); + + // 获取当前 join 的表,subJoint 可以看作是一张表 + List
joinTables = null; + if (joinItem instanceof Table) { + joinTables = new ArrayList<>(); + joinTables.add((Table) joinItem); + } else if (joinItem instanceof SubJoin) { + joinTables = processSubJoin((SubJoin) joinItem); + } + + if (joinTables != null) { + + // 如果是隐式内连接 + if (join.isSimple()) { + mainTables.addAll(joinTables); + continue; + } + + // 当前表是否忽略 + Table joinTable = joinTables.get(0); + + List
onTables = null; + // 如果不要忽略,且是右连接,则记录下当前表 + if (join.isRight()) { + mainTable = joinTable; + if (leftTable != null) { + onTables = Collections.singletonList(leftTable); + } + } else if (join.isLeft()) { + onTables = Collections.singletonList(joinTable); + } else if (join.isInner()) { + if (mainTable == null) { + onTables = Collections.singletonList(joinTable); + } else { + onTables = Arrays.asList(mainTable, joinTable); + } + mainTable = null; + } + + mainTables = new ArrayList<>(); + if (mainTable != null) { + mainTables.add(mainTable); + } + // 获取 join 尾缀的 on 表达式列表 Collection originOnExpressions = join.getOnExpressions(); // 正常 join on 表达式只有一个,立刻处理 - if (originOnExpressions.size() == 1) { - processJoin(join); + if (originOnExpressions.size() == 1 && onTables != null) { + List onExpressions = new LinkedList<>(); + onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables)); + join.setOnExpressions(onExpressions); + leftTable = joinTable; continue; } - tables.push(fromTable); + // 表名压栈,忽略的表压入 null,以便后续不处理 + onTableDeque.push(onTables); // 尾缀多个 on 表达式的时候统一处理 if (originOnExpressions.size() > 1) { Collection onExpressions = new LinkedList<>(); for (Expression originOnExpression : originOnExpressions) { - Table currentTable = tables.poll(); - onExpressions.add(builderExpression(originOnExpression, currentTable)); + List
currentTableList = onTableDeque.poll(); + if (CollectionUtils.isEmpty(currentTableList)) { + onExpressions.add(originOnExpression); + } else { + onExpressions.add(builderExpression(originOnExpression, currentTableList)); + } } join.setOnExpressions(onExpressions); } + leftTable = joinTable; } else { - // 处理右边连接的子表达式 - processFromItem(fromItem); + processOtherFromItem(joinItem); + leftTable = null; } } + + return mainTables; } + // ========== 和 TenantLineInnerInterceptor 存在差异的逻辑:关键,实现权限条件的拼接 ========== + /** - * 处理联接语句 + * 处理条件 + * + * @param currentExpression 当前 where 条件 + * @param table 单个表 */ - protected void processJoin(Join join) { - if (join.getRightItem() instanceof Table) { - Table fromTable = (Table) join.getRightItem(); - Expression originOnExpression = CollUtil.getFirst(join.getOnExpressions()); - originOnExpression = builderExpression(originOnExpression, fromTable); - join.setOnExpressions(CollUtil.newArrayList(originOnExpression)); - } + protected Expression builderExpression(Expression currentExpression, Table table) { + return this.builderExpression(currentExpression, Collections.singletonList(table)); } /** * 处理条件 + * + * @param currentExpression 当前 where 条件 + * @param tables 多个表 */ - protected Expression builderExpression(Expression currentExpression, Table table) { - // 获得 Table 对应的数据权限条件 - Expression equalsTo = buildDataPermissionExpression(table); - if (equalsTo == null) { // 如果没条件,则返回 currentExpression 默认 + protected Expression builderExpression(Expression currentExpression, List
tables) { + // 没有表需要处理直接返回 + if (CollectionUtils.isEmpty(tables)) { return currentExpression; } - // 表达式为空,则直接返回 equalsTo + // 第一步,获得 Table 对应的数据权限条件 + Expression dataPermissionExpression = null; + for (Table table : tables) { + // 构建每个表的权限 Expression 条件 + Expression expression = buildDataPermissionExpression(table); + if (expression == null) { + continue; + } + // 合并到 dataPermissionExpression 中 + dataPermissionExpression = dataPermissionExpression == null ? expression + : new AndExpression(dataPermissionExpression, expression); + } + + // 第二步,合并多个 Expression 条件 + if (dataPermissionExpression == null) { + return currentExpression; + } if (currentExpression == null) { - return equalsTo; + return dataPermissionExpression; } - // 如果表达式为 Or,则需要 (currentExpression) AND equalsTo + // ① 如果表达式为 Or,则需要 (currentExpression) AND dataPermissionExpression if (currentExpression instanceof OrExpression) { - return new AndExpression(new Parenthesis(currentExpression), equalsTo); + return new AndExpression(new Parenthesis(currentExpression), dataPermissionExpression); } - // 如果表达式为 And,则直接返回 currentExpression AND equalsTo - return new AndExpression(currentExpression, equalsTo); + // ② 如果表达式为 And,则直接返回 where AND dataPermissionExpression + return new AndExpression(currentExpression, dataPermissionExpression); } /** diff --git a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest2.java b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest2.java index 8c0772f1a..b8cad13cf 100644 --- a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest2.java +++ b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest2.java @@ -46,7 +46,7 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest @Override public Set getTableNames() { - return asSet("entity", "entity1", "entity2", "t1", "t2", // 支持 MyBatis Plus 的单元测试 + return asSet("entity", "entity1", "entity2", "entity3", "t1", "t2", "sys_dict_item", // 支持 MyBatis Plus 的单元测试 "t_user", "t_role"); // 满足自己的单元测试 } @@ -84,30 +84,30 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest @Test void delete() { assertSql("delete from entity where id = ?", - "DELETE FROM entity WHERE id = ? AND tenant_id = 1"); + "DELETE FROM entity WHERE id = ? AND entity.tenant_id = 1"); } @Test void update() { assertSql("update entity set name = ? where id = ?", - "UPDATE entity SET name = ? WHERE id = ? AND tenant_id = 1"); + "UPDATE entity SET name = ? WHERE id = ? AND entity.tenant_id = 1"); } @Test void selectSingle() { // 单表 assertSql("select * from entity where id = ?", - "SELECT * FROM entity WHERE id = ? AND tenant_id = 1"); + "SELECT * FROM entity WHERE id = ? AND entity.tenant_id = 1"); assertSql("select * from entity where id = ? or name = ?", - "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1"); + "SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1"); assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)", - "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1"); + "SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1"); /* not */ assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)", - "SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND tenant_id = 1"); + "SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND entity.tenant_id = 1"); } @Test @@ -167,10 +167,12 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest assertSql("SELECT * FROM entity e WHERE e.id >= (select e1.id from entity1 e1 where e1.id = ?)", "SELECT * FROM entity e WHERE e.id >= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1"); + /* <= */ assertSql("SELECT * FROM entity e WHERE e.id <= (select e1.id from entity1 e1 where e1.id = ?)", "SELECT * FROM entity e WHERE e.id <= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1"); + /* <> */ assertSql("SELECT * FROM entity e WHERE e.id <> (select e1.id from entity1 e1 where e1.id = ?)", "SELECT * FROM entity e WHERE e.id <> (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1"); @@ -204,6 +206,14 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest "SELECT * FROM entity e " + "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); + + assertSql("SELECT * FROM entity e " + + "left join entity1 e1 on e1.id = e.id " + + "left join entity2 e2 on e1.id = e2.id", + "SELECT * FROM entity e " + + "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + + "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " + + "WHERE e.tenant_id = 1"); } @Test @@ -212,17 +222,125 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest assertSql("SELECT * FROM entity e " + "right join entity1 e1 on e1.id = e.id", "SELECT * FROM entity e " + - "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE e.tenant_id = 1"); + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " + + "WHERE e1.tenant_id = 1"); + + assertSql("SELECT * FROM with_as_1 e " + + "right join entity1 e1 on e1.id = e.id", + "SELECT * FROM with_as_1 e " + + "RIGHT JOIN entity1 e1 ON e1.id = e.id " + + "WHERE e1.tenant_id = 1"); assertSql("SELECT * FROM entity e " + "right join entity1 e1 on e1.id = e.id " + "WHERE e.id = ? OR e.name = ?", "SELECT * FROM entity e " + - "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " + + "WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1"); + + assertSql("SELECT * FROM entity e " + + "right join entity1 e1 on e1.id = e.id " + + "right join entity2 e2 on e1.id = e2.id ", + "SELECT * FROM entity e " + + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " + + "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " + + "WHERE e2.tenant_id = 1"); } + @Test + void selectMixJoin() { + assertSql("SELECT * FROM entity e " + + "right join entity1 e1 on e1.id = e.id " + + "left join entity2 e2 on e1.id = e2.id", + "SELECT * FROM entity e " + + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " + + "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " + + "WHERE e1.tenant_id = 1"); + + assertSql("SELECT * FROM entity e " + + "left join entity1 e1 on e1.id = e.id " + + "right join entity2 e2 on e1.id = e2.id", + "SELECT * FROM entity e " + + "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + + "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " + + "WHERE e2.tenant_id = 1"); + + assertSql("SELECT * FROM entity e " + + "left join entity1 e1 on e1.id = e.id " + + "inner join entity2 e2 on e1.id = e2.id", + "SELECT * FROM entity e " + + "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + + "INNER JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 AND e2.tenant_id = 1"); + } + + + @Test + void selectJoinSubSelect() { + assertSql("select * from (select * from entity) e1 " + + "left join entity2 e2 on e1.id = e2.id", + "SELECT * FROM (SELECT * FROM entity WHERE entity.tenant_id = 1) e1 " + + "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1"); + + assertSql("select * from entity1 e1 " + + "left join (select * from entity2) e2 " + + "on e1.id = e2.id", + "SELECT * FROM entity1 e1 " + + "LEFT JOIN (SELECT * FROM entity2 WHERE entity2.tenant_id = 1) e2 " + + "ON e1.id = e2.id " + + "WHERE e1.tenant_id = 1"); + } + + @Test + void selectSubJoin() { + + assertSql("select * FROM " + + "(entity1 e1 right JOIN entity2 e2 ON e1.id = e2.id)", + "SELECT * FROM " + + "(entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " + + "WHERE e2.tenant_id = 1"); + + assertSql("select * FROM " + + "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id)", + "SELECT * FROM " + + "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " + + "WHERE e1.tenant_id = 1"); + + + assertSql("select * FROM " + + "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id) " + + "right join entity3 e3 on e1.id = e3.id", + "SELECT * FROM " + + "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " + + "RIGHT JOIN entity3 e3 ON e1.id = e3.id AND e1.tenant_id = 1 " + + "WHERE e3.tenant_id = 1"); + + + assertSql("select * FROM entity e " + + "LEFT JOIN (entity1 e1 right join entity2 e2 ON e1.id = e2.id) " + + "on e.id = e2.id", + "SELECT * FROM entity e " + + "LEFT JOIN (entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " + + "ON e.id = e2.id AND e2.tenant_id = 1 " + + "WHERE e.tenant_id = 1"); + + assertSql("select * FROM entity e " + + "LEFT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " + + "on e.id = e2.id", + "SELECT * FROM entity e " + + "LEFT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " + + "ON e.id = e2.id AND e1.tenant_id = 1 " + + "WHERE e.tenant_id = 1"); + + assertSql("select * FROM entity e " + + "RIGHT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " + + "on e.id = e2.id", + "SELECT * FROM entity e " + + "RIGHT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " + + "ON e.id = e2.id AND e.tenant_id = 1 " + + "WHERE e1.tenant_id = 1"); + } + + @Test void selectLeftJoinMultipleTrailingOn() { // 多个 on 尾缀的 @@ -256,51 +374,97 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest "inner join entity1 e1 on e1.id = e.id " + "WHERE e.id = ? OR e.name = ?", "SELECT * FROM entity e " + - "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); + "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " + + "WHERE e.id = ? OR e.name = ?"); assertSql("SELECT * FROM entity e " + "inner join entity1 e1 on e1.id = e.id " + "WHERE (e.id = ? OR e.name = ?)", "SELECT * FROM entity e " + - "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); + "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " + + "WHERE (e.id = ? OR e.name = ?)"); + + // 隐式内连接 + assertSql("SELECT * FROM entity,entity1 " + + "WHERE entity.id = entity1.id", + "SELECT * FROM entity, entity1 " + + "WHERE entity.id = entity1.id AND entity.tenant_id = 1 AND entity1.tenant_id = 1"); + + // 隐式内连接 + assertSql("SELECT * FROM entity a, with_as_entity1 b " + + "WHERE a.id = b.id", + "SELECT * FROM entity a, with_as_entity1 b " + + "WHERE a.id = b.id AND a.tenant_id = 1"); + + assertSql("SELECT * FROM with_as_entity a, with_as_entity1 b " + + "WHERE a.id = b.id", + "SELECT * FROM with_as_entity a, with_as_entity1 b " + + "WHERE a.id = b.id"); + + // SubJoin with 隐式内连接 + assertSql("SELECT * FROM (entity,entity1) " + + "WHERE entity.id = entity1.id", + "SELECT * FROM (entity, entity1) " + + "WHERE entity.id = entity1.id " + + "AND entity.tenant_id = 1 AND entity1.tenant_id = 1"); + + assertSql("SELECT * FROM ((entity,entity1),entity2) " + + "WHERE entity.id = entity1.id and entity.id = entity2.id", + "SELECT * FROM ((entity, entity1), entity2) " + + "WHERE entity.id = entity1.id AND entity.id = entity2.id " + + "AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1"); + + assertSql("SELECT * FROM (entity,(entity1,entity2)) " + + "WHERE entity.id = entity1.id and entity.id = entity2.id", + "SELECT * FROM (entity, (entity1, entity2)) " + + "WHERE entity.id = entity1.id AND entity.id = entity2.id " + + "AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1"); + + // 沙雕的括号写法 + assertSql("SELECT * FROM (((entity,entity1))) " + + "WHERE entity.id = entity1.id", + "SELECT * FROM (((entity, entity1))) " + + "WHERE entity.id = entity1.id " + + "AND entity.tenant_id = 1 AND entity1.tenant_id = 1"); - // 垃圾 inner join todo -// assertSql("SELECT * FROM entity,entity1 " + -// "WHERE entity.id = entity1.id", -// "SELECT * FROM entity e " + -// "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + -// "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); } + @Test void selectWithAs() { assertSql("with with_as_A as (select * from entity) select * from with_as_A", - "WITH with_as_A AS (SELECT * FROM entity WHERE tenant_id = 1) SELECT * FROM with_as_A"); + "WITH with_as_A AS (SELECT * FROM entity WHERE entity.tenant_id = 1) SELECT * FROM with_as_A"); + } + + + @Test + void selectIgnoreTable() { + assertSql(" SELECT dict.dict_code, item.item_text AS \"text\", item.item_value AS \"value\" FROM sys_dict_item item INNER JOIN sys_dict dict ON dict.id = item.dict_id WHERE dict.dict_code IN (1, 2, 3) AND item.item_value IN (1, 2, 3)", + "SELECT dict.dict_code, item.item_text AS \"text\", item.item_value AS \"value\" FROM sys_dict_item item INNER JOIN sys_dict dict ON dict.id = item.dict_id AND item.tenant_id = 1 WHERE dict.dict_code IN (1, 2, 3) AND item.item_value IN (1, 2, 3)"); } private void assertSql(String sql, String targetSql) { assertEquals(targetSql, interceptor.parserSingle(sql, null)); } + // ========== 额外的测试 ========== @Test public void testSelectSingle() { // 单表 assertSql("select * from t_user where id = ?", - "SELECT * FROM t_user WHERE id = ? AND tenant_id = 1 AND dept_id IN (10, 20)"); + "SELECT * FROM t_user WHERE id = ? AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)"); assertSql("select * from t_user where id = ? or name = ?", - "SELECT * FROM t_user WHERE (id = ? OR name = ?) AND tenant_id = 1 AND dept_id IN (10, 20)"); + "SELECT * FROM t_user WHERE (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)"); assertSql("SELECT * FROM t_user WHERE (id = ? OR name = ?)", - "SELECT * FROM t_user WHERE (id = ? OR name = ?) AND tenant_id = 1 AND dept_id IN (10, 20)"); + "SELECT * FROM t_user WHERE (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)"); /* not */ assertSql("SELECT * FROM t_user WHERE not (id = ? OR name = ?)", - "SELECT * FROM t_user WHERE NOT (id = ? OR name = ?) AND tenant_id = 1 AND dept_id IN (10, 20)"); + "SELECT * FROM t_user WHERE NOT (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)"); } @Test @@ -329,16 +493,16 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest "right join t_role e1 on e1.id = e.id " + "WHERE e.id = ? OR e.name = ?", "SELECT * FROM t_user e " + - "RIGHT JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)"); + "RIGHT JOIN t_role e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) " + + "WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1"); // 条件 e.id = ? OR e.name = ? 带括号 assertSql("SELECT * FROM t_user e " + "right join t_role e1 on e1.id = e.id " + "WHERE (e.id = ? OR e.name = ?)", "SELECT * FROM t_user e " + - "RIGHT JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)"); + "RIGHT JOIN t_role e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) " + + "WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1"); } @Test @@ -348,23 +512,22 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest "inner join entity1 e1 on e1.id = e.id " + "WHERE e.id = ? OR e.name = ?", "SELECT * FROM t_user e " + - "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)"); + "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) AND e1.tenant_id = 1 " + + "WHERE e.id = ? OR e.name = ?"); // 条件 e.id = ? OR e.name = ? 带括号 assertSql("SELECT * FROM t_user e " + - "inner join t_role e1 on e1.id = e.id " + + "inner join entity1 e1 on e1.id = e.id " + "WHERE (e.id = ? OR e.name = ?)", "SELECT * FROM t_user e " + - "INNER JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)"); + "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) AND e1.tenant_id = 1 " + + "WHERE (e.id = ? OR e.name = ?)"); - // 垃圾 inner join todo -// assertSql("SELECT * FROM entity,entity1 " + -// "WHERE entity.id = entity1.id", -// "SELECT * FROM entity e " + -// "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + -// "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); + // 没有 On 的 inner join + assertSql("SELECT * FROM entity,entity1 " + + "WHERE entity.id = entity1.id", + "SELECT * FROM entity, entity1 " + + "WHERE entity.id = entity1.id AND entity.tenant_id = 1 AND entity1.tenant_id = 1"); } } diff --git a/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/util/MyBatisUtils.java b/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/util/MyBatisUtils.java index 16d3b29b1..0b1b01b08 100644 --- a/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/util/MyBatisUtils.java +++ b/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/util/MyBatisUtils.java @@ -4,6 +4,7 @@ import cn.hutool.core.collection.CollectionUtil; import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.SortingField; import com.baomidou.mybatisplus.core.metadata.OrderItem; +import com.baomidou.mybatisplus.core.toolkit.StringPool; import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor; import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; @@ -78,7 +79,10 @@ public class MyBatisUtils { * @return Column 对象 */ public static Column buildColumn(String tableName, Alias tableAlias, String column) { - return new Column(tableAlias != null ? tableAlias.getName() + "." + column : column); + if (tableAlias != null) { + tableName = tableAlias.getName(); + } + return new Column(tableName + StringPool.DOT + column); } }