IllegalSqlInterceptor.java 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. package cn.com.ty.lift.common.sql;
  2. import com.baomidou.mybatisplus.core.parser.SqlParserHelper;
  3. import lombok.Data;
  4. import lombok.extern.slf4j.Slf4j;
  5. import net.sf.jsqlparser.expression.BinaryExpression;
  6. import net.sf.jsqlparser.expression.Expression;
  7. import net.sf.jsqlparser.expression.Function;
  8. import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
  9. import net.sf.jsqlparser.expression.operators.relational.InExpression;
  10. import net.sf.jsqlparser.expression.operators.relational.ItemsList;
  11. import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
  12. import net.sf.jsqlparser.parser.CCJSqlParserUtil;
  13. import net.sf.jsqlparser.schema.Column;
  14. import net.sf.jsqlparser.schema.Table;
  15. import net.sf.jsqlparser.statement.Statement;
  16. import net.sf.jsqlparser.statement.delete.Delete;
  17. import net.sf.jsqlparser.statement.select.Join;
  18. import net.sf.jsqlparser.statement.select.PlainSelect;
  19. import net.sf.jsqlparser.statement.select.Select;
  20. import net.sf.jsqlparser.statement.select.SubSelect;
  21. import net.sf.jsqlparser.statement.update.Update;
  22. import org.apache.ibatis.executor.statement.StatementHandler;
  23. import org.apache.ibatis.mapping.BoundSql;
  24. import org.apache.ibatis.mapping.MappedStatement;
  25. import org.apache.ibatis.mapping.SqlCommandType;
  26. import org.apache.ibatis.plugin.*;
  27. import org.apache.ibatis.reflection.MetaObject;
  28. import org.apache.ibatis.reflection.SystemMetaObject;
  29. import java.sql.Connection;
  30. import java.sql.DatabaseMetaData;
  31. import java.sql.ResultSet;
  32. import java.sql.SQLException;
  33. import java.util.*;
  34. import java.util.concurrent.ConcurrentHashMap;
  35. /**
  36. * 由于开发人员水平参差不齐,即使订了开发规范很多人也不遵守
  37. * <p>SQL是影响系统性能最重要的因素,所以拦截掉垃圾SQL语句</p>
  38. * <br>
  39. * <p>拦截SQL类型的场景</p>
  40. * <p>1.必须使用到索引,包含left join连接字段,符合索引最左原则</p>
  41. * <p>必须使用索引好处,</p>
  42. * <p>1.1 如果因为动态SQL,bug导致update的where条件没有带上,全表更新上万条数据</p>
  43. * <p>1.2 如果检查到使用了索引,SQL性能基本不会太差</p>
  44. * <br>
  45. * <p>2.SQL尽量单表执行,有查询left join的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left join的语句,如果不能拆成单表执行的SQL,请leader商量在做</p>
  46. * <p>https://gaoxianglong.github.io/shark</p>
  47. * <p>SQL尽量单表执行的好处</p>
  48. * <p>2.1 查询条件简单、易于开理解和维护;</p>
  49. * <p>2.2 扩展性极强;(可为分库分表做准备)</p>
  50. * <p>2.3 缓存利用率高;</p>
  51. * <p>2.在字段上使用函数</p>
  52. * <br>
  53. * <p>3.where条件为空</p>
  54. * <p>4.where条件使用了 !=</p>
  55. * <p>5.where条件使用了 not 关键字</p>
  56. * <p>6.where条件使用了 or 关键字</p>
  57. * <p>7.where条件使用了 使用子查询</p>
  58. *
  59. * @author willenfoo
  60. * @since 2018-03-22
  61. */
  62. @Slf4j
  63. @Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
  64. public class IllegalSqlInterceptor implements Interceptor {
  65. /**
  66. * 缓存验证结果,提高性能
  67. */
  68. private final Set<String> cacheValidResult = new HashSet<>();
  69. /**
  70. * 缓存表的索引信息
  71. */
  72. private final Map<String, IndexInfo> indexInfoCacheMap = new ConcurrentHashMap<>();
  73. private final int indexValidColumnIndex = 8;
  74. private final int indexDBNameColumnIndex = 1;
  75. private final int indexTableNameColumnIndex = 3;
  76. private final int indexColumnNameColumnIndex = 9;
  77. private boolean failFast = false;
  78. private ThreadLocal<String> message = new ThreadLocal<>();
  79. private void set(String message) {
  80. this.message.set(message);
  81. }
  82. /**
  83. * 验证expression对象是不是 or、not等等
  84. *
  85. * @param expression ignore
  86. */
  87. private boolean validExpression(Expression expression) {
  88. //where条件使用了 or 关键字
  89. if (expression instanceof OrExpression) {
  90. OrExpression or = (OrExpression) expression;
  91. set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", or.getStringExpression(), or.toString()));
  92. return true;
  93. } else if (expression instanceof NotEqualsTo) {
  94. NotEqualsTo notEqTo = (NotEqualsTo) expression;
  95. set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", notEqTo.getStringExpression(), notEqTo.toString()));
  96. return true;
  97. } else if (expression instanceof BinaryExpression) {
  98. BinaryExpression binary = (BinaryExpression) expression;
  99. if (binary.isNot()) {
  100. set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", binary.getStringExpression(), binary.toString()));
  101. return true;
  102. }
  103. Expression leftExpression = binary.getLeftExpression();
  104. if (leftExpression instanceof Function) {
  105. Function function = (Function) leftExpression;
  106. set(String.format("Should not use db function in WHERE condition. illegal sql: %s", function.toString()));
  107. return true;
  108. }
  109. Expression rightExpression = binary.getRightExpression();
  110. if (rightExpression instanceof SubSelect) {
  111. SubSelect subSelect = (SubSelect) rightExpression;
  112. set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString()));
  113. return true;
  114. }
  115. } else if (expression instanceof InExpression) {
  116. InExpression in = (InExpression) expression;
  117. ItemsList rightItemsList = in.getRightItemsList();
  118. if (rightItemsList instanceof SubSelect) {
  119. SubSelect subSelect = (SubSelect) rightItemsList;
  120. set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString()));
  121. return true;
  122. }
  123. }
  124. return false;
  125. }
  126. /**
  127. * 如果SQL用了 left Join,验证是否有or、not等等,并且验证是否使用了索引
  128. *
  129. * @param joins ignore
  130. * @param table ignore
  131. * @param connection ignore
  132. */
  133. private void validJoins(List<Join> joins, Table table, Connection connection) {
  134. //允许执行join,验证join是否使用索引等等
  135. if (SqlUtils.notEmpty(joins)) {
  136. for (Join join : joins) {
  137. Table rightTable = (Table) join.getRightItem();
  138. Expression expression = join.getOnExpression();
  139. validWhere(expression, table, rightTable, connection);
  140. }
  141. }
  142. }
  143. /**
  144. * 检查是否使用索引
  145. *
  146. * @param table ignore
  147. * @param columnName ignore
  148. * @param connection ignore
  149. */
  150. private void validUseIndex(Table table, String columnName, Connection connection) {
  151. String tableInfo = table.getName();
  152. //表存在的索引
  153. String dbName = null;
  154. String tableName;
  155. String[] tableArray = tableInfo.split("\\.");
  156. if (tableArray.length == 1) {
  157. tableName = tableArray[0];
  158. } else {
  159. dbName = tableArray[0];
  160. tableName = tableArray[1];
  161. }
  162. String cacheKey = String.format("%s::%s::%s", dbName, tableName, columnName);
  163. boolean useIndexFlag = getIndexInfos(cacheKey, dbName,tableName, connection);
  164. if (!useIndexFlag) {
  165. throw new SqlAnalysisException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
  166. }
  167. }
  168. /**
  169. * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
  170. *
  171. * @param expression ignore
  172. * @param table ignore
  173. * @param connection ignore
  174. */
  175. private void validWhere(Expression expression, Table table, Connection connection) {
  176. validWhere(expression, table, null, connection);
  177. }
  178. /**
  179. * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
  180. *
  181. * @param expression ignore
  182. * @param table ignore
  183. * @param joinTable ignore
  184. * @param connection ignore
  185. */
  186. private boolean validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
  187. boolean valid = validExpression(expression);
  188. if(valid){
  189. return valid;
  190. }
  191. if (expression instanceof BinaryExpression) {
  192. //获得左边表达式
  193. Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
  194. boolean left = validExpression(leftExpression);
  195. if(left){
  196. return left;
  197. }
  198. //如果左边表达式为Column对象,则直接获得列名
  199. if (leftExpression instanceof Column) {
  200. Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
  201. if (joinTable != null && rightExpression instanceof Column) {
  202. if (Objects.equals(((Column) rightExpression).getTable().getName(), table.getAlias().getName())) {
  203. validUseIndex(table, ((Column) rightExpression).getColumnName(), connection);
  204. validUseIndex(joinTable, ((Column) leftExpression).getColumnName(), connection);
  205. } else {
  206. validUseIndex(joinTable, ((Column) rightExpression).getColumnName(), connection);
  207. validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
  208. }
  209. } else {
  210. //获得列名
  211. validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
  212. }
  213. }
  214. //如果BinaryExpression,进行迭代
  215. else if (leftExpression instanceof BinaryExpression) {
  216. validWhere(leftExpression, table, joinTable, connection);
  217. }
  218. //获得右边表达式,并分解
  219. Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
  220. validExpression(rightExpression);
  221. }
  222. return true;
  223. }
  224. /**
  225. * 得到表的索引信息
  226. *
  227. * @param dbName ignore
  228. * @param tableName ignore
  229. * @param conn ignore
  230. * @return ignore
  231. */
  232. public boolean getIndexInfos(String dbName, String tableName, Connection conn) {
  233. return getIndexInfos(null, dbName, tableName, conn);
  234. }
  235. /**
  236. * 得到表的索引信息
  237. *
  238. * @param key ignore
  239. * @param dbName ignore
  240. * @param tableName ignore
  241. * @param conn ignore
  242. * @return ignore
  243. */
  244. public boolean getIndexInfos(String key, String dbName, String tableName, Connection conn) {
  245. //if the indexInfoCacheMap is empty, must get index info from the connection.
  246. if(indexInfoCacheMap.isEmpty()){
  247. try {
  248. DatabaseMetaData metadata = conn.getMetaData();
  249. ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
  250. while (resultSet.next()) {
  251. //索引中的列序列号等于1,才有效
  252. log.info("resultSet: {}", resultSet);
  253. if ("1".equals(resultSet.getString(indexValidColumnIndex))) {
  254. IndexInfo indexInfo = new IndexInfo();
  255. indexInfo.setDbName(resultSet.getString(indexDBNameColumnIndex));
  256. indexInfo.setTableName(resultSet.getString(indexTableNameColumnIndex));
  257. indexInfo.setColumnName(resultSet.getString(indexColumnNameColumnIndex));
  258. //index
  259. indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo);
  260. }
  261. }
  262. } catch (SQLException e) {
  263. log.error("get index info from connection fail.");
  264. }
  265. }
  266. try {
  267. DatabaseMetaData metadata = conn.getMetaData();
  268. ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
  269. while (resultSet.next()) {
  270. //索引中的列序列号等于1,才有效
  271. if (Objects.equals(resultSet.getString(8), "1")) {
  272. IndexInfo indexInfo = new IndexInfo();
  273. indexInfo.setDbName(resultSet.getString(1));
  274. indexInfo.setTableName(resultSet.getString(3));
  275. indexInfo.setColumnName(resultSet.getString(9));
  276. //index
  277. indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo);
  278. }
  279. }
  280. } catch (SQLException e) {
  281. log.error("get index info from connection fail.");
  282. }
  283. return true;
  284. }
  285. @Override
  286. public Object intercept(Invocation invocation) throws Throwable {
  287. StatementHandler statementHandler = SqlUtils.realTarget(invocation.getTarget());
  288. MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
  289. // 如果是insert操作, 或者 @SqlParser(filter = true) 跳过该方法解析 , 不进行验证
  290. MappedStatement mappedStatement = SqlUtils.getMappedStatement(metaObject);
  291. if (SqlCommandType.INSERT.equals(mappedStatement.getSqlCommandType()) || SqlParserHelper.getSqlParserInfo(metaObject)) {
  292. return invocation.proceed();
  293. }
  294. BoundSql boundSql = SqlUtils.getBoundSql(metaObject);
  295. String originalSql = boundSql.getSql();
  296. log.info("检查SQL ID: {}", mappedStatement.getId());
  297. log.info("检查SQL是否合规,SQL: \n{}", originalSql);
  298. String md5Base64 = SqlUtils.md5Base64(originalSql);
  299. log.info("md5Base64: {}", md5Base64);
  300. if (cacheValidResult.contains(md5Base64)) {
  301. return invocation.proceed();
  302. }
  303. Statement statement = CCJSqlParserUtil.parse(originalSql);
  304. Expression where = null;
  305. Table table = null;
  306. List<Join> joins = null;
  307. if (statement instanceof Select) {
  308. PlainSelect plainSelect = (PlainSelect) ((Select) statement).getSelectBody();
  309. where = plainSelect.getWhere();
  310. //table = (Table) plainSelect.getFromItem();
  311. table = plainSelect.getForUpdateTable();
  312. joins = plainSelect.getJoins();
  313. } else if (statement instanceof Update) {
  314. Update update = (Update) statement;
  315. where = update.getWhere();
  316. table = update.getTables().get(0);
  317. joins = update.getJoins();
  318. } else if (statement instanceof Delete) {
  319. Delete delete = (Delete) statement;
  320. where = delete.getWhere();
  321. table = delete.getTable();
  322. joins = delete.getJoins();
  323. }
  324. //where条件不能为空
  325. if (SqlUtils.isNull(where)) {
  326. throw new SqlAnalysisException("非法SQL,必须要有where条件");
  327. }
  328. if(SqlUtils.isNull(table)){
  329. return invocation.proceed();
  330. }
  331. Connection connection = (Connection) invocation.getArgs()[0];
  332. validWhere(where, table, connection);
  333. validJoins(joins, table, connection);
  334. //缓存验证结果
  335. cacheValidResult.add(md5Base64);
  336. return invocation.proceed();
  337. }
  338. @Override
  339. public Object plugin(Object target) {
  340. if (target instanceof StatementHandler) {
  341. return Plugin.wrap(target, this);
  342. }
  343. return target;
  344. }
  345. @Override
  346. public void setProperties(Properties prop) {
  347. String failFast = prop.getProperty("failFast");
  348. if (SqlUtils.isBoolean(failFast)) {
  349. this.failFast = Boolean.valueOf(failFast);
  350. }
  351. }
  352. /**
  353. * the information of the index.
  354. */
  355. @Data
  356. private class IndexInfo {
  357. //the name of the db.
  358. private String dbName;
  359. //the name of the table.
  360. private String tableName;
  361. //the name of the column.
  362. private String columnName;
  363. public String cacheKey(){
  364. return String.format("%s::%s::%s",getDbName(), getTableName(), getColumnName());
  365. }
  366. }
  367. }