0


MyBatis的动态拦截sql并修改

需求

因工作需求,需要根据用户的数据权限,来查询并展示相应的数据,那么就需要动态拦截sql,在根据用户权限做相应的处理,因此需要一个通用拦截器,并以注解实现。该文只做查询拦截,如有其他需求,可根据工作做相应更改


步骤一

该注解是方法级,因此需要注解在dao层方法上,如有需要也可更改为类级
注解:

@Retention(RetentionPolicy.RUNTIME)@Target(ElementType.METHOD)@Documented// 指名数据库查询方法需要和权限挂钩public@interfacePermission{}

步骤二

定义拦截器实现接口重写其intercept方法

@Intercepts({//        @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})//        @Signature( type = Executor.class, method = "update",args = {MappedStatement.class, Object.class}),@Signature(type =Executor.class, method ="query",args ={MappedStatement.class,Object.class,RowBounds.class,ResultHandler.class}),@Signature(type =Executor.class, method ="query",args ={MappedStatement.class,Object.class,RowBounds.class,ResultHandler.class,CacheKey.class,BoundSql.class})})@ComponentpublicclassPermissionInterceptorimplementsInterceptor{@OverridepublicObjectintercept(Invocation invocation)throwsThrowable{}}

步骤三

拿到所有查询sql请求,并得到相应的statement

@Intercepts({//        @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})//        @Signature( type = Executor.class, method = "update",args = {MappedStatement.class, Object.class}),@Signature(type =Executor.class, method ="query",args ={MappedStatement.class,Object.class,RowBounds.class,ResultHandler.class}),@Signature(type =Executor.class, method ="query",args ={MappedStatement.class,Object.class,RowBounds.class,ResultHandler.class,CacheKey.class,BoundSql.class})})@ComponentpublicclassPermissionInterceptorimplementsInterceptor{@OverridepublicObjectintercept(Invocation invocation)throwsThrowable{String processSql =ExecutorPluginUtils.getSqlByInvocation(invocation);// 执行自定义修改sql操作// 获取sqlString sql2Reset = processSql;Statement statement =CCJSqlParserUtil.parse(processSql);MappedStatement mappedStatement =(MappedStatement) invocation.getArgs()[0];}}

步骤四

如果后端未用分页,则这步可以省略在项目启动类下完成该配置

//得到spring上下文ConfigurableApplicationContext run =SpringApplication.run(Application.class, args);Interceptor permissionInterceptor =(Interceptor) run.getBean("permissionInterceptor");//这种方式添加mybatis拦截器保证在pageHelper前执行
         run.getBean(SqlSessionFactory.class).getConfiguration().addInterceptor(permissionInterceptor);

步骤五

工具类

packagecom.ydy.common.utils;importcom.ydy.common.annotation.Permission;importorg.apache.ibatis.mapping.BoundSql;importorg.apache.ibatis.mapping.MappedStatement;importorg.apache.ibatis.mapping.SqlCommandType;importorg.apache.ibatis.mapping.SqlSource;importorg.apache.ibatis.plugin.Invocation;importorg.apache.ibatis.reflection.DefaultReflectorFactory;importorg.apache.ibatis.reflection.MetaObject;importorg.apache.ibatis.reflection.factory.DefaultObjectFactory;importorg.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;importjava.lang.reflect.Method;importjava.lang.reflect.Type;importjava.sql.SQLException;importjava.util.Arrays;importjava.util.Objects;publicclassExecutorPluginUtils{/**
     * 获取sql语句
     * @param invocation
     * @return
     */publicstaticStringgetSqlByInvocation(Invocation invocation){finalObject[] args = invocation.getArgs();MappedStatement ms =(MappedStatement) args[0];Object parameterObject = args[1];BoundSql boundSql = ms.getBoundSql(parameterObject);return boundSql.getSql();}/**
     * 包装sql后,重置到invocation中
     * @param invocation
     * @param sql
     * @throws SQLException
     */publicstaticvoidresetSql2Invocation(Invocation invocation,String sql)throwsSQLException{finalObject[] args = invocation.getArgs();MappedStatement statement =(MappedStatement) args[0];Object parameterObject = args[1];BoundSql boundSql = statement.getBoundSql(parameterObject);MappedStatement newStatement =newMappedStatement(statement,newBoundSqlSqlSource(boundSql));MetaObject msObject =MetaObject.forObject(newStatement,newDefaultObjectFactory(),newDefaultObjectWrapperFactory(),newDefaultReflectorFactory());
        msObject.setValue("sqlSource.boundSql.sql", sql);
        args[0]= newStatement;}privatestaticMappedStatementnewMappedStatement(MappedStatement ms,SqlSource newSqlSource){MappedStatement.Builder builder =newMappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());if(ms.getKeyProperties()!=null&& ms.getKeyProperties().length !=0){StringBuilder keyProperties =newStringBuilder();for(String keyProperty : ms.getKeyProperties()){
                keyProperties.append(keyProperty).append(",");}
            keyProperties.delete(keyProperties.length()-1, keyProperties.length());
            builder.keyProperty(keyProperties.toString());}
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());return builder.build();}/**
     * 是否标记为区域字段
     * @return
     */publicstaticbooleanisAreaTag(MappedStatement mappedStatement)throwsClassNotFoundException{String id = mappedStatement.getId();//获取类名String className = id.substring(0, id.lastIndexOf("."));Class clazz =Class.forName(className);//获取方法名String methodName = id.substring(id.lastIndexOf(".")+1);//这里是博主工作需求,防止pagehelper那里未生效if(methodName.contains("_COUNT")){
            methodName=methodName.replace("_COUNT","");}String m=methodName;Class<?> classType =Class.forName(id.substring(0,mappedStatement.getId().lastIndexOf(".")));//获取对应拦截方法名String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".")+1);//这里是博主工作需求,防止pagehelper那里未生效if(mName.contains("_COUNT")){
            mName=mName.replace("_COUNT","");}boolean ignore =false;//获取该类(接口)的所有方法,如果你查询的方法就写在该类,就不需要下面的if判断Method[] declaredMethods = classType.getDeclaredMethods();Method declaredMethod =Arrays.stream(declaredMethods).filter(it -> it.getName().equals(m)).findFirst().orElse(null);//该判断是拿到该接口的超类的方法,博主的查询方法就在超类里,因此需要利用下面代码来获取对应方法if(declaredMethod ==null){Type[] genericInterfaces = clazz.getGenericInterfaces();
            declaredMethod =Arrays.stream(genericInterfaces).map(e ->{Method[] declaredMethods1 =((Class) e).getDeclaredMethods();returnArrays.stream(declaredMethods1).filter(it -> it.getName().equals(m)).findFirst().orElse(null);}).filter(Objects::nonNull).findFirst().orElse(null);}if(declaredMethod!=null){//查询方法是否被permission标记注解
            ignore = declaredMethod.isAnnotationPresent(Permission.class);}return ignore;}/**
     * 是否标记为区域字段
     * @return
     */publicstaticbooleanisAreaTagIngore(MappedStatement mappedStatement)throwsClassNotFoundException{String id = mappedStatement.getId();String className = id.substring(0, id.lastIndexOf("."));Class clazz =Class.forName(className);String methodName = id.substring(id.lastIndexOf(".")+1);Class<?> classType =Class.forName(id.substring(0,mappedStatement.getId().lastIndexOf(".")));//获取对应拦截方法名String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".")+1);boolean ignore =false;Method[] declaredMethods = classType.getDeclaredMethods();Method declaredMethod =Arrays.stream(declaredMethods).filter(it -> it.getName().equals(methodName)).findFirst().orElse(null);if(declaredMethod ==null){Type[] genericInterfaces = clazz.getGenericInterfaces();
            declaredMethod =Arrays.stream(genericInterfaces).map(e ->{Method[] declaredMethods1 =((Class) e).getDeclaredMethods();returnArrays.stream(declaredMethods1).filter(it -> it.getName().equals(methodName)).findFirst().orElse(null);}).filter(Objects::nonNull).findFirst().orElse(null);}
        ignore = declaredMethod.isAnnotationPresent(Permission.class);return ignore;}publicstaticStringgetOperateType(Invocation invocation){finalObject[] args = invocation.getArgs();MappedStatement ms =(MappedStatement) args[0];SqlCommandType commondType = ms.getSqlCommandType();if(commondType.compareTo(SqlCommandType.SELECT)==0){return"select";}returnnull;}//    定义一个内部辅助类,作用是包装sqstaticclassBoundSqlSqlSourceimplementsSqlSource{privateBoundSql boundSql;publicBoundSqlSqlSource(BoundSql boundSql){this.boundSql = boundSql;}@OverridepublicBoundSqlgetBoundSql(Object parameterObject){return boundSql;}}}

步骤六

如果方法被permission注解进入if方法,查询各自数据权限,拼接sql,替换sql。如未进入则放行。

if(ExecutorPluginUtils.isAreaTag(mappedStatement)){//获取该用户所具有的角色的数据权限dataScope//因数据敏感省略//获取该用户的所在公司或部门下的所有人//例如 StringBuffer orgBuffer = new StringBuffer();// orgBuffer.append("(");//String collect = allUserByOrgs.stream().map(String::valueOf).collect(Collectors.joining(","));//orgBuffer.append(collect).append(")");//String orgsUser = orgBuffer.toString();try{if(statement instanceofSelect){Select selectStatement =(Select) statement;//其中的PlainSelect 可以拿到sql语句的全部节点信息,具体各位可以看源码PlainSelect plain =(PlainSelect) selectStatement.getSelectBody();//获取所有外连接List<Join> joins = plain.getJoins();//获取到原始sql语句String sql = processSql;StringBuffer whereSql =newStringBuffer();switch(dataScope){//这里dataScope  范围 1 所有数据权限  2 本人  3,部门及分部门(递归)  4.公司及分公司(递归)//所有数据权限作用在人上,因此sql用 in case1:
                            whereSql.append("1=1");break;case2:for(Join join : joins){Table rightItem =(Table) join.getRightItem();//匹配表名if(rightItem.getName().equals("sec_user")){//获取别名if(rightItem.getAlias()!=null){
                                        whereSql.append(rightItem.getAlias().getName()).append(".id = ").append(SecurityUtils.getLoginUser().getId());}else{
                                        whereSql.append("id = ").append(deptsUser);}}}break;case3:for(Join join : joins){Table rightItem =(Table) join.getRightItem();if(rightItem.getName().equals("sec_user")){if(rightItem.getAlias()!=null){
                                        whereSql.append(rightItem.getAlias().getName()).append(".id in ").append(deptsUser);}else{
                                        whereSql.append("id in ").append(deptsUser);}}}break;case4:for(Join join : joins){Table rightItem =(Table) join.getRightItem();if(rightItem.getName().equals("sec_user")){if(rightItem.getAlias()!=null){
                                        whereSql.append(rightItem.getAlias().getName()).append(".id in ").append(orgsUser);}else{
                                        whereSql.append("id in ").append(deptsUser);}}}break;}//获取where节点Expression where = plain.getWhere();if(where ==null){if(whereSql.length()>0){Expression expression =CCJSqlParserUtil.parseCondExpression(whereSql.toString());Expression whereExpression =(Expression) expression;
                            plain.setWhere(whereExpression);}}else{if(whereSql.length()>0){//where条件之前存在,需要重新进行拼接
                            whereSql.append(" and ( "+ where.toString()+" )");}else{//新增片段不存在,使用之前的sql
                            whereSql.append(where.toString());}Expression expression =CCJSqlParserUtil.parseCondExpression(whereSql.toString());
                        plain.setWhere(expression);}
                    sql2Reset = selectStatement.toString();}}catch(Exception e){
                e.printStackTrace();}}// 替换sqlExecutorPluginUtils.resetSql2Invocation(invocation, sql2Reset);//放行Object proceed = invocation.proceed();return proceed;
标签: mybatis sql java

本文转载自: https://blog.csdn.net/ewcc_ycl/article/details/131001517
版权归原作者 元芳,你怎么看 所有, 如有侵权,请联系我们删除。

“MyBatis的动态拦截sql并修改”的评论:

还没有评论