|
@@ -40,12 +40,12 @@ import java.util.concurrent.ConcurrentHashMap;
|
|
|
* <p>SQL是影响系统性能最重要的因素,所以拦截掉垃圾SQL语句</p>
|
|
|
* <br>
|
|
|
* <p>拦截SQL类型的场景</p>
|
|
|
- * <p>1.必须使用到索引,包含left jion连接字段,符合索引最左原则</p>
|
|
|
+ * <p>1.必须使用到索引,包含left join连接字段,符合索引最左原则</p>
|
|
|
* <p>必须使用索引好处,</p>
|
|
|
* <p>1.1 如果因为动态SQL,bug导致update的where条件没有带上,全表更新上万条数据</p>
|
|
|
* <p>1.2 如果检查到使用了索引,SQL性能基本不会太差</p>
|
|
|
* <br>
|
|
|
- * <p>2.SQL尽量单表执行,有查询left jion的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left jion的语句,如果不能拆成单表执行的SQL,请leader商量在做</p>
|
|
|
+ * <p>2.SQL尽量单表执行,有查询left join的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left join的语句,如果不能拆成单表执行的SQL,请leader商量在做</p>
|
|
|
* <p>https://gaoxianglong.github.io/shark</p>
|
|
|
* <p>SQL尽量单表执行的好处</p>
|
|
|
* <p>2.1 查询条件简单、易于开理解和维护;</p>
|
|
@@ -73,45 +73,61 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
/**
|
|
|
* 缓存表的索引信息
|
|
|
*/
|
|
|
- private final Map<String, List<IndexInfo>> indexInfoMap = new ConcurrentHashMap<>();
|
|
|
+ private final Map<String, IndexInfo> indexInfoCacheMap = new ConcurrentHashMap<>();
|
|
|
|
|
|
+ private final int indexValidColumnIndex = 8;
|
|
|
+ private final int indexDBNameColumnIndex = 1;
|
|
|
+ private final int indexTableNameColumnIndex = 3;
|
|
|
+ private final int indexColumnNameColumnIndex = 9;
|
|
|
+ private boolean failFast = false;
|
|
|
+ private ThreadLocal<String> message = new ThreadLocal<>();
|
|
|
+
|
|
|
+ private void set(String message) {
|
|
|
+ this.message.set(message);
|
|
|
+ }
|
|
|
/**
|
|
|
* 验证expression对象是不是 or、not等等
|
|
|
*
|
|
|
* @param expression ignore
|
|
|
*/
|
|
|
- private void validExpression(Expression expression) {
|
|
|
+ private boolean validExpression(Expression expression) {
|
|
|
//where条件使用了 or 关键字
|
|
|
if (expression instanceof OrExpression) {
|
|
|
- OrExpression orExpression = (OrExpression) expression;
|
|
|
- throw new SqlAnalysisException("非法SQL,where条件中不能使用[or]关键字,错误or信息:" + orExpression.toString());
|
|
|
+ OrExpression or = (OrExpression) expression;
|
|
|
+ set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", or.getStringExpression(), or.toString()));
|
|
|
+ return true;
|
|
|
} else if (expression instanceof NotEqualsTo) {
|
|
|
- NotEqualsTo notEqualsTo = (NotEqualsTo) expression;
|
|
|
- throw new SqlAnalysisException("非法SQL,where条件中不能使用[!=]关键字,错误!=信息:" + notEqualsTo.toString());
|
|
|
+ NotEqualsTo notEqTo = (NotEqualsTo) expression;
|
|
|
+ set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", notEqTo.getStringExpression(), notEqTo.toString()));
|
|
|
+ return true;
|
|
|
} else if (expression instanceof BinaryExpression) {
|
|
|
- BinaryExpression binaryExpression = (BinaryExpression) expression;
|
|
|
- if (binaryExpression.isNot()) {
|
|
|
- throw new SqlAnalysisException("非法SQL,where条件中不能使用[not]关键字,错误not信息:" + binaryExpression.toString());
|
|
|
+ BinaryExpression binary = (BinaryExpression) expression;
|
|
|
+ if (binary.isNot()) {
|
|
|
+ set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", binary.getStringExpression(), binary.toString()));
|
|
|
+ return true;
|
|
|
}
|
|
|
- Expression leftExpression = binaryExpression.getLeftExpression();
|
|
|
+ Expression leftExpression = binary.getLeftExpression();
|
|
|
if (leftExpression instanceof Function) {
|
|
|
Function function = (Function) leftExpression;
|
|
|
- throw new SqlAnalysisException("非法SQL,where条件中不能使用数据库函数,错误函数信息:" + function.toString());
|
|
|
+ set(String.format("Should not use db function in WHERE condition. illegal sql: %s", function.toString()));
|
|
|
+ return true;
|
|
|
}
|
|
|
- Expression rightExpression = binaryExpression.getRightExpression();
|
|
|
+ Expression rightExpression = binary.getRightExpression();
|
|
|
if (rightExpression instanceof SubSelect) {
|
|
|
SubSelect subSelect = (SubSelect) rightExpression;
|
|
|
- throw new SqlAnalysisException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
|
|
|
+ set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString()));
|
|
|
+ return true;
|
|
|
}
|
|
|
} else if (expression instanceof InExpression) {
|
|
|
- InExpression inExpression = (InExpression) expression;
|
|
|
- ItemsList rightItemsList = inExpression.getRightItemsList();
|
|
|
+ InExpression in = (InExpression) expression;
|
|
|
+ ItemsList rightItemsList = in.getRightItemsList();
|
|
|
if (rightItemsList instanceof SubSelect) {
|
|
|
SubSelect subSelect = (SubSelect) rightItemsList;
|
|
|
- throw new SqlAnalysisException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
|
|
|
+ set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString()));
|
|
|
+ return true;
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
+ return false;
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -122,8 +138,8 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
* @param connection ignore
|
|
|
*/
|
|
|
private void validJoins(List<Join> joins, Table table, Connection connection) {
|
|
|
- //允许执行join,验证jion是否使用索引等等
|
|
|
- if (joins != null) {
|
|
|
+ //允许执行join,验证join是否使用索引等等
|
|
|
+ if (SqlUtils.notEmpty(joins)) {
|
|
|
for (Join join : joins) {
|
|
|
Table rightTable = (Table) join.getRightItem();
|
|
|
Expression expression = join.getOnExpression();
|
|
@@ -140,9 +156,6 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
* @param connection ignore
|
|
|
*/
|
|
|
private void validUseIndex(Table table, String columnName, Connection connection) {
|
|
|
- //是否使用索引
|
|
|
- boolean useIndexFlag = false;
|
|
|
-
|
|
|
String tableInfo = table.getName();
|
|
|
//表存在的索引
|
|
|
String dbName = null;
|
|
@@ -154,13 +167,9 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
dbName = tableArray[0];
|
|
|
tableName = tableArray[1];
|
|
|
}
|
|
|
- List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
|
|
|
- for (IndexInfo indexInfo : indexInfos) {
|
|
|
- if (Objects.equals(columnName, indexInfo.getColumnName())) {
|
|
|
- useIndexFlag = true;
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
+ String cacheKey = String.format("%s::%s::%s", dbName, tableName, columnName);
|
|
|
+ boolean useIndexFlag = getIndexInfos(cacheKey, dbName,tableName, connection);
|
|
|
+
|
|
|
if (!useIndexFlag) {
|
|
|
throw new SqlAnalysisException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
|
|
|
}
|
|
@@ -185,13 +194,18 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
* @param joinTable ignore
|
|
|
* @param connection ignore
|
|
|
*/
|
|
|
- private void validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
|
|
|
- validExpression(expression);
|
|
|
+ private boolean validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
|
|
|
+ boolean valid = validExpression(expression);
|
|
|
+ if(valid){
|
|
|
+ return valid;
|
|
|
+ }
|
|
|
if (expression instanceof BinaryExpression) {
|
|
|
//获得左边表达式
|
|
|
Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
|
|
|
- validExpression(leftExpression);
|
|
|
-
|
|
|
+ boolean left = validExpression(leftExpression);
|
|
|
+ if(left){
|
|
|
+ return left;
|
|
|
+ }
|
|
|
//如果左边表达式为Column对象,则直接获得列名
|
|
|
if (leftExpression instanceof Column) {
|
|
|
Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
|
|
@@ -217,6 +231,7 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
|
|
|
validExpression(rightExpression);
|
|
|
}
|
|
|
+ return true;
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -227,7 +242,7 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
* @param conn ignore
|
|
|
* @return ignore
|
|
|
*/
|
|
|
- public List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
|
|
|
+ public boolean getIndexInfos(String dbName, String tableName, Connection conn) {
|
|
|
return getIndexInfos(null, dbName, tableName, conn);
|
|
|
}
|
|
|
|
|
@@ -240,35 +255,46 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
* @param conn ignore
|
|
|
* @return ignore
|
|
|
*/
|
|
|
- public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
|
|
|
- List<IndexInfo> indexInfos = null;
|
|
|
- if (SqlUtils.notEmpty(key)) {
|
|
|
- indexInfos = indexInfoMap.get(key);
|
|
|
- }
|
|
|
- if (SqlUtils.isEmpty(indexInfos)) {
|
|
|
- ResultSet rs;
|
|
|
+ public boolean getIndexInfos(String key, String dbName, String tableName, Connection conn) {
|
|
|
+ //if the indexInfoCacheMap is empty, must get index info from the connection.
|
|
|
+ if(indexInfoCacheMap.isEmpty()){
|
|
|
try {
|
|
|
DatabaseMetaData metadata = conn.getMetaData();
|
|
|
- rs = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
|
|
|
- indexInfos = new ArrayList<>();
|
|
|
- while (rs.next()) {
|
|
|
+ ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
|
|
|
+ while (resultSet.next()) {
|
|
|
//索引中的列序列号等于1,才有效
|
|
|
- if (Objects.equals(rs.getString(8), "1")) {
|
|
|
+ log.info("resultSet: {}", resultSet);
|
|
|
+ if ("1".equals(resultSet.getString(indexValidColumnIndex))) {
|
|
|
IndexInfo indexInfo = new IndexInfo();
|
|
|
- indexInfo.setDbName(rs.getString(1));
|
|
|
- indexInfo.setTableName(rs.getString(3));
|
|
|
- indexInfo.setColumnName(rs.getString(9));
|
|
|
- indexInfos.add(indexInfo);
|
|
|
+ indexInfo.setDbName(resultSet.getString(indexDBNameColumnIndex));
|
|
|
+ indexInfo.setTableName(resultSet.getString(indexTableNameColumnIndex));
|
|
|
+ indexInfo.setColumnName(resultSet.getString(indexColumnNameColumnIndex));
|
|
|
+ //index
|
|
|
+ indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo);
|
|
|
}
|
|
|
}
|
|
|
- if (SqlUtils.notEmpty(key)) {
|
|
|
- indexInfoMap.put(key, indexInfos);
|
|
|
- }
|
|
|
} catch (SQLException e) {
|
|
|
- e.printStackTrace();
|
|
|
+ log.error("get index info from connection fail.");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ try {
|
|
|
+ DatabaseMetaData metadata = conn.getMetaData();
|
|
|
+ ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
|
|
|
+ while (resultSet.next()) {
|
|
|
+ //索引中的列序列号等于1,才有效
|
|
|
+ if (Objects.equals(resultSet.getString(8), "1")) {
|
|
|
+ IndexInfo indexInfo = new IndexInfo();
|
|
|
+ indexInfo.setDbName(resultSet.getString(1));
|
|
|
+ indexInfo.setTableName(resultSet.getString(3));
|
|
|
+ indexInfo.setColumnName(resultSet.getString(9));
|
|
|
+ //index
|
|
|
+ indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo);
|
|
|
+ }
|
|
|
}
|
|
|
+ } catch (SQLException e) {
|
|
|
+ log.error("get index info from connection fail.");
|
|
|
}
|
|
|
- return indexInfos;
|
|
|
+ return true;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -282,10 +308,11 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
}
|
|
|
BoundSql boundSql = SqlUtils.getBoundSql(metaObject);
|
|
|
String originalSql = boundSql.getSql();
|
|
|
- log.info("检查SQL是否合规,SQL: {}", originalSql);
|
|
|
+ log.info("检查SQL ID: {}", mappedStatement.getId());
|
|
|
+ log.info("检查SQL是否合规,SQL: \n{}", originalSql);
|
|
|
String md5Base64 = SqlUtils.md5Base64(originalSql);
|
|
|
+ log.info("md5Base64: {}", md5Base64);
|
|
|
if (cacheValidResult.contains(md5Base64)) {
|
|
|
- log.info("该SQL已验证,无需再次验证, SQL: {}", originalSql);
|
|
|
return invocation.proceed();
|
|
|
}
|
|
|
Statement statement = CCJSqlParserUtil.parse(originalSql);
|
|
@@ -334,7 +361,10 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
|
|
|
@Override
|
|
|
public void setProperties(Properties prop) {
|
|
|
-
|
|
|
+ String failFast = prop.getProperty("failFast");
|
|
|
+ if (SqlUtils.isBoolean(failFast)) {
|
|
|
+ this.failFast = Boolean.valueOf(failFast);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -342,17 +372,15 @@ public class IllegalSqlInterceptor implements Interceptor {
|
|
|
*/
|
|
|
@Data
|
|
|
private class IndexInfo {
|
|
|
- /**
|
|
|
- * the name of the db.
|
|
|
- */
|
|
|
+ //the name of the db.
|
|
|
private String dbName;
|
|
|
- /**
|
|
|
- * the name of the table.
|
|
|
- */
|
|
|
+ //the name of the table.
|
|
|
private String tableName;
|
|
|
- /**
|
|
|
- * the name of the column.
|
|
|
- */
|
|
|
+ //the name of the column.
|
|
|
private String columnName;
|
|
|
+
|
|
|
+ public String cacheKey(){
|
|
|
+ return String.format("%s::%s::%s",getDbName(), getTableName(), getColumnName());
|
|
|
+ }
|
|
|
}
|
|
|
}
|