diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/util/MapperUtil.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/util/MapperUtil.java
index e6dbd342..271d4a62 100644
--- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/util/MapperUtil.java
+++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/util/MapperUtil.java
@@ -18,19 +18,111 @@ package com.mybatisflex.core.util;
import com.mybatisflex.core.BaseMapper;
import com.mybatisflex.core.field.FieldQuery;
import com.mybatisflex.core.field.FieldQueryBuilder;
-import com.mybatisflex.core.query.QueryWrapper;
+import com.mybatisflex.core.query.*;
import org.apache.ibatis.exceptions.TooManyResultsException;
import org.apache.ibatis.session.defaults.DefaultSqlSession;
import java.util.*;
import java.util.function.Consumer;
+import static com.mybatisflex.core.query.QueryMethods.count;
+
public class MapperUtil {
private MapperUtil() {
}
+ /**
+ *
原生的、未经过优化的 COUNT 查询。抛开效率问题不谈,只关注结果的准确性,
+ * 这个 COUNT 查询查出来的分页总数据是 100% 正确的,不接受任何反驳。
+ *
+ *
为什么这么说,因为是用子查询实现的,生成的 SQL 如下:
+ *
+ *
+ * {@code
+ * SELECT COUNT(*) AS `total` FROM ( ...用户构建的 SQL 语句... );
+ * }
+ *
+ *
+ * 不进行 SQL 优化的时候,返回的就是这样的 COUNT 查询语句。
+ */
+ public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper) {
+ return QueryWrapper.create()
+ .select(count().as("total"))
+ .from(queryWrapper);
+ }
+
+ /**
+ * 优化 COUNT 查询语句。
+ */
+ public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper) {
+ List selectColumns = CPI.getSelectColumns(queryWrapper);
+ List groupByColumns = CPI.getGroupByColumns(queryWrapper);
+ // 如果有 distinct 语句或者 group by 语句则不优化
+ // 这种一旦优化了就会造成 count 语句查询出来的值不对
+ if (hasDistinct(selectColumns) || hasGroupBy(groupByColumns)) {
+ return rawCountQueryWrapper(queryWrapper);
+ }
+ // 判断能不能清除 join 语句
+ if (canClearJoins(queryWrapper)) {
+ CPI.setJoins(queryWrapper, null);
+ }
+ // 最后将最后面的 order by 移除掉
+ // 将 select 里面的列换成 COUNT(*) AS `total` 就好了
+ CPI.setOrderBys(queryWrapper, null);
+ CPI.setSelectColumns(queryWrapper, Collections.singletonList(count().as("total")));
+ return queryWrapper;
+ }
+
+ private static boolean hasDistinct(List selectColumns) {
+ if (CollectionUtil.isEmpty(selectColumns)) {
+ return false;
+ }
+ for (QueryColumn selectColumn : selectColumns) {
+ if (selectColumn instanceof DistinctQueryColumn) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private static boolean hasGroupBy(List groupByColumns) {
+ return CollectionUtil.isNotEmpty(groupByColumns);
+ }
+
+ private static boolean canClearJoins(QueryWrapper queryWrapper) {
+ List joins = CPI.getJoins(queryWrapper);
+ if (CollectionUtil.isEmpty(joins)) {
+ return false;
+ }
+
+ // 只有全是 left join 语句才会清除 join
+ // 因为如果是 inner join 或 right join 往往都会放大记录数
+ for (Join join : joins) {
+ if (!Join.TYPE_LEFT.equals(CPI.getJoinType(join))) {
+ return false;
+ }
+ }
+
+ // 获取 join 语句中使用到的表名
+ List joinTables = new ArrayList<>();
+// Map joinTables = new HashMap<>();
+ joins.forEach(join -> {
+ QueryTable joinQueryTable = CPI.getJoinQueryTable(join);
+ if (joinQueryTable != null && StringUtil.isNotBlank(joinQueryTable.getName())) {
+ joinTables.add(joinQueryTable.getName());
+ }
+ });
+
+ // 获取 where 语句中的条件
+ QueryCondition where = CPI.getWhereQueryCondition(queryWrapper);
+
+ // 最后判断一下 where 中是否用到了 join 的表
+ return !CPI.containsTable(where, CollectionUtil.toArrayString(joinTables));
+ }
+
+
@SuppressWarnings({"rawtypes", "unchecked"})
public static void queryFields(BaseMapper> mapper, List list, Consumer>[] consumers) {
if (CollectionUtil.isEmpty(list) || ArrayUtil.isEmpty(consumers) || consumers[0] == null) {