mybatis Interceptor拦截器实现自定义扩展查询兼容mybatis plus
@Intercepts({
@Signature(type = Executor.class,method = "query",args = {MappedStatement.class,Object.class, RowBounds.class, ResultHandler.class})
})
@Slf4j
public class MybatisListSelectFilterInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
// 获取原始sql语句
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
Object parameter = invocation.getArgs()[1];
if (parameter instanceof SelectFilter){
BoundSql boundSql = mappedStatement.getBoundSql(parameter);
SelectFilter selectFilter =(SelectFilter)boundSql.getParameterObject();
dealSelectFilter(mappedStatement,parameter,selectFilter,invocation);
}
// 继续执行
Object result = invocation.proceed();
return result;
}
private void dealSelectFilter(MappedStatement mappedStatement,Object parameter,SelectFilter selectFilter,Invocation invocation){
BoundSql boundSql = mappedStatement.getBoundSql(parameter);
String oldsql = boundSql.getSql();
log.info("old:"+oldsql);
Class entityClass = selectFilter.getEntityClass();
TableName tableNameAnnotation = (TableName) entityClass.getDeclaredAnnotation(TableName.class);
String tableName = tableNameAnnotation.value();
Matcher matcher = Pattern.compile("^select\\s+\\*\\s+from\\s+" + tableName).matcher(oldsql.toLowerCase());
if (matcher.find()){
StringBuffer sqlBuilder=new StringBuffer();
Field[] declaredFields = entityClass.getDeclaredFields();
sqlBuilder.append("select ");
List<String> fields=new ArrayList<>();
for (Field declaredField : declaredFields) {
TableField tableField = declaredField.getAnnotation(TableField.class);
if (Objects.nonNull(tableField)&&tableField.exist()){
String name = declaredField.getName();
fields.add(tableField.value()+" "+name);
}
}
String join = String.join(",", fields);
sqlBuilder.append(join);
sqlBuilder.append(" from ").append(tableName).append(" ");
oldsql=sqlBuilder.toString();
}
selectFilter.buildWithWhere();
List<ParameterMapping> parameterMappingList=new ArrayList<>();
for (Object key : selectFilter.getSqlParamsMap().keySet()) {
parameterMappingList.add(new ParameterMapping.Builder(mappedStatement.getConfiguration(),"sqlParamsMap."+key,Object.class).build());
}
BoundSql newBoundSql = new BoundSql(mappedStatement.getConfiguration(), oldsql + selectFilter.getSql(),
parameterMappingList, boundSql.getParameterObject());
MappedStatement newMs = copyFromMappedStatement(mappedStatement, new BoundSqlSqlSource(newBoundSql));
invocation.getArgs()[0] = newMs;
}
/**
* 复制原始MappedStatement
* @param ms
* @param newSqlSource
* @return
*/
private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource,
ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null) {
for (String keyProperty : ms.getKeyProperties()) {
builder.keyProperty(keyProperty);
}
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.cache(ms.getCache());
builder.useCache(ms.isUseCache());
return builder.build();
}
public static class BoundSqlSqlSource implements SqlSource {
BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
}
以上拦截器通过拦截 SelectFilter参数对象进行自定义扩展查询,通过SelectFilter来构建where后面的语句
拦截器注册:
@Configuration
public class MybatisSelectFilterConfiguration {
@Autowired
private List<SqlSessionFactory> sqlSessionFactoryList;
@PostConstruct
public void addInterceptor() {
MybatisListSelectFilterInterceptor listSelectFilterInterceptor=new MybatisListSelectFilterInterceptor();
MybatisPageSelectFilterInterceptor pageSelectFilterInterceptor=new MybatisPageSelectFilterInterceptor();
sqlSessionFactoryList.forEach(sqlSessionFactory -> {
sqlSessionFactory.getConfiguration().addInterceptor(listSelectFilterInterceptor);
sqlSessionFactory.getConfiguration().addInterceptor(pageSelectFilterInterceptor);
});
}
}
Mapper 方法
@Select("select * from cy_xlxw")
List<Xlxw> list(SelectFilter<Xlxw> selectFilter);
SelectFilter 代码:
https://www.leftso.com/article/1028.html