Przeglądaj źródła

Merge branch 'develop' of http://132.232.206.88:3000/lift-manager/lift-server into feature-bieao

别傲 5 lat temu
rodzic
commit
5820a952bb

Plik diff jest za duży
+ 95 - 110
lift-common/src/main/java/cn.com.ty.lift.common/judge/Judgement.java


+ 0 - 330
lift-common/src/main/java/cn.com.ty.lift.common/sql/AbstractSqlParser.java

@@ -1,330 +0,0 @@
-package cn.com.ty.lift.common.sql;
-
-import lombok.extern.slf4j.Slf4j;
-import org.apache.ibatis.mapping.BoundSql;
-import org.apache.ibatis.mapping.MappedStatement;
-import org.apache.ibatis.plugin.Invocation;
-import org.apache.ibatis.reflection.MetaObject;
-import org.apache.ibatis.reflection.SystemMetaObject;
-
-import java.lang.reflect.Method;
-import java.lang.reflect.Proxy;
-import java.nio.charset.Charset;
-import java.security.MessageDigest;
-import java.sql.Statement;
-import java.util.*;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
-
-/**
- * SQL解析
- *
- * @author wcz
- * @since 2020/2/27
- */
-@Slf4j
-public abstract class AbstractSqlParser {
-
-    private static final String NEWLINE                        = "\n";
-    private static final String SPACE                          = " ";
-    private static final String DruidPooledPreparedStatement   = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
-    private static final String T4CPreparedStatement           = "oracle.jdbc.driver.T4CPreparedStatement";
-    private static final String OraclePreparedStatementWrapper = "oracle.jdbc.driver.OraclePreparedStatementWrapper";
-    private static final String h_statement                    = "h.statement";
-    private static final String stmt_statement                 = "stmt.statement";
-    private static final String delegate                       = "delegate";
-    private static final String getSql                         = "getSql";
-    private static final String getOriginalSql                 = "getOriginalSql";
-    private              Method oracleGetOriginalSqlMethod;
-    private              Method druidGetSQLMethod;
-
-    private static final SqlFormatter SQL_FORMATTER             = new SqlFormatter();
-    private static final String       h_target                  = "h.target";
-    /**
-     * MD5
-     */
-    private static final String       MD5                       = "MD5";
-    /**
-     * Eight-bit UCS Transformation Format
-     */
-    private static final  Charset      UTF_8                     = Charset.forName("UTF-8");
-    private static final  String       DELEGATE_MAPPED_STATEMENT = "delegate.mappedStatement";
-    private static final  String       DELEGATE_BOUND_SQL        = "delegate.boundSql";
-
-    private Statement hStatement(Invocation invocation) {
-        Statement statement;
-        Object firstArg = invocation.getArgs()[0];
-        if (Proxy.isProxyClass(firstArg.getClass())) {
-            statement = (Statement) SystemMetaObject.forObject(firstArg).getValue(h_statement);
-        } else {
-            statement = (Statement) firstArg;
-        }
-        return statement;
-    }
-
-    private Statement stmtStatement(MetaObject metaObject) {
-        Statement statement = null;
-        try {
-            statement = (Statement) metaObject.getValue(stmt_statement);
-        } catch (Exception e) {
-            // do nothing
-            log.warn("Get stmt.statement from MetaObject fail.");
-        }
-        return statement;
-    }
-
-    private Statement delegateStatement(MetaObject metaObject) {
-        Statement statement = null;
-        if (metaObject.hasGetter(delegate)) {
-            //Hikari
-            try {
-                statement = (Statement) metaObject.getValue(delegate);
-            } catch (Exception ignored) {
-                // do nothing
-                log.warn("Get delegate statement from MetaObject fail.");
-            }
-        }
-        return statement;
-    }
-
-    private String originalSqlInDruidPooled(Statement statement) {
-        String originalSql = null;
-        try {
-            if (isNull(druidGetSQLMethod)) {
-                Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
-                druidGetSQLMethod = clazz.getMethod(getSql);
-                druidGetSQLMethod.setAccessible(true);
-            }
-            Object stmtSql = druidGetSQLMethod.invoke(statement);
-            if (isString(stmtSql)) {
-                originalSql = (String) stmtSql;
-            }
-        } catch (Exception e) {
-            log.error("Get original Sql from DruidPooledPreparedStatement fail.", e);
-        }
-        return originalSql;
-    }
-
-    private String originalSqlInT4COrOracleWrapper(Statement statement) {
-        String originalSql = null;
-        try {
-            if(isNull(oracleGetOriginalSqlMethod)){
-                Class<?> clazz = Class.forName(statement.getClass().getName());
-                oracleGetOriginalSqlMethod = getMethodRegular(clazz, getOriginalSql);
-            }
-            if(notNull(oracleGetOriginalSqlMethod)){
-                //OraclePreparedStatementWrapper is not a public class, need set this.
-                oracleGetOriginalSqlMethod.setAccessible(true);
-                Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
-                if (isString(stmtSql)) {
-                    originalSql = (String) stmtSql;
-                }
-            }
-        } catch (Exception e) {
-            //ignore
-            log.warn("Get original Sql from T4CPreparedStatement | OraclePreparedStatementWrapper fail.");
-        }
-        return originalSql;
-    }
-
-    public String originalSql(Invocation invocation) {
-        Statement statement = hStatement(invocation);
-
-        MetaObject stmtMetaObject = SystemMetaObject.forObject(statement);
-
-        statement = stmtStatement(stmtMetaObject);
-        statement = delegateStatement(stmtMetaObject);
-
-        String originalSql = null;
-        String stmtClassName = statement.getClass().getName();
-        if (DruidPooledPreparedStatement.equals(stmtClassName)) {
-            originalSql = originalSqlInDruidPooled(statement);
-        } else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
-            originalSql = originalSqlInT4COrOracleWrapper(statement);
-        }
-        if (isNull(originalSql)) {
-            originalSql = statement.toString();
-        }
-        return originalSql;
-    }
-
-    /**
-     * 获取此方法名的具体 Method
-     *
-     * @param clazz      class 对象
-     * @param methodName 方法名
-     * @return 方法
-     */
-    private Method getMethodRegular(Class<?> clazz, String methodName) {
-        if (Object.class.equals(clazz)) {
-            return null;
-        }
-        for (Method method : clazz.getDeclaredMethods()) {
-            if (method.getName().equals(methodName)) {
-                return method;
-            }
-        }
-        return getMethodRegular(clazz.getSuperclass(), methodName);
-    }
-
-    /**
-     * 获取sql语句开头部分
-     *
-     * @param sql ignore
-     * @return ignore
-     */
-    private int indexOfSqlStart(String sql) {
-        String upperCaseSql = sql.toUpperCase();
-        Set<Integer> set = new HashSet<>();
-        set.add(upperCaseSql.indexOf("SELECT "));
-        set.add(upperCaseSql.indexOf("UPDATE "));
-        set.add(upperCaseSql.indexOf("INSERT "));
-        set.add(upperCaseSql.indexOf("DELETE "));
-        set.remove(-1);
-        if (SqlUtils.isEmpty(set)) {
-            return -1;
-        }
-        List<Integer> list = new ArrayList<>(set);
-        list.sort(Comparator.naturalOrder());
-        return list.get(0);
-    }
-
-    /**
-     * 格式sql
-     *
-     * @param boundSql
-     * @param format
-     * @return
-     */
-    public String sqlFormat(String boundSql, boolean format) {
-        if (format) {
-            try {
-                return SQL_FORMATTER.format(boundSql);
-            } catch (Exception ignored) {
-            }
-        }
-        return boundSql;
-    }
-
-    public boolean isNull(Object object){
-        return null == object;
-    }
-
-    public boolean notNull(Object object){
-        return null != object;
-    }
-
-    public boolean isString(Object object){
-        return (object instanceof String);
-    }
-    /**
-     * 判断字符串是否为空
-     *
-     * @param cs 需要判断字符串
-     * @return 判断结果
-     */
-    public boolean isEmpty(final CharSequence cs) {
-        int strLen;
-        if (isNull(cs) || (strLen = cs.length()) == 0) {
-            return true;
-        }
-        for (int i = 0; i < strLen; i++) {
-            if (!Character.isWhitespace(cs.charAt(i))) {
-                return false;
-            }
-        }
-        return true;
-    }
-
-    /**
-     * 判断字符串是否不为空
-     *
-     * @param cs 需要判断字符串
-     * @return 判断结果
-     */
-    public boolean isNotEmpty(final CharSequence cs) {
-        return !isEmpty(cs);
-    }
-
-    /**
-     * 校验集合是否为空
-     *
-     * @param coll 入参
-     * @return boolean
-     */
-    public boolean isEmpty(Collection<?> coll) {
-        return (coll == null || coll.isEmpty());
-    }
-
-    /**
-     * 判断是不是数字
-     *
-     * @param cs 输入
-     * @return boolean
-     */
-    public boolean isNumeric(final CharSequence cs) {
-        if (isEmpty(cs)) {
-            return false;
-        }
-        Pattern pattern = Pattern.compile("[0-9]+");
-        Matcher isNum = pattern.matcher(cs);
-        return isNum.matches();
-    }
-
-    /**
-     * 判断是否为boolean类型的字符串
-     *
-     * @param cs 输入
-     * @return boolean
-     */
-    public boolean isBoolean(final CharSequence cs) {
-        if (isEmpty(cs)) {
-            return false;
-        }
-        String input = cs.toString().toLowerCase();
-        return "true".equals(input) || "false".equals(input);
-    }
-
-    /**
-     * 获得真正的处理对象,可能多层代理.
-     */
-    @SuppressWarnings("unchecked")
-    public <T> T realTarget(Object target) {
-        if (Proxy.isProxyClass(target.getClass())) {
-            MetaObject metaObject = SystemMetaObject.forObject(target);
-            return realTarget(metaObject.getValue(h_target));
-        }
-        return (T) target;
-    }
-
-    /**
-     * MD5 Base64 加密
-     *
-     * @param str 待加密的字符串
-     * @return 加密后的字符串
-     */
-    public String md5Base64(String str) {
-        //确定计算方法
-        try {
-            MessageDigest md5 = MessageDigest.getInstance(MD5);
-            //加密后的字符串
-            byte[] src = md5.digest(str.getBytes(UTF_8));
-            return Base64.getEncoder().encodeToString(src);
-        } catch (Exception e) {
-            throw new SqlAnalysisException(e);
-        }
-    }
-
-    /**
-     * 获取当前执行 MappedStatement
-     *
-     * @param metaObject 元对象
-     */
-    public MappedStatement getMappedStatement(MetaObject metaObject) {
-        return (MappedStatement) metaObject.getValue(DELEGATE_MAPPED_STATEMENT);
-    }
-
-    public BoundSql getBoundSql(MetaObject metaObject) {
-        return (BoundSql) metaObject.getValue(DELEGATE_BOUND_SQL);
-    }
-}

+ 96 - 68
lift-common/src/main/java/cn.com.ty.lift.common/sql/IllegalSqlInterceptor.java

@@ -40,12 +40,12 @@ import java.util.concurrent.ConcurrentHashMap;
  * <p>SQL是影响系统性能最重要的因素,所以拦截掉垃圾SQL语句</p>
  * <br>
  * <p>拦截SQL类型的场景</p>
- * <p>1.必须使用到索引,包含left jion连接字段,符合索引最左原则</p>
+ * <p>1.必须使用到索引,包含left join连接字段,符合索引最左原则</p>
  * <p>必须使用索引好处,</p>
  * <p>1.1 如果因为动态SQL,bug导致update的where条件没有带上,全表更新上万条数据</p>
  * <p>1.2 如果检查到使用了索引,SQL性能基本不会太差</p>
  * <br>
- * <p>2.SQL尽量单表执行,有查询left jion的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left jion的语句,如果不能拆成单表执行的SQL,请leader商量在做</p>
+ * <p>2.SQL尽量单表执行,有查询left join的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left join的语句,如果不能拆成单表执行的SQL,请leader商量在做</p>
  * <p>https://gaoxianglong.github.io/shark</p>
  * <p>SQL尽量单表执行的好处</p>
  * <p>2.1 查询条件简单、易于开理解和维护;</p>
@@ -73,45 +73,61 @@ public class IllegalSqlInterceptor implements Interceptor {
     /**
      * 缓存表的索引信息
      */
-    private final Map<String, List<IndexInfo>> indexInfoMap     = new ConcurrentHashMap<>();
+    private final Map<String, IndexInfo> indexInfoCacheMap     = new ConcurrentHashMap<>();
 
+    private final int indexValidColumnIndex = 8;
+    private final int indexDBNameColumnIndex = 1;
+    private final int indexTableNameColumnIndex = 3;
+    private final int indexColumnNameColumnIndex = 9;
+    private boolean             failFast = false;
+    private ThreadLocal<String> message  = new ThreadLocal<>();
+
+    private void set(String message) {
+        this.message.set(message);
+    }
     /**
      * 验证expression对象是不是 or、not等等
      *
      * @param expression ignore
      */
-    private void validExpression(Expression expression) {
+    private boolean validExpression(Expression expression) {
         //where条件使用了 or 关键字
         if (expression instanceof OrExpression) {
-            OrExpression orExpression = (OrExpression) expression;
-            throw new SqlAnalysisException("非法SQL,where条件中不能使用[or]关键字,错误or信息:" + orExpression.toString());
+            OrExpression or = (OrExpression) expression;
+            set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", or.getStringExpression(), or.toString()));
+            return true;
         } else if (expression instanceof NotEqualsTo) {
-            NotEqualsTo notEqualsTo = (NotEqualsTo) expression;
-            throw new SqlAnalysisException("非法SQL,where条件中不能使用[!=]关键字,错误!=信息:" + notEqualsTo.toString());
+            NotEqualsTo notEqTo = (NotEqualsTo) expression;
+            set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", notEqTo.getStringExpression(), notEqTo.toString()));
+            return true;
         } else if (expression instanceof BinaryExpression) {
-            BinaryExpression binaryExpression = (BinaryExpression) expression;
-            if (binaryExpression.isNot()) {
-                throw new SqlAnalysisException("非法SQL,where条件中不能使用[not]关键字,错误not信息:" + binaryExpression.toString());
+            BinaryExpression binary = (BinaryExpression) expression;
+            if (binary.isNot()) {
+                set(String.format("Should not use [%s] in WHERE condition. illegal sql: %s", binary.getStringExpression(), binary.toString()));
+                return true;
             }
-            Expression leftExpression = binaryExpression.getLeftExpression();
+            Expression leftExpression = binary.getLeftExpression();
             if (leftExpression instanceof Function) {
                 Function function = (Function) leftExpression;
-                throw new SqlAnalysisException("非法SQL,where条件中不能使用数据库函数,错误函数信息:" + function.toString());
+                set(String.format("Should not use db function in WHERE condition. illegal sql: %s", function.toString()));
+                return true;
             }
-            Expression rightExpression = binaryExpression.getRightExpression();
+            Expression rightExpression = binary.getRightExpression();
             if (rightExpression instanceof SubSelect) {
                 SubSelect subSelect = (SubSelect) rightExpression;
-                throw new SqlAnalysisException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
+                set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString()));
+                return true;
             }
         } else if (expression instanceof InExpression) {
-            InExpression inExpression = (InExpression) expression;
-            ItemsList rightItemsList = inExpression.getRightItemsList();
+            InExpression in = (InExpression) expression;
+            ItemsList rightItemsList = in.getRightItemsList();
             if (rightItemsList instanceof SubSelect) {
                 SubSelect subSelect = (SubSelect) rightItemsList;
-                throw new SqlAnalysisException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
+                set(String.format("Should not use subSelect in WHERE condition. illegal sql: %s", subSelect.toString()));
+                return true;
             }
         }
-
+        return false;
     }
 
     /**
@@ -122,8 +138,8 @@ public class IllegalSqlInterceptor implements Interceptor {
      * @param connection ignore
      */
     private void validJoins(List<Join> joins, Table table, Connection connection) {
-        //允许执行join,验证jion是否使用索引等等
-        if (joins != null) {
+        //允许执行join,验证join是否使用索引等等
+        if (SqlUtils.notEmpty(joins)) {
             for (Join join : joins) {
                 Table rightTable = (Table) join.getRightItem();
                 Expression expression = join.getOnExpression();
@@ -140,9 +156,6 @@ public class IllegalSqlInterceptor implements Interceptor {
      * @param connection ignore
      */
     private void validUseIndex(Table table, String columnName, Connection connection) {
-        //是否使用索引
-        boolean useIndexFlag = false;
-
         String tableInfo = table.getName();
         //表存在的索引
         String dbName = null;
@@ -154,13 +167,9 @@ public class IllegalSqlInterceptor implements Interceptor {
             dbName = tableArray[0];
             tableName = tableArray[1];
         }
-        List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
-        for (IndexInfo indexInfo : indexInfos) {
-            if (Objects.equals(columnName, indexInfo.getColumnName())) {
-                useIndexFlag = true;
-                break;
-            }
-        }
+        String cacheKey = String.format("%s::%s::%s", dbName, tableName, columnName);
+        boolean useIndexFlag = getIndexInfos(cacheKey, dbName,tableName, connection);
+
         if (!useIndexFlag) {
             throw new SqlAnalysisException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
         }
@@ -185,13 +194,18 @@ public class IllegalSqlInterceptor implements Interceptor {
      * @param joinTable  ignore
      * @param connection ignore
      */
-    private void validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
-        validExpression(expression);
+    private boolean validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
+        boolean valid = validExpression(expression);
+        if(valid){
+            return valid;
+        }
         if (expression instanceof BinaryExpression) {
             //获得左边表达式
             Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
-            validExpression(leftExpression);
-
+            boolean left = validExpression(leftExpression);
+            if(left){
+                return left;
+            }
             //如果左边表达式为Column对象,则直接获得列名
             if (leftExpression instanceof Column) {
                 Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
@@ -217,6 +231,7 @@ public class IllegalSqlInterceptor implements Interceptor {
             Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
             validExpression(rightExpression);
         }
+        return true;
     }
 
     /**
@@ -227,7 +242,7 @@ public class IllegalSqlInterceptor implements Interceptor {
      * @param conn      ignore
      * @return ignore
      */
-    public List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
+    public boolean getIndexInfos(String dbName, String tableName, Connection conn) {
         return getIndexInfos(null, dbName, tableName, conn);
     }
 
@@ -240,35 +255,46 @@ public class IllegalSqlInterceptor implements Interceptor {
      * @param conn      ignore
      * @return ignore
      */
-    public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
-        List<IndexInfo> indexInfos = null;
-        if (SqlUtils.notEmpty(key)) {
-            indexInfos = indexInfoMap.get(key);
-        }
-        if (SqlUtils.isEmpty(indexInfos)) {
-            ResultSet rs;
+    public boolean getIndexInfos(String key, String dbName, String tableName, Connection conn) {
+        //if the indexInfoCacheMap is empty, must get index info from the connection.
+        if(indexInfoCacheMap.isEmpty()){
             try {
                 DatabaseMetaData metadata = conn.getMetaData();
-                rs = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
-                indexInfos = new ArrayList<>();
-                while (rs.next()) {
+                ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
+                while (resultSet.next()) {
                     //索引中的列序列号等于1,才有效
-                    if (Objects.equals(rs.getString(8), "1")) {
+                    log.info("resultSet: {}", resultSet);
+                    if ("1".equals(resultSet.getString(indexValidColumnIndex))) {
                         IndexInfo indexInfo = new IndexInfo();
-                        indexInfo.setDbName(rs.getString(1));
-                        indexInfo.setTableName(rs.getString(3));
-                        indexInfo.setColumnName(rs.getString(9));
-                        indexInfos.add(indexInfo);
+                        indexInfo.setDbName(resultSet.getString(indexDBNameColumnIndex));
+                        indexInfo.setTableName(resultSet.getString(indexTableNameColumnIndex));
+                        indexInfo.setColumnName(resultSet.getString(indexColumnNameColumnIndex));
+                        //index
+                        indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo);
                     }
                 }
-                if (SqlUtils.notEmpty(key)) {
-                    indexInfoMap.put(key, indexInfos);
-                }
             } catch (SQLException e) {
-                e.printStackTrace();
+                log.error("get index info from connection fail.");
+            }
+        }
+        try {
+            DatabaseMetaData metadata = conn.getMetaData();
+            ResultSet resultSet = metadata.getIndexInfo(dbName, dbName, tableName, false, true);
+            while (resultSet.next()) {
+                //索引中的列序列号等于1,才有效
+                if (Objects.equals(resultSet.getString(8), "1")) {
+                    IndexInfo indexInfo = new IndexInfo();
+                    indexInfo.setDbName(resultSet.getString(1));
+                    indexInfo.setTableName(resultSet.getString(3));
+                    indexInfo.setColumnName(resultSet.getString(9));
+                    //index
+                    indexInfoCacheMap.put(indexInfo.cacheKey(), indexInfo);
+                }
             }
+        } catch (SQLException e) {
+            log.error("get index info from connection fail.");
         }
-        return indexInfos;
+        return true;
     }
 
     @Override
@@ -282,10 +308,11 @@ public class IllegalSqlInterceptor implements Interceptor {
         }
         BoundSql boundSql = SqlUtils.getBoundSql(metaObject);
         String originalSql = boundSql.getSql();
-        log.info("检查SQL是否合规,SQL: {}", originalSql);
+        log.info("检查SQL ID: {}", mappedStatement.getId());
+        log.info("检查SQL是否合规,SQL: \n{}", originalSql);
         String md5Base64 = SqlUtils.md5Base64(originalSql);
+        log.info("md5Base64: {}", md5Base64);
         if (cacheValidResult.contains(md5Base64)) {
-            log.info("该SQL已验证,无需再次验证, SQL: {}", originalSql);
             return invocation.proceed();
         }
         Statement statement = CCJSqlParserUtil.parse(originalSql);
@@ -334,7 +361,10 @@ public class IllegalSqlInterceptor implements Interceptor {
 
     @Override
     public void setProperties(Properties prop) {
-
+        String failFast = prop.getProperty("failFast");
+        if (SqlUtils.isBoolean(failFast)) {
+            this.failFast = Boolean.valueOf(failFast);
+        }
     }
 
     /**
@@ -342,17 +372,15 @@ public class IllegalSqlInterceptor implements Interceptor {
      */
     @Data
     private class IndexInfo {
-        /**
-         * the name of the db.
-         */
+        //the name of the db.
         private String dbName;
-        /**
-         * the name of the table.
-         */
+        //the name of the table.
         private String tableName;
-        /**
-         * the name of the column.
-         */
+        //the name of the column.
         private String columnName;
+
+        public String cacheKey(){
+            return String.format("%s::%s::%s",getDbName(), getTableName(), getColumnName());
+        }
     }
 }

+ 85 - 62
lift-common/src/main/java/cn.com.ty.lift.common/sql/SqlAnalysisInterceptor.java

@@ -87,72 +87,17 @@ public class SqlAnalysisInterceptor implements Interceptor {
 
     @Override
     public Object intercept(Invocation invocation) throws Throwable {
-        Statement statement;
-        Object firstArg = invocation.getArgs()[0];
-        if (Proxy.isProxyClass(firstArg.getClass())) {
-            statement = (Statement) SystemMetaObject.forObject(firstArg).getValue(h_statement);
-        } else {
-            statement = (Statement) firstArg;
-        }
-        MetaObject stmtMetaObject = SystemMetaObject.forObject(statement);
-        if (stmtMetaObject.hasGetter(stmt)) {
-            try {
-                statement = (Statement) stmtMetaObject.getValue(stmt_statement);
-            } catch (Exception ignored) {
-                // do nothing
-                log.warn("Get stmt.statement from MetaObject fail.");
-            }
-        }
-        if (stmtMetaObject.hasGetter(delegate)) {
-            //Hikari
-            try {
-                statement = (Statement) stmtMetaObject.getValue(delegate);
-            } catch (Exception ignored) {
-                // do nothing
-                log.warn("Get delegate statement from MetaObject fail.");
-            }
-        }
+        // get statement from the invocation.
+        Statement statement = getStatement(invocation);
+        // if statement is null, just return.
         if (SqlUtils.isNull(statement)) {
             return invocation.proceed();
         }
-        String originalSql = null;
-        String stmtClassName = statement.getClass().getName();
-        if (DruidPooledPreparedStatement.equals(stmtClassName)) {
-            try {
-                if (SqlUtils.isNull(druidGetSQLMethod)) {
-                    Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
-                    druidGetSQLMethod = clazz.getMethod(getSql);
-                }
-                druidGetSQLMethod.setAccessible(true);
-                Object stmtSql = druidGetSQLMethod.invoke(statement);
-                if (SqlUtils.isString(stmtSql)) {
-                    originalSql = (String) stmtSql;
-                }
-            } catch (Exception e) {
-                log.error("Get original Sql from DruidPooledPreparedStatement fail.", e);
-            }
-        } else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
-            try {
-                if (SqlUtils.isNull(oracleGetOriginalSqlMethod)) {
-                    Class<?> clazz = Class.forName(stmtClassName);
-                    oracleGetOriginalSqlMethod = getMethodRegular(clazz, getOriginalSql);
-                }
+        // get originalSql from the statement.
+        String originalSql = getOriginalSql(statement);
 
-                if (SqlUtils.notNull(oracleGetOriginalSqlMethod)) {
-                    //OraclePreparedStatementWrapper is not a public class, need set this.
-                    oracleGetOriginalSqlMethod.setAccessible(true);
-                    Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
-                    if (SqlUtils.isString(stmtSql)) {
-                        originalSql = (String) stmtSql;
-                    }
-                }
-            } catch (Exception ignored) {
-                //ignore
-                log.warn("Get original Sql from T4CPreparedStatement | OraclePreparedStatementWrapper fail.");
-            }
-        }
-        if (SqlUtils.isNull(originalSql)) {
-            originalSql = statement.toString();
+        if (SqlUtils.isEmpty(originalSql)) {
+            return invocation.proceed();
         }
         //\s用于匹配空白字符, \\s用于匹配字符串中的\和s,两个字符
         //去掉原始sql中的格式,换行,全部格式化为一行sql
@@ -258,4 +203,82 @@ public class SqlAnalysisInterceptor implements Interceptor {
         return list.get(0);
     }
 
+    /**
+     * get sql {@link Statement} from the {@link Invocation}
+     *
+     * @param invocation the invocation inject into.
+     * @return the sql statement.
+     */
+    private Statement getStatement(Invocation invocation) {
+        Statement statement;
+        Object firstArg = invocation.getArgs()[0];
+        if (Proxy.isProxyClass(firstArg.getClass())) {
+            statement = (Statement) SystemMetaObject.forObject(firstArg).getValue(h_statement);
+        } else {
+            statement = (Statement) firstArg;
+        }
+        MetaObject stmtMetaObject = SystemMetaObject.forObject(statement);
+        if (stmtMetaObject.hasGetter(stmt)) {
+            try {
+                statement = (Statement) stmtMetaObject.getValue(stmt_statement);
+            } catch (Exception ignored) {
+                // do nothing
+            }
+        }
+        if (stmtMetaObject.hasGetter(delegate)) {
+            //Hikari
+            try {
+                statement = (Statement) stmtMetaObject.getValue(delegate);
+            } catch (Exception ignored) {
+                // do nothing
+            }
+        }
+        return statement;
+    }
+
+    /**
+     * get Original Sql with parameter from the sql {@link Statement}
+     *
+     * @param statement the sql statement.
+     * @return the Original Sql string.
+     */
+    private String getOriginalSql(Statement statement) {
+        Object stmtSql = null;
+        String stmtClassName = statement.getClass().getName();
+        if (DruidPooledPreparedStatement.equals(stmtClassName)) {
+            try {
+                if (SqlUtils.isNull(druidGetSQLMethod)) {
+                    Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
+                    druidGetSQLMethod = clazz.getMethod(getSql);
+                }
+                druidGetSQLMethod.setAccessible(true);
+                stmtSql = druidGetSQLMethod.invoke(statement);
+            } catch (Exception e) {
+                log.error("Get original Sql from DruidPooledPreparedStatement fail.", e);
+            }
+        } else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
+            try {
+                if (SqlUtils.isNull(oracleGetOriginalSqlMethod)) {
+                    Class<?> clazz = Class.forName(stmtClassName);
+                    oracleGetOriginalSqlMethod = getMethodRegular(clazz, getOriginalSql);
+                }
+                if (SqlUtils.notNull(oracleGetOriginalSqlMethod)) {
+                    //OraclePreparedStatementWrapper is not a public class, need set this.
+                    oracleGetOriginalSqlMethod.setAccessible(true);
+                    stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
+                }
+            } catch (Exception ignored) {
+                //ignore
+                log.warn("Get original Sql from T4CPreparedStatement | OraclePreparedStatementWrapper fail.");
+            }
+        }
+        String originalSql = null;
+        if (SqlUtils.isString(stmtSql)) {
+            originalSql = (String) stmtSql;
+        }
+        if (SqlUtils.isNull(originalSql)) {
+            originalSql = statement.toString();
+        }
+        return originalSql;
+    }
 }

+ 1 - 1
lift-common/src/main/java/cn.com.ty.lift.common/sql/SqlUtils.java

@@ -74,7 +74,7 @@ public class SqlUtils {
     }
 
     public static boolean isString(Object object){
-        return (object instanceof String);
+        return notNull(object) && (object instanceof String);
     }
 
     /**

+ 1 - 1
lift-enterprise-service/src/main/java/cn/com/ty/lift/enterprise/oa/controller/AttendanceController.java

@@ -186,7 +186,7 @@ public class AttendanceController {
         Integer type = entity.getType();
         //先统计当天是否有对应的打卡记录
         int count = attendanceService.countByUserAndType(mtCompanyId,userId,type);
-        Judge.lt0(count,Judge.Attend.hadClock);
+        Judge.nogt0(count,Judge.Attend.hadClock);
         MaintenanceCompany maintenanceCompany = maintenanceCompanyService.getById(mtCompanyId);
         Judge.notNull(maintenanceCompany);
 

+ 1 - 1
lift-enterprise-service/src/main/java/cn/com/ty/lift/enterprise/oa/controller/LiftCertificateController.java

@@ -106,7 +106,7 @@ public class LiftCertificateController {
         Long mtCompanyId = entity.getMtCompanyId();
 
         int count = liftCertificateService.countByUserAndMtCompany(ownerId, mtCompanyId);
-        Judge.lt0(count, Judge.LiftCert.CertHadExist);
+        Judge.nogt0(count, Judge.LiftCert.CertHadExist);
 
         entity.setStatus(Judge.LiftCert.Status_ToAudit);
 

+ 0 - 1
lift-enterprise-service/src/main/java/cn/com/ty/lift/enterprise/oa/entity/LiftCertificate.java

@@ -79,7 +79,6 @@ public class LiftCertificate extends BaseEntity {
     /**
      * 0:无证,1:待审核,2:审核未通过,3:审核通过,4:超期
      */
-    @NotNull(message = "操作证状态不能为空")
     @Range(max = 4, message = "操作证状态有误")
     private Integer status;
     

+ 0 - 5
lift-system-service/src/main/java/cn/com/ty/lift/system/homepage/dao/dto/request/PlatformCalendarRequest.java

@@ -30,11 +30,6 @@ public class PlatformCalendarRequest {
      */
     private String monthBeginStr;
 
-    /**
-     * 当前时间
-     */
-    private String nowStr;
-
     /**
      * 月末时间 年月日
      */

+ 4 - 2
lift-system-service/src/main/java/cn/com/ty/lift/system/homepage/service/PlatformCalendarService.java

@@ -12,7 +12,6 @@ import org.apache.commons.lang3.StringUtils;
 import org.springframework.stereotype.Service;
 
 import javax.annotation.Resource;
-import java.time.LocalDate;
 import java.util.*;
 import java.util.stream.Collectors;
 
@@ -105,8 +104,11 @@ public class PlatformCalendarService {
         //获取年检记录,并将年检记录转化为 天 -> 年检数据(电梯 -> 年检数据)
         Map<Integer, Map<Long, List<AnnualInspectionDataModel>>> dayToAnnualInspection = getDayToAnnualInspection(
                 platformCalendarRequest);
+        //获取请求日期的天数
+        int monthDayNum = CommonUtil.getLastDayNumOfMonthFromDateStr(
+                platformCalendarRequest.getRequestDateDayStr(), CommonConstants.PlatformCalendarConstants.REQUEST_DATE_FORMAT);
         //循环设置每日的数据
-        for (int day = 1; day <= LocalDate.now().getDayOfMonth(); day++) {
+        for (int day = 1; day <= monthDayNum; day++) {
             addCalendarLift(dayToCalendarLift, dayToMtRecord, dayToMtPlan, dayToEmergencyRecord, dayToAnnualInspection, day);
         }
         return dayToCalendarLift;

+ 81 - 35
lift-system-service/src/main/java/cn/com/ty/lift/system/user/service/impl/LoginService.java

@@ -144,30 +144,17 @@ public class LoginService implements ILoginService {
 
     @Override
     public RestResponse changeTeam(HttpServletRequest request, UserRequest userRequest) {
-        Long companyId = userRequest.getCompanyId();
         UserResponse userResponse = new UserResponse();
-        //校验团队信息
-        MaintenanceCompany maintenanceCompany = maintenanceCompanyService.getById(companyId);
-        if (maintenanceCompany == null) {
-            return RestResponse.success(ApiConstants.RESULT_SUCCESS, "切换团队不存在");
+        //设置token信息
+        userResponse.setToken(userRequest.getToken());
+        //校验用户信息,并设置用户信息
+        RestResponse restResponse = checkAndSetChangeInfo(userRequest,userResponse);
+        //如果校验信息不为空返回校验信息
+        if (restResponse != null) {
+            return restResponse;
         }
-        maintenanceCompany.setCurrentTeamFlag(true);
-        userResponse.setMaintenanceCompany(maintenanceCompany);
-        UserRole userRole = userRoleService.getOne(new QueryWrapper<UserRole>()
-                .eq("company_id", companyId)
-                .eq("user_id", userRequest.getUserId())
-        );
-        if (userRole != null) {
-            Role role = roleService.getById(userRole.getRoleId());
-            userResponse.setRole(role);
-            //获取菜单信息
-            if (role != null) {
-                //设置菜单树
-                List<Menu> menuTree = roleMenuService.getMenuTree(role.getId());
-                userResponse.setMenus(menuTree);
-            }
-        }
-        //更新redis中数据值
+        //更新redis中的信息数据
+        updateUserInfoInRedis(userResponse);
         return RestResponse.success(userResponse, ApiConstants.RESULT_SUCCESS, "切换团队成功");
     }
 
@@ -189,6 +176,19 @@ public class LoginService implements ILoginService {
         return RestResponse.success(userResponse, ApiConstants.RESULT_SUCCESS, "用户登录成功");
     }
 
+    @Override
+    public RestResponse verifySmsCode(String mobile, String inputSmsCode) {
+        Object smsCode = redisTemplate.opsForValue().get(mobile + AliConstants.SmsConstants.SMS_CODE_FIELD);
+        if (smsCode != null) {
+            if (!(smsCode.equals(inputSmsCode))) {
+                return RestResponse.fail(ApiConstants.RESULT_ERROR, "短信验证码输入有误");
+            }
+        } else {
+            return RestResponse.fail(ApiConstants.RESULT_ERROR, "验证码过期,请重新获取验证码");
+        }
+        return null;
+    }
+
     /**
      * @param
      * @return
@@ -212,19 +212,6 @@ public class LoginService implements ILoginService {
         return userResponse;
     }
 
-    @Override
-    public RestResponse verifySmsCode(String mobile, String inputSmsCode) {
-        Object smsCode = redisTemplate.opsForValue().get(mobile + AliConstants.SmsConstants.SMS_CODE_FIELD);
-        if (smsCode != null) {
-            if (!(smsCode.equals(inputSmsCode))) {
-                return RestResponse.fail(ApiConstants.RESULT_ERROR, "短信验证码输入有误");
-            }
-        } else {
-            return RestResponse.fail(ApiConstants.RESULT_ERROR, "验证码过期,请重新获取验证码");
-        }
-        return null;
-    }
-
     /**
      * @param userResponse 用户信息
      * @return map 用户信息map
@@ -258,4 +245,63 @@ public class LoginService implements ILoginService {
         currentUserInfoMap.put(ApiConstants.CURRENT_PERMISSION_URL, permissionUrl);
         return currentUserInfoMap;
     }
+
+    /**
+     * @param userRequest  请求参数
+     * @param userResponse 用户信息
+     * @return
+     * @description 检验并设置更换团队中角色的信息
+     * @date 2020/2/26 1:53 下午
+     */
+    private RestResponse checkAndSetChangeInfo(UserRequest userRequest, UserResponse userResponse) {
+        //校验团队信息
+        MaintenanceCompany maintenanceCompany = maintenanceCompanyService.getById(userRequest.getCompanyId());
+        if (maintenanceCompany == null) {
+            return RestResponse.success(ApiConstants.RESULT_SUCCESS, "要切换的团队不存在");
+        }
+        maintenanceCompany.setCurrentTeamFlag(true);
+        userResponse.setMaintenanceCompany(maintenanceCompany);
+        //获取用户在切换团队的角色信息
+        UserRole userRole = userRoleService.getOne(new QueryWrapper<UserRole>()
+                .eq("company_id", userRequest.getCompanyId())
+                .eq("user_id", userRequest.getUserId())
+        );
+        //校验角色信息
+        if (userRole == null) {
+            return RestResponse.success(ApiConstants.RESULT_ERROR, "用户在要切换的团队没有角色,无法切换");
+        }
+        Role role = roleService.getById(userRole.getRoleId());
+        if (role == null) {
+            return RestResponse.success(ApiConstants.RESULT_ERROR, "用户在要切换的团队没有角色,无法切换");
+        }
+        userResponse.setRole(role);
+        //获取角色对应的菜单信息
+        List<Menu> menuTree = roleMenuService.getMenuTree(role.getId());
+        //校验菜单信息
+        if (menuTree == null || menuTree.size() <= 0) {
+            return RestResponse.success(ApiConstants.RESULT_ERROR, "切换团队中用户没有任何菜单,无法切换");
+        }
+        userResponse.setMenus(menuTree);
+        return null;
+    }
+
+    /**
+     * @param
+     * @return
+     * @description 更新用户在redis中的信息
+     * @date 2020/2/26 12:52 下午
+     */
+    private void updateUserInfoInRedis(UserResponse userResponse) {
+        //从redis中获取用户信息
+        String mobile = (String) redisTemplate.opsForValue().get(userResponse.getToken());
+        Map<String, Object> userInfoMap = JSONUtil.parseObj(redisTemplate.opsForValue().get(mobile));
+        //设置用户角色id
+        userInfoMap.put(ApiConstants.CURRENT_ROLE_ID, userResponse.getRole().getId());
+        //设置用户公司id
+        userInfoMap.put(ApiConstants.CURRENT_COMPANY_ID, userResponse.getMaintenanceCompany().getId());
+        //设置用户菜单信息
+        userInfoMap.put(ApiConstants.CURRENT_PERMISSION_URL, menuService.menuToUrl(userResponse.getMenus()));
+        //重新将用户信息放入到redis中
+        redisTemplate.opsForValue().set(mobile, JSONUtil.toJsonPrettyStr(userInfoMap), 30L, TimeUnit.MINUTES);
+    }
 }

+ 4 - 2
lift-system-service/src/main/java/cn/com/ty/lift/system/user/service/impl/MtCompanyUserService.java

@@ -269,9 +269,11 @@ public class MtCompanyUserService extends ServiceImpl<MtCompanyUserMapper, MtCom
             List<UserInfo> userInfoList = (List<UserInfo>) userInfoService.listByIds(userIdList);
             Map<Long, UserInfo> userIdToUserInfo = ProjectUtils.attrToObjMap(userInfoList, "userId", null);
             //获取团队用户角色信息
-            Map<Long, Role> userIdToRole = roleService.getUserIdToRoleByUserIdsAndCompanyId(userIdList, mtCompanyUserRequest.getCompanyId());
+            Map<Long, Role> userIdToRole = roleService.getUserIdToRoleByUserIdsAndCompanyId(userIdList,
+                    mtCompanyUserRequest.getCompanyId());
             //获取团队用户操作证信息
-            Map<Long, LiftCertificate> userIdToLiftCertificate = projectService.getUserIdToLiftCertificateByUserIdListAndCompanyId(userIdList, mtCompanyUserRequest.getCompanyId());
+            Map<Long, LiftCertificate> userIdToLiftCertificate = projectService
+                    .getUserIdToLiftCertificateByUserIdListAndCompanyId(userIdList, mtCompanyUserRequest.getCompanyId());
             for (MtCompanyUser mtCompanyUser : mtCompanyUserList) {
                 AppCompanyUserResponse appCompanyUserResponse = new AppCompanyUserResponse();
                 //设置用户信息

+ 29 - 5
lift-system-service/src/main/java/cn/com/ty/lift/system/utils/CommonUtil.java

@@ -94,7 +94,7 @@ public class CommonUtil {
     public static void setMonthDate(PlatformCalendarRequest platformCalendarRequest) {
         String dateStr = platformCalendarRequest.getRequestDateStr();
         //如果请求时间为空,就设置带天的请求时间
-        if(StringUtils.isBlank(dateStr)) {
+        if (StringUtils.isBlank(dateStr)) {
             dateStr = platformCalendarRequest.getRequestDateDayStr();
         }
         Map<String, String> dateTypeToDateValue = CommonUtil.getMonthDate(dateStr,
@@ -102,9 +102,6 @@ public class CommonUtil {
         //设置月初时间
         platformCalendarRequest.setMonthBeginStr(dateTypeToDateValue.get(
                 CommonConstants.PlatformCalendarConstants.MONTH_BEGIN_DATE));
-        //设置当期时间
-        platformCalendarRequest.setNowStr(LocalDate.now().format(DateTimeFormatter.ofPattern(
-                CommonConstants.PlatformCalendarConstants.TRANS_DATE_FORMAT)));
         //设置月末时间
         platformCalendarRequest.setMonthEndStr(dateTypeToDateValue.get(
                 CommonConstants.PlatformCalendarConstants.MONTH_END_DATE));
@@ -121,7 +118,8 @@ public class CommonUtil {
         LocalDate transDate = LocalDate.parse(timeDateStr, DateTimeFormatter.ofPattern(timeFormat));
         Map<Integer, Map<Integer, Long>> dayToLiftStatusToNum = new HashMap<>();
         if (transDate != null) {
-            for (int monthDay = 1; monthDay <= transDate.getMonthValue(); monthDay++) {
+            int dayNum = getLastDayNumOfMonth(transDate);
+            for (int monthDay = 1; monthDay <= dayNum; monthDay++) {
                 Map<Integer, Long> liftStatusToNum = new HashMap<>();
                 for (int liftStatus : CommonConstants.PlatformCalendarConstants.LIFT_STATUS_ARRAY) {
                     liftStatusToNum.put(liftStatus, 0L);
@@ -205,4 +203,30 @@ public class CommonUtil {
         return dayToDataMap;
     }
 
+    /**
+     * @param
+     * @return
+     * @description 获取某月的天数
+     */
+    public static int getLastDayNumOfMonth(LocalDate date) {
+        if (date != null) {
+            return date.with(TemporalAdjusters.lastDayOfMonth()).getDayOfMonth();
+        }
+        return 0;
+    }
+
+    /**
+     * @param
+     * @return
+     * @description 获取给点日期字符串指定格式的月的天数
+     * @date 2020/2/28 2:57 下午
+     */
+    public static int getLastDayNumOfMonthFromDateStr(String dateStr, String format) {
+        if (StringUtils.isNotBlank(dateStr)) {
+            LocalDate date = LocalDate.parse(dateStr, DateTimeFormatter.ofPattern(format));
+            return getLastDayNumOfMonth(date);
+        }
+        return 0;
+    }
+
 }