fix: count 查询优化有误。

This commit is contained in:
Suomm 2023-06-09 20:59:58 +08:00
parent b087dd0b92
commit 03628a0bb7

View File

@ -18,19 +18,111 @@ package com.mybatisflex.core.util;
import com.mybatisflex.core.BaseMapper;
import com.mybatisflex.core.field.FieldQuery;
import com.mybatisflex.core.field.FieldQueryBuilder;
import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.core.query.*;
import org.apache.ibatis.exceptions.TooManyResultsException;
import org.apache.ibatis.session.defaults.DefaultSqlSession;
import java.util.*;
import java.util.function.Consumer;
import static com.mybatisflex.core.query.QueryMethods.count;
public class MapperUtil {
private MapperUtil() {
}
/**
* <p>原生的未经过优化的 COUNT 查询抛开效率问题不谈只关注结果的准确性
* 这个 COUNT 查询查出来的分页总数据是 100% 正确的不接受任何反驳
*
* <p>为什么这么说因为是用子查询实现的生成的 SQL 如下
*
* <p><pre>
* {@code
* SELECT COUNT(*) AS `total` FROM ( ...用户构建的 SQL 语句... );
* }
* </pre>
*
* <p>不进行 SQL 优化的时候返回的就是这样的 COUNT 查询语句
*/
public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper) {
return QueryWrapper.create()
.select(count().as("total"))
.from(queryWrapper);
}
/**
* 优化 COUNT 查询语句
*/
public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper) {
List<QueryColumn> selectColumns = CPI.getSelectColumns(queryWrapper);
List<QueryColumn> groupByColumns = CPI.getGroupByColumns(queryWrapper);
// 如果有 distinct 语句或者 group by 语句则不优化
// 这种一旦优化了就会造成 count 语句查询出来的值不对
if (hasDistinct(selectColumns) || hasGroupBy(groupByColumns)) {
return rawCountQueryWrapper(queryWrapper);
}
// 判断能不能清除 join 语句
if (canClearJoins(queryWrapper)) {
CPI.setJoins(queryWrapper, null);
}
// 最后将最后面的 order by 移除掉
// select 里面的列换成 COUNT(*) AS `total` 就好了
CPI.setOrderBys(queryWrapper, null);
CPI.setSelectColumns(queryWrapper, Collections.singletonList(count().as("total")));
return queryWrapper;
}
private static boolean hasDistinct(List<QueryColumn> selectColumns) {
if (CollectionUtil.isEmpty(selectColumns)) {
return false;
}
for (QueryColumn selectColumn : selectColumns) {
if (selectColumn instanceof DistinctQueryColumn) {
return true;
}
}
return false;
}
private static boolean hasGroupBy(List<QueryColumn> groupByColumns) {
return CollectionUtil.isNotEmpty(groupByColumns);
}
private static boolean canClearJoins(QueryWrapper queryWrapper) {
List<Join> joins = CPI.getJoins(queryWrapper);
if (CollectionUtil.isEmpty(joins)) {
return false;
}
// 只有全是 left join 语句才会清除 join
// 因为如果是 inner join right join 往往都会放大记录数
for (Join join : joins) {
if (!Join.TYPE_LEFT.equals(CPI.getJoinType(join))) {
return false;
}
}
// 获取 join 语句中使用到的表名
List<String> joinTables = new ArrayList<>();
// Map<String, String> joinTables = new HashMap<>();
joins.forEach(join -> {
QueryTable joinQueryTable = CPI.getJoinQueryTable(join);
if (joinQueryTable != null && StringUtil.isNotBlank(joinQueryTable.getName())) {
joinTables.add(joinQueryTable.getName());
}
});
// 获取 where 语句中的条件
QueryCondition where = CPI.getWhereQueryCondition(queryWrapper);
// 最后判断一下 where 中是否用到了 join 的表
return !CPI.containsTable(where, CollectionUtil.toArrayString(joinTables));
}
@SuppressWarnings({"rawtypes", "unchecked"})
public static <R> void queryFields(BaseMapper<?> mapper, List<R> list, Consumer<FieldQueryBuilder<R>>[] consumers) {
if (CollectionUtil.isEmpty(list) || ArrayUtil.isEmpty(consumers) || consumers[0] == null) {