package cn.com.ty.lift.common.sql; import com.baomidou.mybatisplus.core.parser.SqlParserHelper; import lombok.Data; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.InExpression; import net.sf.jsqlparser.expression.operators.relational.ItemsList; import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.Statement; import net.sf.jsqlparser.statement.delete.Delete; import net.sf.jsqlparser.statement.select.Join; import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SubSelect; import net.sf.jsqlparser.statement.update.Update; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.SqlCommandType; import org.apache.ibatis.plugin.*; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.SystemMetaObject; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.ResultSet; import java.sql.SQLException; import java.util.*; import java.util.concurrent.ConcurrentHashMap; /** * 由于开发人员水平参差不齐,即使订了开发规范很多人也不遵守 *

SQL是影响系统性能最重要的因素,所以拦截掉垃圾SQL语句

*
*

拦截SQL类型的场景

*

1.必须使用到索引,包含left join连接字段,符合索引最左原则

*

必须使用索引好处,

*

1.1 如果因为动态SQL,bug导致update的where条件没有带上,全表更新上万条数据

*

1.2 如果检查到使用了索引,SQL性能基本不会太差

*
*

2.SQL尽量单表执行,有查询left join的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left join的语句,如果不能拆成单表执行的SQL,请leader商量在做

*

https://gaoxianglong.github.io/shark

*

SQL尽量单表执行的好处

*

2.1 查询条件简单、易于开理解和维护;

*

2.2 扩展性极强;(可为分库分表做准备)

*

2.3 缓存利用率高;

*

2.在字段上使用函数

*
*

3.where条件为空

*

4.where条件使用了 !=

*

5.where条件使用了 not 关键字

*

6.where条件使用了 or 关键字

*

7.where条件使用了 使用子查询

* * @author willenfoo * @since 2018-03-22 */ @Slf4j @Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})}) public class IllegalSqlInterceptor implements Interceptor { /** * 缓存验证结果,提高性能 */ private final Set cacheValidResult = new HashSet<>(); /** * 缓存表的索引信息 */ private final Map indexInfoCacheMap = new ConcurrentHashMap<>(); private final int indexValidColumnIndex = 8; private final int indexDBNameColumnIndex = 1; private final int indexTableNameColumnIndex = 3; private final int indexColumnNameColumnIndex = 9; private boolean failFast = false; private ThreadLocal message = new ThreadLocal<>(); private void set(String message) { this.message.set(message); } /** * 验证expression对象是不是 or、not等等 * * @param expression ignore */ private boolean validExpression(Expression expression) { //where条件使用了 or 关键字 if (expression instanceof OrExpression) { OrExpression or = (OrExpression) expression; set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", or.getStringExpression(), or.toString())); return true; } else if (expression instanceof NotEqualsTo) { NotEqualsTo notEqTo = (NotEqualsTo) expression; set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", notEqTo.getStringExpression(), notEqTo.toString())); return true; } else if (expression instanceof BinaryExpression) { BinaryExpression binary = (BinaryExpression) expression; if (binary.isNot()) { set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", binary.getStringExpression(), binary.toString())); return true; } Expression leftExpression = binary.getLeftExpression(); if (leftExpression instanceof Function) { Function function = (Function) leftExpression; set(String.format("Should not use db function in WHERE condition. illegal sql: %s", function.toString())); return true; } Expression rightExpression = binary.getRightExpression(); if (rightExpression instanceof SubSelect) { SubSelect subSelect = (SubSelect) rightExpression; set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString())); return true; } } else if (expression instanceof InExpression) { InExpression in = (InExpression) expression; ItemsList rightItemsList = in.getRightItemsList(); if (rightItemsList instanceof SubSelect) { SubSelect subSelect = (SubSelect) rightItemsList; set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString())); return true; } } return false; } /** * 如果SQL用了 left Join,验证是否有or、not等等,并且验证是否使用了索引 * * @param joins ignore * @param table ignore * @param connection ignore */ private void validJoins(List joins, Table table, Connection connection) { //允许执行join,验证join是否使用索引等等 if (SqlUtils.notEmpty(joins)) { for (Join join : joins) { Table rightTable = (Table) join.getRightItem(); Expression expression = join.getOnExpression(); validWhere(expression, table, rightTable, connection); } } } /** * 检查是否使用索引 * * @param table ignore * @param columnName ignore * @param connection ignore */ private void validUseIndex(Table table, String columnName, Connection connection) { String tableInfo = table.getName(); //表存在的索引 String dbName = null; String tableName; String[] tableArray = tableInfo.split("\\."); if (tableArray.length == 1) { tableName = tableArray[0]; } else { dbName = tableArray[0]; tableName = tableArray[1]; } String cacheKey = String.format("%s::%s::%s", dbName, tableName, columnName); boolean useIndexFlag = getIndexInfos(cacheKey, dbName,tableName, connection); if (!useIndexFlag) { throw new SqlAnalysisException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName); } } /** * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引 * * @param expression ignore * @param table ignore * @param connection ignore */ private void validWhere(Expression expression, Table table, Connection connection) { validWhere(expression, table, null, connection); } /** * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引 * * @param expression ignore * @param table ignore * @param joinTable ignore * @param connection ignore */ private boolean validWhere(Expression expression, Table table, Table joinTable, Connection connection) { boolean valid = validExpression(expression); if(valid){ return valid; } if (expression instanceof BinaryExpression) { //获得左边表达式 Expression leftExpression = ((BinaryExpression) expression).getLeftExpression(); boolean left = validExpression(leftExpression); if(left){ return left; } //如果左边表达式为Column对象,则直接获得列名 if (leftExpression instanceof Column) { Expression rightExpression = ((BinaryExpression) expression).getRightExpression(); if (joinTable != null && rightExpression instanceof Column) { if (Objects.equals(((Column) rightExpression).getTable().getName(), table.getAlias().getName())) { validUseIndex(table, ((Column) rightExpression).getColumnName(), connection); validUseIndex(joinTable, ((Column) leftExpression).getColumnName(), connection); } else { validUseIndex(joinTable, ((Column) rightExpression).getColumnName(), connection); validUseIndex(table, ((Column) leftExpression).getColumnName(), connection); } } else { //获得列名 validUseIndex(table, ((Column) leftExpression).getColumnName(), connection); } } //如果BinaryExpression,进行迭代 else if (leftExpression instanceof BinaryExpression) { validWhere(leftExpression, table, joinTable, connection); } //获得右边表达式,并分解 Expression rightExpression = ((BinaryExpression) expression).getRightExpression(); validExpression(rightExpression); } return true; } /** * 得到表的索引信息 * * @param dbName ignore * @param tableName ignore * @param conn ignore * @return ignore */ public boolean getIndexInfos(String dbName, String tableName, Connection conn) { return getIndexInfos(null, dbName, tableName, conn); } /** * 得到表的索引信息 * * @param key ignore * @param dbName ignore * @param tableName ignore * @param conn ignore * @return ignore */ public boolean getIndexInfos(String key, String dbName, String tableName, Connection conn) { //if the indexInfoCacheMap is empty, must get index info from the connection. if(indexInfoCacheMap.isEmpty()){ try { DatabaseMetaData metadata = conn.getMetaData(); ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true); while (resultSet.next()) { //索引中的列序列号等于1,才有效 log.info("resultSet: {}", resultSet); if ("1".equals(resultSet.getString(indexValidColumnIndex))) { IndexInfo indexInfo = new IndexInfo(); indexInfo.setDbName(resultSet.getString(indexDBNameColumnIndex)); indexInfo.setTableName(resultSet.getString(indexTableNameColumnIndex)); indexInfo.setColumnName(resultSet.getString(indexColumnNameColumnIndex)); //index indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo); } } } catch (SQLException e) { log.error("get index info from connection fail."); } } try { DatabaseMetaData metadata = conn.getMetaData(); ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true); while (resultSet.next()) { //索引中的列序列号等于1,才有效 if (Objects.equals(resultSet.getString(8), "1")) { IndexInfo indexInfo = new IndexInfo(); indexInfo.setDbName(resultSet.getString(1)); indexInfo.setTableName(resultSet.getString(3)); indexInfo.setColumnName(resultSet.getString(9)); //index indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo); } } } catch (SQLException e) { log.error("get index info from connection fail."); } return true; } @Override public Object intercept(Invocation invocation) throws Throwable { StatementHandler statementHandler = SqlUtils.realTarget(invocation.getTarget()); MetaObject metaObject = SystemMetaObject.forObject(statementHandler); // 如果是insert操作, 或者 @SqlParser(filter = true) 跳过该方法解析 , 不进行验证 MappedStatement mappedStatement = SqlUtils.getMappedStatement(metaObject); if (SqlCommandType.INSERT.equals(mappedStatement.getSqlCommandType()) || SqlParserHelper.getSqlParserInfo(metaObject)) { return invocation.proceed(); } BoundSql boundSql = SqlUtils.getBoundSql(metaObject); String originalSql = boundSql.getSql(); log.info("检查SQL ID: {}", mappedStatement.getId()); log.info("检查SQL是否合规,SQL: \n{}", originalSql); String md5Base64 = SqlUtils.md5Base64(originalSql); log.info("md5Base64: {}", md5Base64); if (cacheValidResult.contains(md5Base64)) { return invocation.proceed(); } Statement statement = CCJSqlParserUtil.parse(originalSql); Expression where = null; Table table = null; List joins = null; if (statement instanceof Select) { PlainSelect plainSelect = (PlainSelect) ((Select) statement).getSelectBody(); where = plainSelect.getWhere(); //table = (Table) plainSelect.getFromItem(); table = plainSelect.getForUpdateTable(); joins = plainSelect.getJoins(); } else if (statement instanceof Update) { Update update = (Update) statement; where = update.getWhere(); table = update.getTables().get(0); joins = update.getJoins(); } else if (statement instanceof Delete) { Delete delete = (Delete) statement; where = delete.getWhere(); table = delete.getTable(); joins = delete.getJoins(); } //where条件不能为空 if (SqlUtils.isNull(where)) { throw new SqlAnalysisException("非法SQL,必须要有where条件"); } if(SqlUtils.isNull(table)){ return invocation.proceed(); } Connection connection = (Connection) invocation.getArgs()[0]; validWhere(where, table, connection); validJoins(joins, table, connection); //缓存验证结果 cacheValidResult.add(md5Base64); return invocation.proceed(); } @Override public Object plugin(Object target) { if (target instanceof StatementHandler) { return Plugin.wrap(target, this); } return target; } @Override public void setProperties(Properties prop) { String failFast = prop.getProperty("failFast"); if (SqlUtils.isBoolean(failFast)) { this.failFast = Boolean.valueOf(failFast); } } /** * the information of the index. */ @Data private class IndexInfo { //the name of the db. private String dbName; //the name of the table. private String tableName; //the name of the column. private String columnName; public String cacheKey(){ return String.format("%s::%s::%s",getDbName(), getTableName(), getColumnName()); } } }