123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386 |
- package cn.com.ty.lift.common.sql;
- import com.baomidou.mybatisplus.core.parser.SqlParserHelper;
- import lombok.Data;
- import lombok.extern.slf4j.Slf4j;
- import net.sf.jsqlparser.expression.BinaryExpression;
- import net.sf.jsqlparser.expression.Expression;
- import net.sf.jsqlparser.expression.Function;
- import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
- import net.sf.jsqlparser.expression.operators.relational.InExpression;
- import net.sf.jsqlparser.expression.operators.relational.ItemsList;
- import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
- import net.sf.jsqlparser.parser.CCJSqlParserUtil;
- import net.sf.jsqlparser.schema.Column;
- import net.sf.jsqlparser.schema.Table;
- import net.sf.jsqlparser.statement.Statement;
- import net.sf.jsqlparser.statement.delete.Delete;
- import net.sf.jsqlparser.statement.select.Join;
- import net.sf.jsqlparser.statement.select.PlainSelect;
- import net.sf.jsqlparser.statement.select.Select;
- import net.sf.jsqlparser.statement.select.SubSelect;
- import net.sf.jsqlparser.statement.update.Update;
- import org.apache.ibatis.executor.statement.StatementHandler;
- import org.apache.ibatis.mapping.BoundSql;
- import org.apache.ibatis.mapping.MappedStatement;
- import org.apache.ibatis.mapping.SqlCommandType;
- import org.apache.ibatis.plugin.*;
- import org.apache.ibatis.reflection.MetaObject;
- import org.apache.ibatis.reflection.SystemMetaObject;
- import java.sql.Connection;
- import java.sql.DatabaseMetaData;
- import java.sql.ResultSet;
- import java.sql.SQLException;
- import java.util.*;
- import java.util.concurrent.ConcurrentHashMap;
- /**
- * 由于开发人员水平参差不齐,即使订了开发规范很多人也不遵守
- * <p>SQL是影响系统性能最重要的因素,所以拦截掉垃圾SQL语句</p>
- * <br>
- * <p>拦截SQL类型的场景</p>
- * <p>1.必须使用到索引,包含left join连接字段,符合索引最左原则</p>
- * <p>必须使用索引好处,</p>
- * <p>1.1 如果因为动态SQL,bug导致update的where条件没有带上,全表更新上万条数据</p>
- * <p>1.2 如果检查到使用了索引,SQL性能基本不会太差</p>
- * <br>
- * <p>2.SQL尽量单表执行,有查询left join的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left join的语句,如果不能拆成单表执行的SQL,请leader商量在做</p>
- * <p>https://gaoxianglong.github.io/shark</p>
- * <p>SQL尽量单表执行的好处</p>
- * <p>2.1 查询条件简单、易于开理解和维护;</p>
- * <p>2.2 扩展性极强;(可为分库分表做准备)</p>
- * <p>2.3 缓存利用率高;</p>
- * <p>2.在字段上使用函数</p>
- * <br>
- * <p>3.where条件为空</p>
- * <p>4.where条件使用了 !=</p>
- * <p>5.where条件使用了 not 关键字</p>
- * <p>6.where条件使用了 or 关键字</p>
- * <p>7.where条件使用了 使用子查询</p>
- *
- * @author willenfoo
- * @since 2018-03-22
- */
- @Slf4j
- @Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
- public class IllegalSqlInterceptor implements Interceptor {
- /**
- * 缓存验证结果,提高性能
- */
- private final Set<String> cacheValidResult = new HashSet<>();
- /**
- * 缓存表的索引信息
- */
- private final Map<String, IndexInfo> indexInfoCacheMap = new ConcurrentHashMap<>();
- private final int indexValidColumnIndex = 8;
- private final int indexDBNameColumnIndex = 1;
- private final int indexTableNameColumnIndex = 3;
- private final int indexColumnNameColumnIndex = 9;
- private boolean failFast = false;
- private ThreadLocal<String> message = new ThreadLocal<>();
- private void set(String message) {
- this.message.set(message);
- }
- /**
- * 验证expression对象是不是 or、not等等
- *
- * @param expression ignore
- */
- private boolean validExpression(Expression expression) {
- //where条件使用了 or 关键字
- if (expression instanceof OrExpression) {
- OrExpression or = (OrExpression) expression;
- set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", or.getStringExpression(), or.toString()));
- return true;
- } else if (expression instanceof NotEqualsTo) {
- NotEqualsTo notEqTo = (NotEqualsTo) expression;
- set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", notEqTo.getStringExpression(), notEqTo.toString()));
- return true;
- } else if (expression instanceof BinaryExpression) {
- BinaryExpression binary = (BinaryExpression) expression;
- if (binary.isNot()) {
- set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", binary.getStringExpression(), binary.toString()));
- return true;
- }
- Expression leftExpression = binary.getLeftExpression();
- if (leftExpression instanceof Function) {
- Function function = (Function) leftExpression;
- set(String.format("Should not use db function in WHERE condition. illegal sql: %s", function.toString()));
- return true;
- }
- Expression rightExpression = binary.getRightExpression();
- if (rightExpression instanceof SubSelect) {
- SubSelect subSelect = (SubSelect) rightExpression;
- set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString()));
- return true;
- }
- } else if (expression instanceof InExpression) {
- InExpression in = (InExpression) expression;
- ItemsList rightItemsList = in.getRightItemsList();
- if (rightItemsList instanceof SubSelect) {
- SubSelect subSelect = (SubSelect) rightItemsList;
- set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString()));
- return true;
- }
- }
- return false;
- }
- /**
- * 如果SQL用了 left Join,验证是否有or、not等等,并且验证是否使用了索引
- *
- * @param joins ignore
- * @param table ignore
- * @param connection ignore
- */
- private void validJoins(List<Join> joins, Table table, Connection connection) {
- //允许执行join,验证join是否使用索引等等
- if (SqlUtils.notEmpty(joins)) {
- for (Join join : joins) {
- Table rightTable = (Table) join.getRightItem();
- Expression expression = join.getOnExpression();
- validWhere(expression, table, rightTable, connection);
- }
- }
- }
- /**
- * 检查是否使用索引
- *
- * @param table ignore
- * @param columnName ignore
- * @param connection ignore
- */
- private void validUseIndex(Table table, String columnName, Connection connection) {
- String tableInfo = table.getName();
- //表存在的索引
- String dbName = null;
- String tableName;
- String[] tableArray = tableInfo.split("\\.");
- if (tableArray.length == 1) {
- tableName = tableArray[0];
- } else {
- dbName = tableArray[0];
- tableName = tableArray[1];
- }
- String cacheKey = String.format("%s::%s::%s", dbName, tableName, columnName);
- boolean useIndexFlag = getIndexInfos(cacheKey, dbName,tableName, connection);
- if (!useIndexFlag) {
- throw new SqlAnalysisException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
- }
- }
- /**
- * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
- *
- * @param expression ignore
- * @param table ignore
- * @param connection ignore
- */
- private void validWhere(Expression expression, Table table, Connection connection) {
- validWhere(expression, table, null, connection);
- }
- /**
- * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
- *
- * @param expression ignore
- * @param table ignore
- * @param joinTable ignore
- * @param connection ignore
- */
- private boolean validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
- boolean valid = validExpression(expression);
- if(valid){
- return valid;
- }
- if (expression instanceof BinaryExpression) {
- //获得左边表达式
- Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
- boolean left = validExpression(leftExpression);
- if(left){
- return left;
- }
- //如果左边表达式为Column对象,则直接获得列名
- if (leftExpression instanceof Column) {
- Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
- if (joinTable != null && rightExpression instanceof Column) {
- if (Objects.equals(((Column) rightExpression).getTable().getName(), table.getAlias().getName())) {
- validUseIndex(table, ((Column) rightExpression).getColumnName(), connection);
- validUseIndex(joinTable, ((Column) leftExpression).getColumnName(), connection);
- } else {
- validUseIndex(joinTable, ((Column) rightExpression).getColumnName(), connection);
- validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
- }
- } else {
- //获得列名
- validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
- }
- }
- //如果BinaryExpression,进行迭代
- else if (leftExpression instanceof BinaryExpression) {
- validWhere(leftExpression, table, joinTable, connection);
- }
- //获得右边表达式,并分解
- Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
- validExpression(rightExpression);
- }
- return true;
- }
- /**
- * 得到表的索引信息
- *
- * @param dbName ignore
- * @param tableName ignore
- * @param conn ignore
- * @return ignore
- */
- public boolean getIndexInfos(String dbName, String tableName, Connection conn) {
- return getIndexInfos(null, dbName, tableName, conn);
- }
- /**
- * 得到表的索引信息
- *
- * @param key ignore
- * @param dbName ignore
- * @param tableName ignore
- * @param conn ignore
- * @return ignore
- */
- public boolean getIndexInfos(String key, String dbName, String tableName, Connection conn) {
- //if the indexInfoCacheMap is empty, must get index info from the connection.
- if(indexInfoCacheMap.isEmpty()){
- try {
- DatabaseMetaData metadata = conn.getMetaData();
- ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
- while (resultSet.next()) {
- //索引中的列序列号等于1,才有效
- log.info("resultSet: {}", resultSet);
- if ("1".equals(resultSet.getString(indexValidColumnIndex))) {
- IndexInfo indexInfo = new IndexInfo();
- indexInfo.setDbName(resultSet.getString(indexDBNameColumnIndex));
- indexInfo.setTableName(resultSet.getString(indexTableNameColumnIndex));
- indexInfo.setColumnName(resultSet.getString(indexColumnNameColumnIndex));
- //index
- indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo);
- }
- }
- } catch (SQLException e) {
- log.error("get index info from connection fail.");
- }
- }
- try {
- DatabaseMetaData metadata = conn.getMetaData();
- ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
- while (resultSet.next()) {
- //索引中的列序列号等于1,才有效
- if (Objects.equals(resultSet.getString(8), "1")) {
- IndexInfo indexInfo = new IndexInfo();
- indexInfo.setDbName(resultSet.getString(1));
- indexInfo.setTableName(resultSet.getString(3));
- indexInfo.setColumnName(resultSet.getString(9));
- //index
- indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo);
- }
- }
- } catch (SQLException e) {
- log.error("get index info from connection fail.");
- }
- return true;
- }
- @Override
- public Object intercept(Invocation invocation) throws Throwable {
- StatementHandler statementHandler = SqlUtils.realTarget(invocation.getTarget());
- MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
- // 如果是insert操作, 或者 @SqlParser(filter = true) 跳过该方法解析 , 不进行验证
- MappedStatement mappedStatement = SqlUtils.getMappedStatement(metaObject);
- if (SqlCommandType.INSERT.equals(mappedStatement.getSqlCommandType()) || SqlParserHelper.getSqlParserInfo(metaObject)) {
- return invocation.proceed();
- }
- BoundSql boundSql = SqlUtils.getBoundSql(metaObject);
- String originalSql = boundSql.getSql();
- log.info("检查SQL ID: {}", mappedStatement.getId());
- log.info("检查SQL是否合规,SQL: \n{}", originalSql);
- String md5Base64 = SqlUtils.md5Base64(originalSql);
- log.info("md5Base64: {}", md5Base64);
- if (cacheValidResult.contains(md5Base64)) {
- return invocation.proceed();
- }
- Statement statement = CCJSqlParserUtil.parse(originalSql);
- Expression where = null;
- Table table = null;
- List<Join> joins = null;
- if (statement instanceof Select) {
- PlainSelect plainSelect = (PlainSelect) ((Select) statement).getSelectBody();
- where = plainSelect.getWhere();
- //table = (Table) plainSelect.getFromItem();
- table = plainSelect.getForUpdateTable();
- joins = plainSelect.getJoins();
- } else if (statement instanceof Update) {
- Update update = (Update) statement;
- where = update.getWhere();
- table = update.getTables().get(0);
- joins = update.getJoins();
- } else if (statement instanceof Delete) {
- Delete delete = (Delete) statement;
- where = delete.getWhere();
- table = delete.getTable();
- joins = delete.getJoins();
- }
- //where条件不能为空
- if (SqlUtils.isNull(where)) {
- throw new SqlAnalysisException("非法SQL,必须要有where条件");
- }
- if(SqlUtils.isNull(table)){
- return invocation.proceed();
- }
- Connection connection = (Connection) invocation.getArgs()[0];
- validWhere(where, table, connection);
- validJoins(joins, table, connection);
- //缓存验证结果
- cacheValidResult.add(md5Base64);
- return invocation.proceed();
- }
- @Override
- public Object plugin(Object target) {
- if (target instanceof StatementHandler) {
- return Plugin.wrap(target, this);
- }
- return target;
- }
- @Override
- public void setProperties(Properties prop) {
- String failFast = prop.getProperty("failFast");
- if (SqlUtils.isBoolean(failFast)) {
- this.failFast = Boolean.valueOf(failFast);
- }
- }
- /**
- * the information of the index.
- */
- @Data
- private class IndexInfo {
- //the name of the db.
- private String dbName;
- //the name of the table.
- private String tableName;
- //the name of the column.
- private String columnName;
- public String cacheKey(){
- return String.format("%s::%s::%s",getDbName(), getTableName(), getColumnName());
- }
- }
- }
|