🐛修复数据权限,不支持隐式内连接的问题

This commit is contained in:
YunaiV 2022-09-10 21:37:16 +08:00
parent 61b0624a59
commit 38e88b02f5
3 changed files with 420 additions and 122 deletions

View File

@ -18,7 +18,6 @@ import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression; import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression; import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.ItemsList;
import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete; import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.*; import net.sf.jsqlparser.statement.select.*;
@ -37,7 +36,7 @@ import java.util.concurrent.ConcurrentHashMap;
/** /**
* 数据权限拦截器通过 {@link DataPermissionRule} 数据权限规则重写 SQL 的方式来实现 * 数据权限拦截器通过 {@link DataPermissionRule} 数据权限规则重写 SQL 的方式来实现
* 主要的 SQL 重写方法可见 {@link #builderExpression(Expression, Table)} 方法 * 主要的 SQL 重写方法可见 {@link #builderExpression(Expression, List)} 方法
* *
* 整体的代码实现上参考 {@link com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor} 实现 * 整体的代码实现上参考 {@link com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor} 实现
* 所以每次 MyBatis Plus 升级时需要 Review 下其具体的实现是否有变更 * 所以每次 MyBatis Plus 升级时需要 Review 下其具体的实现是否有变更
@ -53,8 +52,7 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
private final MappedStatementCache mappedStatementCache = new MappedStatementCache(); private final MappedStatementCache mappedStatementCache = new MappedStatementCache();
@Override // SELECT 场景 @Override // SELECT 场景
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
// 获得 Mapper 对应的数据权限的规则 // 获得 Mapper 对应的数据权限的规则
List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId()); List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写则跳过 if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写则跳过
@ -68,12 +66,14 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
// 处理 SQL // 处理 SQL
mpBs.sql(parserSingle(mpBs.sql(), null)); mpBs.sql(parserSingle(mpBs.sql(), null));
} finally { } finally {
// 添加是否需要重写的缓存
addMappedStatementCache(ms); addMappedStatementCache(ms);
// 清空上下文
ContextHolder.clear(); ContextHolder.clear();
} }
} }
@Override // 只处理 UPDATE / DELETE 场景不处理 INSERT 场景 @Override // 只处理 UPDATE / DELETE 场景不处理 INSERT 场景因为 INSERT 不需要数据权限)
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) { public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh); PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement(); MappedStatement ms = mpSh.mappedStatement();
@ -92,7 +92,9 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
// 处理 SQL // 处理 SQL
mpBs.sql(parserMulti(mpBs.sql(), null)); mpBs.sql(parserMulti(mpBs.sql(), null));
} finally { } finally {
// 添加是否需要重写的缓存
addMappedStatementCache(ms); addMappedStatementCache(ms);
// 清空上下文
ContextHolder.clear(); ContextHolder.clear();
} }
} }
@ -107,24 +109,6 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
} }
} }
protected void processSelectBody(SelectBody selectBody) {
if (selectBody == null) {
return;
}
if (selectBody instanceof PlainSelect) {
processPlainSelect((PlainSelect) selectBody);
} else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
processSelectBody(withItem.getSubSelect().getSelectBody());
} else {
SetOperationList operationList = (SetOperationList) selectBody;
List<SelectBody> selectBodys = operationList.getSelects();
if (CollectionUtils.isNotEmpty(selectBodys)) {
selectBodys.forEach(this::processSelectBody);
}
}
}
/** /**
* update 语句处理 * update 语句处理
*/ */
@ -142,28 +126,77 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable())); delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable()));
} }
// ========== TenantLineInnerInterceptor 一致的逻辑 ==========
protected void processSelectBody(SelectBody selectBody) {
if (selectBody == null) {
return;
}
if (selectBody instanceof PlainSelect) {
processPlainSelect((PlainSelect) selectBody);
} else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
processSelectBody(withItem.getSubSelect().getSelectBody());
} else {
SetOperationList operationList = (SetOperationList) selectBody;
List<SelectBody> selectBodyList = operationList.getSelects();
if (CollectionUtils.isNotEmpty(selectBodyList)) {
selectBodyList.forEach(this::processSelectBody);
}
}
}
/** /**
* 处理 PlainSelect * 处理 PlainSelect
*/ */
protected void processPlainSelect(PlainSelect plainSelect) { protected void processPlainSelect(PlainSelect plainSelect) {
FromItem fromItem = plainSelect.getFromItem();
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
plainSelect.setWhere(builderExpression(where, fromTable));
} else {
processFromItem(fromItem);
}
//#3087 github //#3087 github
List<SelectItem> selectItems = plainSelect.getSelectItems(); List<SelectItem> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isNotEmpty(selectItems)) { if (CollectionUtils.isNotEmpty(selectItems)) {
selectItems.forEach(this::processSelectItem); selectItems.forEach(this::processSelectItem);
} }
// 处理 where 中的子查询
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
// 处理 fromItem
FromItem fromItem = plainSelect.getFromItem();
List<Table> list = processFromItem(fromItem);
List<Table> mainTables = new ArrayList<>(list);
// 处理 join
List<Join> joins = plainSelect.getJoins(); List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) { if (CollectionUtils.isNotEmpty(joins)) {
processJoins(joins); mainTables = processJoins(mainTables, joins);
} }
// 当有 mainTable 进行 where 条件追加
if (CollectionUtils.isNotEmpty(mainTables)) {
plainSelect.setWhere(builderExpression(where, mainTables));
}
}
private List<Table> processFromItem(FromItem fromItem) {
// 处理括号括起来的表达式
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}
List<Table> mainTables = new ArrayList<>();
// join 时的处理逻辑
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
mainTables.add(fromTable);
} else if (fromItem instanceof SubJoin) {
// SubJoin 类型则还需要添加上 where 条件
List<Table> tables = processSubJoin((SubJoin) fromItem);
mainTables.addAll(tables);
} else {
// 处理下 fromItem
processOtherFromItem(fromItem);
}
return mainTables;
} }
/** /**
@ -191,7 +224,7 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
return; return;
} }
if (where instanceof FromItem) { if (where instanceof FromItem) {
processFromItem((FromItem) where); processOtherFromItem((FromItem) where);
return; return;
} }
if (where.toString().indexOf("SELECT") > 0) { if (where.toString().indexOf("SELECT") > 0) {
@ -204,9 +237,9 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
} else if (where instanceof InExpression) { } else if (where instanceof InExpression) {
// in // in
InExpression expression = (InExpression) where; InExpression expression = (InExpression) where;
ItemsList itemsList = expression.getRightItemsList(); Expression inExpression = expression.getRightExpression();
if (itemsList instanceof SubSelect) { if (inExpression instanceof SubSelect) {
processSelectBody(((SubSelect) itemsList).getSelectBody()); processSelectBody(((SubSelect) inExpression).getSelectBody());
} }
} else if (where instanceof ExistsExpression) { } else if (where instanceof ExistsExpression) {
// exists // exists
@ -239,7 +272,7 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
* <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p> * <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
* <p> fixed gitee pulls/141</p> * <p> fixed gitee pulls/141</p>
* *
* @param function 函数 * @param function
*/ */
protected void processFunction(Function function) { protected void processFunction(Function function) {
ExpressionList parameters = function.getParameters(); ExpressionList parameters = function.getParameters();
@ -257,22 +290,19 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
/** /**
* 处理子查询等 * 处理子查询等
*/ */
protected void processFromItem(FromItem fromItem) { protected void processOtherFromItem(FromItem fromItem) {
if (fromItem instanceof SubJoin) { // 去除括号
SubJoin subJoin = (SubJoin) fromItem; while (fromItem instanceof ParenthesisFromItem) {
if (subJoin.getJoinList() != null) { fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
processJoins(subJoin.getJoinList()); }
}
if (subJoin.getLeft() != null) { if (fromItem instanceof SubSelect) {
processFromItem(subJoin.getLeft());
}
} else if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem; SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) { if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody()); processSelectBody(subSelect.getSelectBody());
} }
} else if (fromItem instanceof ValuesList) { } else if (fromItem instanceof ValuesList) {
logger.debug("Perform a subquery, if you do not give us feedback"); logger.debug("Perform a subQuery, if you do not give us feedback");
} else if (fromItem instanceof LateralSubSelect) { } else if (fromItem instanceof LateralSubSelect) {
LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem; LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
if (lateralSubSelect.getSubSelect() != null) { if (lateralSubSelect.getSubSelect() != null) {
@ -284,75 +314,176 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
} }
} }
/**
* 处理 sub join
*
* @param subJoin subJoin
* @return Table subJoin 中的主表
*/
private List<Table> processSubJoin(SubJoin subJoin) {
List<Table> mainTables = new ArrayList<>();
if (subJoin.getJoinList() != null) {
List<Table> list = processFromItem(subJoin.getLeft());
mainTables.addAll(list);
mainTables = processJoins(mainTables, subJoin.getJoinList());
}
return mainTables;
}
/** /**
* 处理 joins * 处理 joins
* *
* @param joins join 集合 * @param mainTables 可以为 null
* @param joins join 集合
* @return List<Table> 右连接查询的 Table 列表
*/ */
private void processJoins(List<Join> joins) { private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
// join 表达式中最终的主表
Table mainTable = null;
// 当前 join 的左表
Table leftTable = null;
if (mainTables == null) {
mainTables = new ArrayList<>();
} else if (mainTables.size() == 1) {
mainTable = mainTables.get(0);
leftTable = mainTable;
}
//对于 on 表达式写在最后的 join需要记录下前面多个 on 的表名 //对于 on 表达式写在最后的 join需要记录下前面多个 on 的表名
Deque<Table> tables = new LinkedList<>(); Deque<List<Table>> onTableDeque = new LinkedList<>();
for (Join join : joins) { for (Join join : joins) {
// 处理 on 表达式 // 处理 on 表达式
FromItem fromItem = join.getRightItem(); FromItem joinItem = join.getRightItem();
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem; // 获取当前 join 的表subJoint 可以看作是一张表
List<Table> joinTables = null;
if (joinItem instanceof Table) {
joinTables = new ArrayList<>();
joinTables.add((Table) joinItem);
} else if (joinItem instanceof SubJoin) {
joinTables = processSubJoin((SubJoin) joinItem);
}
if (joinTables != null) {
// 如果是隐式内连接
if (join.isSimple()) {
mainTables.addAll(joinTables);
continue;
}
// 当前表是否忽略
Table joinTable = joinTables.get(0);
List<Table> onTables = null;
// 如果不要忽略且是右连接则记录下当前表
if (join.isRight()) {
mainTable = joinTable;
if (leftTable != null) {
onTables = Collections.singletonList(leftTable);
}
} else if (join.isLeft()) {
onTables = Collections.singletonList(joinTable);
} else if (join.isInner()) {
if (mainTable == null) {
onTables = Collections.singletonList(joinTable);
} else {
onTables = Arrays.asList(mainTable, joinTable);
}
mainTable = null;
}
mainTables = new ArrayList<>();
if (mainTable != null) {
mainTables.add(mainTable);
}
// 获取 join 尾缀的 on 表达式列表 // 获取 join 尾缀的 on 表达式列表
Collection<Expression> originOnExpressions = join.getOnExpressions(); Collection<Expression> originOnExpressions = join.getOnExpressions();
// 正常 join on 表达式只有一个立刻处理 // 正常 join on 表达式只有一个立刻处理
if (originOnExpressions.size() == 1) { if (originOnExpressions.size() == 1 && onTables != null) {
processJoin(join); List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
join.setOnExpressions(onExpressions);
leftTable = joinTable;
continue; continue;
} }
tables.push(fromTable); // 表名压栈忽略的表压入 null以便后续不处理
onTableDeque.push(onTables);
// 尾缀多个 on 表达式的时候统一处理 // 尾缀多个 on 表达式的时候统一处理
if (originOnExpressions.size() > 1) { if (originOnExpressions.size() > 1) {
Collection<Expression> onExpressions = new LinkedList<>(); Collection<Expression> onExpressions = new LinkedList<>();
for (Expression originOnExpression : originOnExpressions) { for (Expression originOnExpression : originOnExpressions) {
Table currentTable = tables.poll(); List<Table> currentTableList = onTableDeque.poll();
onExpressions.add(builderExpression(originOnExpression, currentTable)); if (CollectionUtils.isEmpty(currentTableList)) {
onExpressions.add(originOnExpression);
} else {
onExpressions.add(builderExpression(originOnExpression, currentTableList));
}
} }
join.setOnExpressions(onExpressions); join.setOnExpressions(onExpressions);
} }
leftTable = joinTable;
} else { } else {
// 处理右边连接的子表达式 processOtherFromItem(joinItem);
processFromItem(fromItem); leftTable = null;
} }
} }
return mainTables;
} }
// ========== TenantLineInnerInterceptor 存在差异的逻辑关键实现权限条件的拼接 ==========
/** /**
* 处理联接语句 * 处理条件
*
* @param currentExpression 当前 where 条件
* @param table 单个表
*/ */
protected void processJoin(Join join) { protected Expression builderExpression(Expression currentExpression, Table table) {
if (join.getRightItem() instanceof Table) { return this.builderExpression(currentExpression, Collections.singletonList(table));
Table fromTable = (Table) join.getRightItem();
Expression originOnExpression = CollUtil.getFirst(join.getOnExpressions());
originOnExpression = builderExpression(originOnExpression, fromTable);
join.setOnExpressions(CollUtil.newArrayList(originOnExpression));
}
} }
/** /**
* 处理条件 * 处理条件
*
* @param currentExpression 当前 where 条件
* @param tables 多个表
*/ */
protected Expression builderExpression(Expression currentExpression, Table table) { protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
// 获得 Table 对应的数据权限条件 // 没有表需要处理直接返回
Expression equalsTo = buildDataPermissionExpression(table); if (CollectionUtils.isEmpty(tables)) {
if (equalsTo == null) { // 如果没条件则返回 currentExpression 默认
return currentExpression; return currentExpression;
} }
// 表达式为空则直接返回 equalsTo // 第一步获得 Table 对应的数据权限条件
Expression dataPermissionExpression = null;
for (Table table : tables) {
// 构建每个表的权限 Expression 条件
Expression expression = buildDataPermissionExpression(table);
if (expression == null) {
continue;
}
// 合并到 dataPermissionExpression
dataPermissionExpression = dataPermissionExpression == null ? expression
: new AndExpression(dataPermissionExpression, expression);
}
// 第二步合并多个 Expression 条件
if (dataPermissionExpression == null) {
return currentExpression;
}
if (currentExpression == null) { if (currentExpression == null) {
return equalsTo; return dataPermissionExpression;
} }
// 如果表达式为 Or则需要 (currentExpression) AND equalsTo // 如果表达式为 Or则需要 (currentExpression) AND dataPermissionExpression
if (currentExpression instanceof OrExpression) { if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), equalsTo); return new AndExpression(new Parenthesis(currentExpression), dataPermissionExpression);
} }
// 如果表达式为 And则直接返回 currentExpression AND equalsTo // 如果表达式为 And则直接返回 where AND dataPermissionExpression
return new AndExpression(currentExpression, equalsTo); return new AndExpression(currentExpression, dataPermissionExpression);
} }
/** /**

View File

@ -46,7 +46,7 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
@Override @Override
public Set<String> getTableNames() { public Set<String> getTableNames() {
return asSet("entity", "entity1", "entity2", "t1", "t2", // 支持 MyBatis Plus 的单元测试 return asSet("entity", "entity1", "entity2", "entity3", "t1", "t2", "sys_dict_item", // 支持 MyBatis Plus 的单元测试
"t_user", "t_role"); // 满足自己的单元测试 "t_user", "t_role"); // 满足自己的单元测试
} }
@ -84,30 +84,30 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
@Test @Test
void delete() { void delete() {
assertSql("delete from entity where id = ?", assertSql("delete from entity where id = ?",
"DELETE FROM entity WHERE id = ? AND tenant_id = 1"); "DELETE FROM entity WHERE id = ? AND entity.tenant_id = 1");
} }
@Test @Test
void update() { void update() {
assertSql("update entity set name = ? where id = ?", assertSql("update entity set name = ? where id = ?",
"UPDATE entity SET name = ? WHERE id = ? AND tenant_id = 1"); "UPDATE entity SET name = ? WHERE id = ? AND entity.tenant_id = 1");
} }
@Test @Test
void selectSingle() { void selectSingle() {
// 单表 // 单表
assertSql("select * from entity where id = ?", assertSql("select * from entity where id = ?",
"SELECT * FROM entity WHERE id = ? AND tenant_id = 1"); "SELECT * FROM entity WHERE id = ? AND entity.tenant_id = 1");
assertSql("select * from entity where id = ? or name = ?", assertSql("select * from entity where id = ? or name = ?",
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1"); "SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1");
assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)", assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)",
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1"); "SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1");
/* not */ /* not */
assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)", assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)",
"SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND tenant_id = 1"); "SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND entity.tenant_id = 1");
} }
@Test @Test
@ -167,10 +167,12 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
assertSql("SELECT * FROM entity e WHERE e.id >= (select e1.id from entity1 e1 where e1.id = ?)", assertSql("SELECT * FROM entity e WHERE e.id >= (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id >= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1"); "SELECT * FROM entity e WHERE e.id >= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
/* <= */ /* <= */
assertSql("SELECT * FROM entity e WHERE e.id <= (select e1.id from entity1 e1 where e1.id = ?)", assertSql("SELECT * FROM entity e WHERE e.id <= (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id <= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1"); "SELECT * FROM entity e WHERE e.id <= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
/* <> */ /* <> */
assertSql("SELECT * FROM entity e WHERE e.id <> (select e1.id from entity1 e1 where e1.id = ?)", assertSql("SELECT * FROM entity e WHERE e.id <> (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id <> (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1"); "SELECT * FROM entity e WHERE e.id <> (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
@ -204,6 +206,14 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
"SELECT * FROM entity e " + "SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
} }
@Test @Test
@ -212,17 +222,125 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
assertSql("SELECT * FROM entity e " + assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id", "right join entity1 e1 on e1.id = e.id",
"SELECT * FROM entity e " + "SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"WHERE e.tenant_id = 1"); "WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM with_as_1 e " +
"right join entity1 e1 on e1.id = e.id",
"SELECT * FROM with_as_1 e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " + assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " + "right join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?", "WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " + "SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); "WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"right join entity2 e2 on e1.id = e2.id ",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e2.tenant_id = 1");
} }
@Test
void selectMixJoin() {
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"right join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e2.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"inner join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"INNER JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 AND e2.tenant_id = 1");
}
@Test
void selectJoinSubSelect() {
assertSql("select * from (select * from entity) e1 " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM (SELECT * FROM entity WHERE entity.tenant_id = 1) e1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1");
assertSql("select * from entity1 e1 " +
"left join (select * from entity2) e2 " +
"on e1.id = e2.id",
"SELECT * FROM entity1 e1 " +
"LEFT JOIN (SELECT * FROM entity2 WHERE entity2.tenant_id = 1) e2 " +
"ON e1.id = e2.id " +
"WHERE e1.tenant_id = 1");
}
@Test
void selectSubJoin() {
assertSql("select * FROM " +
"(entity1 e1 right JOIN entity2 e2 ON e1.id = e2.id)",
"SELECT * FROM " +
"(entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
"WHERE e2.tenant_id = 1");
assertSql("select * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id)",
"SELECT * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"WHERE e1.tenant_id = 1");
assertSql("select * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id) " +
"right join entity3 e3 on e1.id = e3.id",
"SELECT * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"RIGHT JOIN entity3 e3 ON e1.id = e3.id AND e1.tenant_id = 1 " +
"WHERE e3.tenant_id = 1");
assertSql("select * FROM entity e " +
"LEFT JOIN (entity1 e1 right join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN (entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
"ON e.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
assertSql("select * FROM entity e " +
"LEFT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"ON e.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
assertSql("select * FROM entity e " +
"RIGHT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"RIGHT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"ON e.id = e2.id AND e.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
}
@Test @Test
void selectLeftJoinMultipleTrailingOn() { void selectLeftJoinMultipleTrailingOn() {
// 多个 on 尾缀的 // 多个 on 尾缀的
@ -256,51 +374,97 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
"inner join entity1 e1 on e1.id = e.id " + "inner join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?", "WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " + "SELECT * FROM entity e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); "WHERE e.id = ? OR e.name = ?");
assertSql("SELECT * FROM entity e " + assertSql("SELECT * FROM entity e " +
"inner join entity1 e1 on e1.id = e.id " + "inner join entity1 e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)", "WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM entity e " + "SELECT * FROM entity e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); "WHERE (e.id = ? OR e.name = ?)");
// 隐式内连接
assertSql("SELECT * FROM entity,entity1 " +
"WHERE entity.id = entity1.id",
"SELECT * FROM entity, entity1 " +
"WHERE entity.id = entity1.id AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
// 隐式内连接
assertSql("SELECT * FROM entity a, with_as_entity1 b " +
"WHERE a.id = b.id",
"SELECT * FROM entity a, with_as_entity1 b " +
"WHERE a.id = b.id AND a.tenant_id = 1");
assertSql("SELECT * FROM with_as_entity a, with_as_entity1 b " +
"WHERE a.id = b.id",
"SELECT * FROM with_as_entity a, with_as_entity1 b " +
"WHERE a.id = b.id");
// SubJoin with 隐式内连接
assertSql("SELECT * FROM (entity,entity1) " +
"WHERE entity.id = entity1.id",
"SELECT * FROM (entity, entity1) " +
"WHERE entity.id = entity1.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
assertSql("SELECT * FROM ((entity,entity1),entity2) " +
"WHERE entity.id = entity1.id and entity.id = entity2.id",
"SELECT * FROM ((entity, entity1), entity2) " +
"WHERE entity.id = entity1.id AND entity.id = entity2.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1");
assertSql("SELECT * FROM (entity,(entity1,entity2)) " +
"WHERE entity.id = entity1.id and entity.id = entity2.id",
"SELECT * FROM (entity, (entity1, entity2)) " +
"WHERE entity.id = entity1.id AND entity.id = entity2.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1");
// 沙雕的括号写法
assertSql("SELECT * FROM (((entity,entity1))) " +
"WHERE entity.id = entity1.id",
"SELECT * FROM (((entity, entity1))) " +
"WHERE entity.id = entity1.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
// 垃圾 inner join todo
// assertSql("SELECT * FROM entity,entity1 " +
// "WHERE entity.id = entity1.id",
// "SELECT * FROM entity e " +
// "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
// "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
} }
@Test @Test
void selectWithAs() { void selectWithAs() {
assertSql("with with_as_A as (select * from entity) select * from with_as_A", assertSql("with with_as_A as (select * from entity) select * from with_as_A",
"WITH with_as_A AS (SELECT * FROM entity WHERE tenant_id = 1) SELECT * FROM with_as_A"); "WITH with_as_A AS (SELECT * FROM entity WHERE entity.tenant_id = 1) SELECT * FROM with_as_A");
}
@Test
void selectIgnoreTable() {
assertSql(" SELECT dict.dict_code, item.item_text AS \"text\", item.item_value AS \"value\" FROM sys_dict_item item INNER JOIN sys_dict dict ON dict.id = item.dict_id WHERE dict.dict_code IN (1, 2, 3) AND item.item_value IN (1, 2, 3)",
"SELECT dict.dict_code, item.item_text AS \"text\", item.item_value AS \"value\" FROM sys_dict_item item INNER JOIN sys_dict dict ON dict.id = item.dict_id AND item.tenant_id = 1 WHERE dict.dict_code IN (1, 2, 3) AND item.item_value IN (1, 2, 3)");
} }
private void assertSql(String sql, String targetSql) { private void assertSql(String sql, String targetSql) {
assertEquals(targetSql, interceptor.parserSingle(sql, null)); assertEquals(targetSql, interceptor.parserSingle(sql, null));
} }
// ========== 额外的测试 ========== // ========== 额外的测试 ==========
@Test @Test
public void testSelectSingle() { public void testSelectSingle() {
// 单表 // 单表
assertSql("select * from t_user where id = ?", assertSql("select * from t_user where id = ?",
"SELECT * FROM t_user WHERE id = ? AND tenant_id = 1 AND dept_id IN (10, 20)"); "SELECT * FROM t_user WHERE id = ? AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
assertSql("select * from t_user where id = ? or name = ?", assertSql("select * from t_user where id = ? or name = ?",
"SELECT * FROM t_user WHERE (id = ? OR name = ?) AND tenant_id = 1 AND dept_id IN (10, 20)"); "SELECT * FROM t_user WHERE (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
assertSql("SELECT * FROM t_user WHERE (id = ? OR name = ?)", assertSql("SELECT * FROM t_user WHERE (id = ? OR name = ?)",
"SELECT * FROM t_user WHERE (id = ? OR name = ?) AND tenant_id = 1 AND dept_id IN (10, 20)"); "SELECT * FROM t_user WHERE (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
/* not */ /* not */
assertSql("SELECT * FROM t_user WHERE not (id = ? OR name = ?)", assertSql("SELECT * FROM t_user WHERE not (id = ? OR name = ?)",
"SELECT * FROM t_user WHERE NOT (id = ? OR name = ?) AND tenant_id = 1 AND dept_id IN (10, 20)"); "SELECT * FROM t_user WHERE NOT (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
} }
@Test @Test
@ -329,16 +493,16 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
"right join t_role e1 on e1.id = e.id " + "right join t_role e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?", "WHERE e.id = ? OR e.name = ?",
"SELECT * FROM t_user e " + "SELECT * FROM t_user e " +
"RIGHT JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "RIGHT JOIN t_role e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)"); "WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
// 条件 e.id = ? OR e.name = ? 带括号 // 条件 e.id = ? OR e.name = ? 带括号
assertSql("SELECT * FROM t_user e " + assertSql("SELECT * FROM t_user e " +
"right join t_role e1 on e1.id = e.id " + "right join t_role e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)", "WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM t_user e " + "SELECT * FROM t_user e " +
"RIGHT JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "RIGHT JOIN t_role e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)"); "WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
} }
@Test @Test
@ -348,23 +512,22 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
"inner join entity1 e1 on e1.id = e.id " + "inner join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?", "WHERE e.id = ? OR e.name = ?",
"SELECT * FROM t_user e " + "SELECT * FROM t_user e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)"); "WHERE e.id = ? OR e.name = ?");
// 条件 e.id = ? OR e.name = ? 带括号 // 条件 e.id = ? OR e.name = ? 带括号
assertSql("SELECT * FROM t_user e " + assertSql("SELECT * FROM t_user e " +
"inner join t_role e1 on e1.id = e.id " + "inner join entity1 e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)", "WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM t_user e " + "SELECT * FROM t_user e " +
"INNER JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)"); "WHERE (e.id = ? OR e.name = ?)");
// 垃圾 inner join todo // 没有 On inner join
// assertSql("SELECT * FROM entity,entity1 " + assertSql("SELECT * FROM entity,entity1 " +
// "WHERE entity.id = entity1.id", "WHERE entity.id = entity1.id",
// "SELECT * FROM entity e " + "SELECT * FROM entity, entity1 " +
// "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "WHERE entity.id = entity1.id AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
// "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
} }
} }

View File

@ -4,6 +4,7 @@ import cn.hutool.core.collection.CollectionUtil;
import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.SortingField; import cn.iocoder.yudao.framework.common.pojo.SortingField;
import com.baomidou.mybatisplus.core.metadata.OrderItem; import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor; import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor; import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
@ -78,7 +79,10 @@ public class MyBatisUtils {
* @return Column 对象 * @return Column 对象
*/ */
public static Column buildColumn(String tableName, Alias tableAlias, String column) { public static Column buildColumn(String tableName, Alias tableAlias, String column) {
return new Column(tableAlias != null ? tableAlias.getName() + "." + column : column); if (tableAlias != null) {
tableName = tableAlias.getName();
}
return new Column(tableName + StringPool.DOT + column);
} }
} }