接上一篇:mybatis Interceptor拦截器实现自定义扩展查询兼容mybatis plus-左搜 (leftso.com)

这里进行自定义分页查询扩展,基于mybatis plus,同样适用于mybatis

mybatis (plus) 自定义分页拦截器
@Intercepts({
        @Signature(
                type = Executor.class,
                method = "query",
                args = {MappedStatement.class,Object.class, RowBounds.class, ResultHandler.class}
        )
})
@Slf4j
public class MybatisPageSelectFilterInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        //这个可以得到当前执行的sql语句在xml文件中配置的id的值
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        Object parameter = invocation.getArgs()[1];
        if (parameter instanceof Paging){
            PageData<?> page=new PageData<>();

            BoundSql boundSql = mappedStatement.getBoundSql(parameter);
            Paging<?> paging =(Paging<?>)boundSql.getParameterObject();

            page.setCurrent(paging.getPageNum());
            page.setSize(paging.getPageSize());

            Connection connection = mappedStatement.getConfiguration().getEnvironment().getDataSource().getConnection();

            long count=count(connection,boundSql,paging.getSelectFilter());
            page.setTotal(count);
            page.setPages(count%paging.getPageSize()==0?(count/ paging.getPageSize()):(count/ paging.getPageSize())+1);
            if (paging.getPageNum()<=page.getPages()&&paging.getPageNum()>0){
                //先处理统计
                dealSelectFilter(mappedStatement,parameter,paging.getSelectFilter(),invocation,paging);
                // 继续执行
                List result = (List)invocation.proceed();
                page.setRecords(result);
            }else{
                page.setRecords(new ArrayList<>());
            }
            return Arrays.asList(page);
        }
        // 继续执行
        Object result = invocation.proceed();
        return result;
    }

    public int count(Connection connection,BoundSql boundSql, SelectFilter selectFilter) {

        Class entityClass = selectFilter.getEntityClass();
        if (Objects.nonNull(entityClass)){
            Field[] declaredFields = entityClass.getDeclaredFields();
            for (Field declaredField : declaredFields) {
                TableField tableField=(TableField) declaredField.getAnnotation(TableField.class);
                if (Objects.nonNull(tableField)){
                    String dbName=StringUtils.isEmpty(tableField.value())?declaredField.getName():tableField.value();
                    if (Objects.nonNull(tableField)){
                        TableLogic tableLogic=(TableLogic) declaredField.getAnnotation(TableLogic.class);
                        if (Objects.nonNull(tableLogic)){
                            selectFilter.buildWithWhere();
                            String sql = selectFilter.getSql();
                            if (!sql.contains(dbName)){
                                selectFilter.eq(dbName,0);
                            }
                        }
                    }
                }
            }
        }

        PreparedStatement countStmt = null;
        ResultSet rs = null;
        selectFilter.buildWithWhere();
        String sql = selectFilter.getSql();
        String countSql=boundSql.getSql().replaceAll("select.*from","select count(0) from ")+sql;
        log.info(countSql);
        log.info(Arrays.toString(selectFilter.getSqlParamsMap().values().toArray()));
        try {
            countStmt = connection.prepareStatement(countSql);
            Map sqlParamsMap = selectFilter.getSqlParamsMap();
            if (!CollectionUtils.isEmpty(sqlParamsMap)){
                int index = 1;//从1开始赋值
                for (Object value : sqlParamsMap.values()) {
                    countStmt.setObject(index,value);
                    index++;
                }
            }
            rs = countStmt.executeQuery();
            if (rs.next()) {
                return rs.getInt(1);
            }
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                if (null != countStmt) {
                    countStmt.close();
                }
                if (null != rs) {
                    rs.close();
                }
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        return 0;
    }


    private void dealSelectFilter(MappedStatement mappedStatement,Object parameter,SelectFilter selectFilter,Invocation invocation,Paging paging){
        BoundSql boundSql = mappedStatement.getBoundSql(parameter);
        String oldsql = boundSql.getSql();
        log.info("old:"+oldsql);

        Class entityClass = selectFilter.getEntityClass();
        if (Objects.nonNull(entityClass)){
            TableName tableNameAnnotation = (TableName) entityClass.getDeclaredAnnotation(TableName.class);
            String tableName = tableNameAnnotation.value();
            Matcher matcher = Pattern.compile("^select\\s+\\*\\s+from\\s+" + tableName).matcher(oldsql.toLowerCase());
            if (matcher.find()){
                StringBuffer sqlBuilder=new StringBuffer();
                Field[] declaredFields = entityClass.getDeclaredFields();
                sqlBuilder.append("select ");
                List<String> fields=new ArrayList<>();

                for (Field declaredField : declaredFields) {
                    String fieldName = declaredField.getName();
                    TableField tableField = declaredField.getAnnotation(TableField.class);
                    if (Objects.nonNull(tableField)&&tableField.exist()){
                        String name = declaredField.getName();
                        String dbName=StringUtils.isEmpty(tableField.value())?fieldName:tableField.value();
                        fields.add(dbName+" "+ name);
                    }
                    TableId tableId = declaredField.getAnnotation(TableId.class);
                    if (Objects.nonNull(tableId)){
                        fields.add(StringUtils.isEmpty(tableId.value())?fieldName: tableId.value()+" "+declaredField.getName());
                    }

                }
                String join = String.join(",", fields);
                sqlBuilder.append(join);
                sqlBuilder.append(" from ").append(tableName).append(" ");
                oldsql=sqlBuilder.toString();

                selectFilter.buildWithWhere();
            }
        }

        selectFilter.buildWithWhere();

        List<ParameterMapping> parameterMappingList=new ArrayList<>();

        for (Object key : selectFilter.getSqlParamsMap().keySet()) {
            parameterMappingList.add(new ParameterMapping.Builder(mappedStatement.getConfiguration(),"selectFilter.sqlParamsMap."+key,Object.class).build());
        }

        String dealSql=oldsql+selectFilter.getSql()+" limit "+(paging.getPageNum()-1)* paging.getPageSize()+","+paging.getPageSize();

        BoundSql newBoundSql = new BoundSql(mappedStatement.getConfiguration(), dealSql,
                parameterMappingList, boundSql.getParameterObject());
        MappedStatement newMs = copyFromMappedStatement(mappedStatement, new BoundSqlSqlSource(newBoundSql),selectFilter.getEntityClass());
        invocation.getArgs()[0] = newMs;
    }


    /**
     * 复制原始MappedStatement
     * @param ms
     * @param newSqlSource
     * @return
     */
    private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource,Class<?> entityClass) {
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource,
                ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null) {
            for (String keyProperty : ms.getKeyProperties()) {
                builder.keyProperty(keyProperty);
            }
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());

        //设置返回列表类型为实体对象类型
        ResultMap resultMap=new ResultMap.Builder(ms.getConfiguration(),ms.getId(), entityClass,new ArrayList<>()).build();
        List<ResultMap> resultMaps=new ArrayList<>();
        resultMaps.add(resultMap);
        builder.resultMaps(resultMaps);

        builder.cache(ms.getCache());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }

    public static class BoundSqlSqlSource implements SqlSource {
        BoundSql boundSql;

        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }
}

拦截器注册:
@Configuration
public class MybatisSelectFilterConfiguration {
    @Autowired
    private List<SqlSessionFactory> sqlSessionFactoryList;

    @PostConstruct
    public void addInterceptor() {

        MybatisListSelectFilterInterceptor listSelectFilterInterceptor=new MybatisListSelectFilterInterceptor();
        MybatisPageSelectFilterInterceptor pageSelectFilterInterceptor=new MybatisPageSelectFilterInterceptor();
        sqlSessionFactoryList.forEach(sqlSessionFactory -> {
            sqlSessionFactory.getConfiguration().addInterceptor(listSelectFilterInterceptor);
            sqlSessionFactory.getConfiguration().addInterceptor(pageSelectFilterInterceptor);
        });
    }
}


mapper 代码:
    @Select("select * from cy_xlxw")
    PageData<Xlxw> page(Paging<Xlxw> paging);


 

评论区域