diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/util/MapperUtil.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/util/MapperUtil.java index 576eeec4..b10b3860 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/util/MapperUtil.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/util/MapperUtil.java @@ -25,13 +25,7 @@ import com.mybatisflex.core.field.FieldQuery; import com.mybatisflex.core.field.FieldQueryBuilder; import com.mybatisflex.core.field.FieldQueryManager; import com.mybatisflex.core.paginate.Page; -import com.mybatisflex.core.query.CPI; -import com.mybatisflex.core.query.DistinctQueryColumn; -import com.mybatisflex.core.query.Join; -import com.mybatisflex.core.query.QueryColumn; -import com.mybatisflex.core.query.QueryCondition; -import com.mybatisflex.core.query.QueryTable; -import com.mybatisflex.core.query.QueryWrapper; +import com.mybatisflex.core.query.*; import com.mybatisflex.core.relation.RelationManager; import com.mybatisflex.core.table.TableInfo; import com.mybatisflex.core.table.TableInfoFactory; @@ -75,13 +69,36 @@ public class MapperUtil { .select(count().as("total")) .from(queryWrapper).as("t"); } - + public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper,List customCountColumns) { + return customCountColumns!=null?QueryWrapper.create() + .select(customCountColumns) + .from(queryWrapper).as("t"):rawCountQueryWrapper(queryWrapper); + } /** * 优化 COUNT 查询语句。 */ public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper) { + return optimizeCountQueryWrapper(queryWrapper, Collections.singletonList(count().as("total"))); + } + /** + * 优化 COUNT 查询语句。 + */ + public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper, List customCountColumns) { // 对克隆对象进行操作,不影响原来的 QueryWrapper 对象 QueryWrapper clone = queryWrapper.clone(); + + List unions = CPI.getUnions(clone); + if(!CollectionUtil.isEmpty(unions)){ + List newUnions = new ArrayList<>(unions.size()); + for (UnionWrapper union : unions) { + QueryWrapper unionQuery = optimizeCountQueryWrapper(union.getQueryWrapper().clone(),null); + UnionWrapper clone1 = union.clone(); + clone1.setQueryWrapper(unionQuery); + newUnions.add(clone1); + } + CPI.setUnions(clone, newUnions); + } + // 将最后面的 order by 移除掉 CPI.setOrderBys(clone, null); // 获取查询列和分组列,用于判断是否进行优化 @@ -91,14 +108,20 @@ public class MapperUtil { // 如果有 distinct、group by、having 等语句则不优化 // 这种一旦优化了就会造成 count 语句查询出来的值不对 if (hasDistinct(selectColumns) || hasGroupBy(groupByColumns) || havingCondition != null) { - return rawCountQueryWrapper(clone); + return clone; } // 判断能不能清除 join 语句 if (canClearJoins(clone)) { CPI.setJoins(clone, null); } // 将 select 里面的列换成 COUNT(*) AS `total` - CPI.setSelectColumns(clone, Collections.singletonList(count().as("total"))); + if(customCountColumns!=null){ + if(hasUnion(clone)){ + return rawCountQueryWrapper(clone,customCountColumns); + }else { + CPI.setSelectColumns(clone, customCountColumns); + } + } return clone; } @@ -118,6 +141,10 @@ public class MapperUtil { return CollectionUtil.isNotEmpty(groupByColumns); } + private static boolean hasUnion(QueryWrapper countQueryWrapper) { + return CollectionUtil.isNotEmpty(CPI.getUnions(countQueryWrapper)); + } + private static boolean canClearJoins(QueryWrapper queryWrapper) { List joins = CPI.getJoins(queryWrapper); if (CollectionUtil.isEmpty(joins)) { diff --git a/mybatis-flex-test/mybatis-flex-spring-boot-test/src/test/java/com/mybatisflex/test/common/MapperUtilTest.java b/mybatis-flex-test/mybatis-flex-spring-boot-test/src/test/java/com/mybatisflex/test/common/MapperUtilTest.java index 0e48c2a3..e420010d 100644 --- a/mybatis-flex-test/mybatis-flex-spring-boot-test/src/test/java/com/mybatisflex/test/common/MapperUtilTest.java +++ b/mybatis-flex-test/mybatis-flex-spring-boot-test/src/test/java/com/mybatisflex/test/common/MapperUtilTest.java @@ -45,4 +45,81 @@ class MapperUtilTest { System.out.println(MapperUtil.optimizeCountQueryWrapper(queryWrapper).toSQL()); } + /** + * 测试 (sql1) union (sql2) + */ + @Test + void testOptimizeCountQueryWrapperOfUnion1() { + //简单union + //SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_id` = 1 order by user_id desc + QueryWrapper union1 = QueryWrapper.create().select(USER.USER_ID, USER.USER_NAME).from(USER).where(USER.USER_ID.eq(1)).orderBy(USER.USER_ID.desc()); + //SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_name` LIKE '%test%' + QueryWrapper union2 = QueryWrapper.create().select(USER.USER_ID, USER.USER_NAME).from(USER).where(USER.USER_NAME.like("test")); + + QueryWrapper query1 = union1.union(union2); + + String sql = MapperUtil.optimizeCountQueryWrapper(query1).toSQL(); + //SELECT COUNT(*) AS `total` FROM ((SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_id` = 1) UNION (SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_name` LIKE '%test%')) AS `t` + System.out.println(sql); + } + + /** + * 测试 (sql1 ) union (sql2 with group by) + */ + @Test + void testOptimizeCountQueryWrapperOfUnion2() { + //with group by union + //SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_id` = 1 order by user_id desc + QueryWrapper union1 = QueryWrapper.create().select(USER.USER_ID, USER.USER_NAME).from(USER).where(USER.USER_ID.eq(1)).orderBy(USER.USER_ID.desc()); + //SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_name` LIKE '%test%' group by user_id, user_name + QueryWrapper union2 = QueryWrapper.create().select(USER.USER_ID, USER.USER_NAME).from(USER).where(USER.USER_NAME.like("test")).groupBy(USER.USER_ID, USER.USER_NAME); + + QueryWrapper query1 = union1.union(union2); + + //SELECT COUNT(*) AS `total` FROM ((SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_id` = 1) UNION (SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_name` LIKE '%test%' GROUP BY `user_id`, `user_name`)) AS `t` + String sql = MapperUtil.optimizeCountQueryWrapper(query1).toSQL(); + System.out.println(sql); + } + + /** + * 测试 (sql1) union (sql2 union sql3) + */ + @Test + void testOptimizeCountQueryWrapperOfUnion3() { + //with sub query union + //SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_id` = 1 order by user_id desc + QueryWrapper union1 = QueryWrapper.create().select(USER.USER_ID, USER.USER_NAME).from(USER).where(USER.USER_ID.eq(1)).orderBy(USER.USER_ID.desc()); + //SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_name` LIKE '%test%' group by user_id, user_name + QueryWrapper union2 = QueryWrapper.create().select(USER.USER_ID, USER.USER_NAME).from(USER).where(USER.USER_NAME.like("test")).orderBy(USER.USER_NAME.desc()); + + QueryWrapper union3 = QueryWrapper.create().select(USER.USER_ID, USER.USER_NAME).from(USER).where(USER.PASSWORD.isNull()).orderBy(USER.USER_NAME.desc()); + + + QueryWrapper query1 = union1.union(union2.union(union3)); + + //SELECT COUNT(*) AS `total` FROM ((SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_id` = 1) UNION ((SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_name` LIKE '%test%') UNION (SELECT `user_id`, `user_name` FROM `tb_user` WHERE `password` IS NULL ))) AS `t` + String sql = MapperUtil.optimizeCountQueryWrapper(query1).toSQL(); + + System.out.println(sql); + } + + @Test + void testOptimizeCountQueryWrapperOfUnion4() { + //with sub query union + //SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_id` = 1 order by user_id desc + QueryWrapper union1 = QueryWrapper.create().select(USER.USER_ID, USER.USER_NAME).from(USER).where(USER.USER_ID.eq(1)).orderBy(USER.USER_ID.desc()); + //SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_name` LIKE '%test%' group by user_id, user_name + QueryWrapper union2 = QueryWrapper.create().select(USER.USER_ID, USER.USER_NAME).from(USER).where(USER.USER_NAME.like("test")).orderBy(USER.USER_NAME.desc()); + + QueryWrapper union3 = QueryWrapper.create().from(union2).as("a"); + + + QueryWrapper query1 = union1.union(union3); + + //SELECT COUNT(*) AS `total` FROM ((SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_id` = 1) UNION ((SELECT `user_id`, `user_name` FROM `tb_user` WHERE `user_name` LIKE '%test%') UNION (SELECT `user_id`, `user_name` FROM `tb_user` WHERE `password` IS NULL ))) AS `t` + String sql = MapperUtil.optimizeCountQueryWrapper(query1).toSQL(); + + System.out.println(sql); + } + }