package cn.com.ty.lift.common.sql;
import com.baomidou.mybatisplus.core.parser.SqlParserHelper;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.ItemsList;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SubSelect;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* 由于开发人员水平参差不齐,即使订了开发规范很多人也不遵守
*
SQL是影响系统性能最重要的因素,所以拦截掉垃圾SQL语句
*
* 拦截SQL类型的场景
* 1.必须使用到索引,包含left join连接字段,符合索引最左原则
* 必须使用索引好处,
* 1.1 如果因为动态SQL,bug导致update的where条件没有带上,全表更新上万条数据
* 1.2 如果检查到使用了索引,SQL性能基本不会太差
*
* 2.SQL尽量单表执行,有查询left join的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left join的语句,如果不能拆成单表执行的SQL,请leader商量在做
* https://gaoxianglong.github.io/shark
* SQL尽量单表执行的好处
* 2.1 查询条件简单、易于开理解和维护;
* 2.2 扩展性极强;(可为分库分表做准备)
* 2.3 缓存利用率高;
* 2.在字段上使用函数
*
* 3.where条件为空
* 4.where条件使用了 !=
* 5.where条件使用了 not 关键字
* 6.where条件使用了 or 关键字
* 7.where条件使用了 使用子查询
*
* @author willenfoo
* @since 2018-03-22
*/
@Slf4j
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class IllegalSqlInterceptor implements Interceptor {
/**
* 缓存验证结果,提高性能
*/
private final Set cacheValidResult = new HashSet<>();
/**
* 缓存表的索引信息
*/
private final Map indexInfoCacheMap = new ConcurrentHashMap<>();
private final int indexValidColumnIndex = 8;
private final int indexDBNameColumnIndex = 1;
private final int indexTableNameColumnIndex = 3;
private final int indexColumnNameColumnIndex = 9;
private boolean failFast = false;
private ThreadLocal message = new ThreadLocal<>();
private void set(String message) {
this.message.set(message);
}
/**
* 验证expression对象是不是 or、not等等
*
* @param expression ignore
*/
private boolean validExpression(Expression expression) {
//where条件使用了 or 关键字
if (expression instanceof OrExpression) {
OrExpression or = (OrExpression) expression;
set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", or.getStringExpression(), or.toString()));
return true;
} else if (expression instanceof NotEqualsTo) {
NotEqualsTo notEqTo = (NotEqualsTo) expression;
set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", notEqTo.getStringExpression(), notEqTo.toString()));
return true;
} else if (expression instanceof BinaryExpression) {
BinaryExpression binary = (BinaryExpression) expression;
if (binary.isNot()) {
set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", binary.getStringExpression(), binary.toString()));
return true;
}
Expression leftExpression = binary.getLeftExpression();
if (leftExpression instanceof Function) {
Function function = (Function) leftExpression;
set(String.format("Should not use db function in WHERE condition. illegal sql: %s", function.toString()));
return true;
}
Expression rightExpression = binary.getRightExpression();
if (rightExpression instanceof SubSelect) {
SubSelect subSelect = (SubSelect) rightExpression;
set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString()));
return true;
}
} else if (expression instanceof InExpression) {
InExpression in = (InExpression) expression;
ItemsList rightItemsList = in.getRightItemsList();
if (rightItemsList instanceof SubSelect) {
SubSelect subSelect = (SubSelect) rightItemsList;
set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString()));
return true;
}
}
return false;
}
/**
* 如果SQL用了 left Join,验证是否有or、not等等,并且验证是否使用了索引
*
* @param joins ignore
* @param table ignore
* @param connection ignore
*/
private void validJoins(List joins, Table table, Connection connection) {
//允许执行join,验证join是否使用索引等等
if (SqlUtils.notEmpty(joins)) {
for (Join join : joins) {
Table rightTable = (Table) join.getRightItem();
Expression expression = join.getOnExpression();
validWhere(expression, table, rightTable, connection);
}
}
}
/**
* 检查是否使用索引
*
* @param table ignore
* @param columnName ignore
* @param connection ignore
*/
private void validUseIndex(Table table, String columnName, Connection connection) {
String tableInfo = table.getName();
//表存在的索引
String dbName = null;
String tableName;
String[] tableArray = tableInfo.split("\\.");
if (tableArray.length == 1) {
tableName = tableArray[0];
} else {
dbName = tableArray[0];
tableName = tableArray[1];
}
String cacheKey = String.format("%s::%s::%s", dbName, tableName, columnName);
boolean useIndexFlag = getIndexInfos(cacheKey, dbName,tableName, connection);
if (!useIndexFlag) {
throw new SqlAnalysisException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
}
}
/**
* 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
*
* @param expression ignore
* @param table ignore
* @param connection ignore
*/
private void validWhere(Expression expression, Table table, Connection connection) {
validWhere(expression, table, null, connection);
}
/**
* 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
*
* @param expression ignore
* @param table ignore
* @param joinTable ignore
* @param connection ignore
*/
private boolean validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
boolean valid = validExpression(expression);
if(valid){
return valid;
}
if (expression instanceof BinaryExpression) {
//获得左边表达式
Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
boolean left = validExpression(leftExpression);
if(left){
return left;
}
//如果左边表达式为Column对象,则直接获得列名
if (leftExpression instanceof Column) {
Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
if (joinTable != null && rightExpression instanceof Column) {
if (Objects.equals(((Column) rightExpression).getTable().getName(), table.getAlias().getName())) {
validUseIndex(table, ((Column) rightExpression).getColumnName(), connection);
validUseIndex(joinTable, ((Column) leftExpression).getColumnName(), connection);
} else {
validUseIndex(joinTable, ((Column) rightExpression).getColumnName(), connection);
validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
}
} else {
//获得列名
validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
}
}
//如果BinaryExpression,进行迭代
else if (leftExpression instanceof BinaryExpression) {
validWhere(leftExpression, table, joinTable, connection);
}
//获得右边表达式,并分解
Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
validExpression(rightExpression);
}
return true;
}
/**
* 得到表的索引信息
*
* @param dbName ignore
* @param tableName ignore
* @param conn ignore
* @return ignore
*/
public boolean getIndexInfos(String dbName, String tableName, Connection conn) {
return getIndexInfos(null, dbName, tableName, conn);
}
/**
* 得到表的索引信息
*
* @param key ignore
* @param dbName ignore
* @param tableName ignore
* @param conn ignore
* @return ignore
*/
public boolean getIndexInfos(String key, String dbName, String tableName, Connection conn) {
//if the indexInfoCacheMap is empty, must get index info from the connection.
if(indexInfoCacheMap.isEmpty()){
try {
DatabaseMetaData metadata = conn.getMetaData();
ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
while (resultSet.next()) {
//索引中的列序列号等于1,才有效
log.info("resultSet: {}", resultSet);
if ("1".equals(resultSet.getString(indexValidColumnIndex))) {
IndexInfo indexInfo = new IndexInfo();
indexInfo.setDbName(resultSet.getString(indexDBNameColumnIndex));
indexInfo.setTableName(resultSet.getString(indexTableNameColumnIndex));
indexInfo.setColumnName(resultSet.getString(indexColumnNameColumnIndex));
//index
indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo);
}
}
} catch (SQLException e) {
log.error("get index info from connection fail.");
}
}
try {
DatabaseMetaData metadata = conn.getMetaData();
ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
while (resultSet.next()) {
//索引中的列序列号等于1,才有效
if (Objects.equals(resultSet.getString(8), "1")) {
IndexInfo indexInfo = new IndexInfo();
indexInfo.setDbName(resultSet.getString(1));
indexInfo.setTableName(resultSet.getString(3));
indexInfo.setColumnName(resultSet.getString(9));
//index
indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo);
}
}
} catch (SQLException e) {
log.error("get index info from connection fail.");
}
return true;
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = SqlUtils.realTarget(invocation.getTarget());
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
// 如果是insert操作, 或者 @SqlParser(filter = true) 跳过该方法解析 , 不进行验证
MappedStatement mappedStatement = SqlUtils.getMappedStatement(metaObject);
if (SqlCommandType.INSERT.equals(mappedStatement.getSqlCommandType()) || SqlParserHelper.getSqlParserInfo(metaObject)) {
return invocation.proceed();
}
BoundSql boundSql = SqlUtils.getBoundSql(metaObject);
String originalSql = boundSql.getSql();
log.info("检查SQL ID: {}", mappedStatement.getId());
log.info("检查SQL是否合规,SQL: \n{}", originalSql);
String md5Base64 = SqlUtils.md5Base64(originalSql);
log.info("md5Base64: {}", md5Base64);
if (cacheValidResult.contains(md5Base64)) {
return invocation.proceed();
}
Statement statement = CCJSqlParserUtil.parse(originalSql);
Expression where = null;
Table table = null;
List joins = null;
if (statement instanceof Select) {
PlainSelect plainSelect = (PlainSelect) ((Select) statement).getSelectBody();
where = plainSelect.getWhere();
//table = (Table) plainSelect.getFromItem();
table = plainSelect.getForUpdateTable();
joins = plainSelect.getJoins();
} else if (statement instanceof Update) {
Update update = (Update) statement;
where = update.getWhere();
table = update.getTables().get(0);
joins = update.getJoins();
} else if (statement instanceof Delete) {
Delete delete = (Delete) statement;
where = delete.getWhere();
table = delete.getTable();
joins = delete.getJoins();
}
//where条件不能为空
if (SqlUtils.isNull(where)) {
throw new SqlAnalysisException("非法SQL,必须要有where条件");
}
if(SqlUtils.isNull(table)){
return invocation.proceed();
}
Connection connection = (Connection) invocation.getArgs()[0];
validWhere(where, table, connection);
validJoins(joins, table, connection);
//缓存验证结果
cacheValidResult.add(md5Base64);
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
if (target instanceof StatementHandler) {
return Plugin.wrap(target, this);
}
return target;
}
@Override
public void setProperties(Properties prop) {
String failFast = prop.getProperty("failFast");
if (SqlUtils.isBoolean(failFast)) {
this.failFast = Boolean.valueOf(failFast);
}
}
/**
* the information of the index.
*/
@Data
private class IndexInfo {
//the name of the db.
private String dbName;
//the name of the table.
private String tableName;
//the name of the column.
private String columnName;
public String cacheKey(){
return String.format("%s::%s::%s",getDbName(), getTableName(), getColumnName());
}
}
}