optimize count sql in paginate

This commit is contained in:
开源海哥 2023-05-21 18:15:44 +08:00
parent 028f2c58bd
commit b966e372de
13 changed files with 177 additions and 92 deletions

View File

@ -19,15 +19,13 @@ import com.mybatisflex.core.exception.FlexExceptions;
import com.mybatisflex.core.mybatis.MappedStatementTypes;
import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.provider.EntitySqlProvider;
import com.mybatisflex.core.query.CPI;
import com.mybatisflex.core.query.QueryColumn;
import com.mybatisflex.core.query.QueryCondition;
import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.core.query.*;
import com.mybatisflex.core.table.TableInfo;
import com.mybatisflex.core.table.TableInfoFactory;
import com.mybatisflex.core.util.CollectionUtil;
import com.mybatisflex.core.util.ConvertUtil;
import com.mybatisflex.core.util.ObjectUtil;
import com.mybatisflex.core.util.StringUtil;
import org.apache.ibatis.annotations.*;
import org.apache.ibatis.builder.annotation.ProviderContext;
@ -483,7 +481,8 @@ public interface BaseMapper<T> {
if (CollectionUtil.isEmpty(selectColumns)) {
queryWrapper.select(count());
}
Object object = selectObjectByQuery(queryWrapper);
List<Object> objects = selectObjectListByQuery(queryWrapper);
Object object = objects == null || objects.isEmpty() ? null : objects.get(0);
if (object == null) {
return 0;
} else if (object instanceof Number) {
@ -570,53 +569,51 @@ public interface BaseMapper<T> {
* @return page 数据
*/
default Page<T> paginate(Page<T> page, QueryWrapper queryWrapper) {
List<QueryColumn> groupByColumns = CPI.getGroupByColumns(queryWrapper);
List<QueryColumn> selectColumns = CPI.getSelectColumns(queryWrapper);
// 只有 totalRow 小于 0 的时候才会去查询总量
// 这样方便用户做总数缓存而非每次都要去查询总量
// 一般的分页场景中只有第一页的时候有必要去查询总量第二页以后是不需要的
if (page.getTotalRow() < 0) {
//清除group by 去查询数据
CPI.setGroupByColumns(queryWrapper, null);
CPI.setSelectColumns(queryWrapper, Arrays.asList(count()));
long count = selectCountByQuery(queryWrapper);
page.setTotalRow(count);
}
if (page.getTotalRow() == 0 || page.getPageNumber() > page.getTotalPage()) {
return page;
}
//恢复数量查询清除的 groupBy
CPI.setGroupByColumns(queryWrapper, groupByColumns);
//重置 selectColumns
CPI.setSelectColumns(queryWrapper, selectColumns);
int offset = page.getPageSize() * (page.getPageNumber() - 1);
queryWrapper.limit(offset, page.getPageSize());
List<T> rows = selectListByQuery(queryWrapper);
page.setRecords(rows);
return page;
return paginateAs(page, queryWrapper, null);
}
default <R> Page<R> paginateAs(Page<R> page, QueryWrapper queryWrapper, Class<R> asType) {
List<QueryColumn> groupByColumns = CPI.getGroupByColumns(queryWrapper);
List<QueryColumn> selectColumns = CPI.getSelectColumns(queryWrapper);
List<Join> joins = CPI.getJoins(queryWrapper);
boolean removedJoins = true;
// 只有 totalRow 小于 0 的时候才会去查询总量
// 这样方便用户做总数缓存而非每次都要去查询总量
// 一般的分页场景中只有第一页的时候有必要去查询总量第二页以后是不需要的
if (page.getTotalRow() < 0) {
//清除group by 去查询数据
CPI.setGroupByColumns(queryWrapper, null);
CPI.setSelectColumns(queryWrapper, Collections.singletonList(count()));
CPI.setSelectColumns(queryWrapper, Collections.singletonList(count().as("total")));
if (joins != null && !joins.isEmpty()) {
for (Join join : joins) {
if (!Join.TYPE_LEFT.equals(CPI.getJoinType(join))) {
removedJoins = false;
break;
}
}
} else {
removedJoins = false;
}
if (removedJoins) {
List<String> joinTables = new ArrayList<>();
joins.forEach(join -> {
QueryTable joinQueryTable = CPI.getJoinQueryTable(join);
if (joinQueryTable != null && StringUtil.isNotBlank(joinQueryTable.getName())) {
joinTables.add(joinQueryTable.getName());
}
});
QueryCondition where = CPI.getWhereQueryCondition(queryWrapper);
if (CPI.containsTable(where, CollectionUtil.toArrayString(joinTables))) {
removedJoins = false;
}
}
if (removedJoins) {
CPI.setJoins(queryWrapper, null);
}
long count = selectCountByQuery(queryWrapper);
page.setTotalRow(count);
@ -626,22 +623,30 @@ public interface BaseMapper<T> {
return page;
}
//恢复数量查询清除的 groupBy
CPI.setGroupByColumns(queryWrapper, groupByColumns);
//重置 selectColumns
CPI.setSelectColumns(queryWrapper, selectColumns);
//重置 join
if (removedJoins) {
CPI.setJoins(queryWrapper, joins);
}
int offset = page.getPageSize() * (page.getPageNumber() - 1);
queryWrapper.limit(offset, page.getPageSize());
try {
// 调用内部方法不走代理需要主动设置 MappedStatementType
// fixed https://gitee.com/mybatis-flex/mybatis-flex/issues/I73BP6
MappedStatementTypes.setCurrentType(asType);
List<R> records = selectListByQueryAs(queryWrapper, asType);
if (asType != null) {
try {
// 调用内部方法不走代理需要主动设置 MappedStatementType
// fixed https://gitee.com/mybatis-flex/mybatis-flex/issues/I73BP6
MappedStatementTypes.setCurrentType(asType);
List<R> records = selectListByQueryAs(queryWrapper, asType);
page.setRecords(records);
} finally {
MappedStatementTypes.clear();
}
} else {
List<R> records = (List<R>) selectListByQuery(queryWrapper);
page.setRecords(records);
}finally {
MappedStatementTypes.clear();
}
return page;
}

View File

@ -56,7 +56,7 @@ public class CommonsDialectImpl implements IDialect {
@Override
public String wrap(String keyword) {
return keywordWrap.wrap(keyword);
return "*".equals(keyword) ? keyword : keywordWrap.wrap(keyword);
}
@Override

View File

@ -69,8 +69,8 @@ public class OracleDialect extends CommonsDialectImpl {
@Override
public String wrap(String keyword) {
if (StringUtil.isBlank(keyword)) {
return "";
if (StringUtil.isBlank(keyword) || "*".equals(keyword)) {
return keyword;
}
if (caseSensitive || keywords.contains(keyword.toUpperCase(Locale.ENGLISH))) {
return "\"" + keyword + "\"";

View File

@ -25,11 +25,10 @@ import java.util.List;
*/
public class Brackets extends QueryCondition {
private final QueryCondition childCondition;
private final QueryCondition child;
public Brackets(QueryCondition childCondition) {
this.childCondition = childCondition;
this.child = childCondition;
}
@ -46,16 +45,16 @@ public class Brackets extends QueryCondition {
}
protected void connectToChild(QueryCondition nextCondition, SqlConnector connector) {
childCondition.connect(nextCondition, connector);
child.connect(nextCondition, connector);
}
@Override
public Object getValue() {
return checkEffective() ? WrapperUtil.getValues(childCondition) : null;
return checkEffective() ? WrapperUtil.getValues(child) : null;
}
public QueryCondition getChildCondition() {
return childCondition;
public QueryCondition getChild() {
return child;
}
@Override
@ -64,7 +63,7 @@ public class Brackets extends QueryCondition {
if (!effective) {
return false;
}
QueryCondition condition = this.childCondition;
QueryCondition condition = this.child;
while (condition != null) {
if (condition.checkEffective()) {
return true;
@ -81,7 +80,7 @@ public class Brackets extends QueryCondition {
StringBuilder sql = new StringBuilder();
if (checkEffective()) {
String childSql = childCondition.toSql(queryTables, dialect);
String childSql = child.toSql(queryTables, dialect);
if (StringUtil.isNotBlank(childSql)) {
QueryCondition effectiveBefore = getEffectiveBefore();
if (effectiveBefore != null) {
@ -101,10 +100,15 @@ public class Brackets extends QueryCondition {
}
@Override
boolean containsTable(String... tables) {
return child.containsTable(tables);
}
@Override
public String toString() {
return "Brackets{" +
"childCondition=" + childCondition +
"childCondition=" + child +
'}';
}
}

View File

@ -88,6 +88,13 @@ public class CPI {
queryWrapper.setJoins(joins);
}
public static String getJoinType(Join join){
return join.type;
}
public static QueryTable getJoinQueryTable(Join join){
return join.getQueryTable();
}
public static List<QueryTable> getJoinTables(QueryWrapper queryWrapper) {
return queryWrapper.getJoinTables();
@ -183,4 +190,8 @@ public class CPI {
queryWrapper.from(tableName);
}
}
public static boolean containsTable(QueryCondition condition,String ... tables){
return condition != null && condition.containsTable(tables);
}
}

View File

@ -30,18 +30,18 @@ public class Join implements Serializable {
private static final long serialVersionUID = 1L;
static final String TYPE_JOIN = " JOIN ";
static final String TYPE_LEFT = " LEFT JOIN ";
static final String TYPE_RIGHT = " RIGHT JOIN ";
static final String TYPE_INNER = " INNER JOIN ";
static final String TYPE_FULL = " FULL JOIN ";
static final String TYPE_CROSS = " CROSS JOIN ";
public static final String TYPE_JOIN = " JOIN ";
public static final String TYPE_LEFT = " LEFT JOIN ";
public static final String TYPE_RIGHT = " RIGHT JOIN ";
public static final String TYPE_INNER = " INNER JOIN ";
public static final String TYPE_FULL = " FULL JOIN ";
public static final String TYPE_CROSS = " CROSS JOIN ";
private final String type;
private final QueryTable queryTable;
private QueryCondition on;
private boolean effective;
protected final String type;
protected final QueryTable queryTable;
protected QueryCondition on;
protected boolean effective;
public Join(String type, String table, boolean when) {
this.type = type;
@ -56,7 +56,6 @@ public class Join implements Serializable {
}
QueryTable getQueryTable() {
return queryTable;
}

View File

@ -61,4 +61,9 @@ public class OperatorQueryCondition extends QueryCondition {
public Object getValue() {
return WrapperUtil.getValues(child);
}
@Override
boolean containsTable(String... tables) {
return child.containsTable(tables);
}
}

View File

@ -63,4 +63,10 @@ public class OperatorSelectCondition extends QueryCondition {
public Object getValue() {
return queryWrapper.getValueArray();
}
@Override
boolean containsTable(String... tables) {
QueryCondition condition = queryWrapper.getWhereQueryCondition();
return condition != null && condition.containsTable(tables);
}
}

View File

@ -264,6 +264,16 @@ public class QueryCondition implements Serializable {
return paramsCount;
}
boolean containsTable(String... tables){
for (String table : tables) {
if (column.table != null && table.equals(column.table.name)){
return true;
}
}
return false;
}
@Override
public String toString() {
return "QueryCondition{" +

View File

@ -41,7 +41,7 @@ class WrapperUtil {
while (condition != null) {
if (condition.checkEffective()) {
if (condition instanceof Brackets) {
List<QueryWrapper> childQueryWrapper = getChildSelect(((Brackets) condition).getChildCondition());
List<QueryWrapper> childQueryWrapper = getChildSelect(((Brackets) condition).getChild());
if (!childQueryWrapper.isEmpty()) {
if (list == null) {
list = new ArrayList<>();

View File

@ -19,18 +19,13 @@ import com.mybatisflex.core.FlexConsts;
import com.mybatisflex.core.exception.FlexExceptions;
import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.provider.RowSqlProvider;
import com.mybatisflex.core.query.CPI;
import com.mybatisflex.core.query.QueryColumn;
import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.core.query.*;
import com.mybatisflex.core.util.CollectionUtil;
import com.mybatisflex.core.util.StringUtil;
import org.apache.ibatis.annotations.*;
import org.apache.ibatis.exceptions.TooManyResultsException;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.*;
import static com.mybatisflex.core.query.QueryMethods.count;
@ -195,6 +190,7 @@ public interface RowMapper {
/**
* 更新 entity主要用于进行批量更新的场景
*
* @param entity 实体类
* @see RowSqlProvider#updateEntity(Map)
* @see Db#updateEntitiesBatch(Collection, int)
@ -382,7 +378,9 @@ public interface RowMapper {
if (CollectionUtil.isEmpty(selectColumns)) {
queryWrapper.select(count());
}
Object object = selectObjectByQuery(tableName, queryWrapper);
List<Object> objects = selectObjectListByQuery(tableName, queryWrapper);
Object object = objects == null || objects.isEmpty() ? null : objects.get(0);
if (object == null) {
return 0;
} else if (object instanceof Number) {
@ -403,17 +401,48 @@ public interface RowMapper {
*/
default Page<Row> paginate(String tableName, Page<Row> page, QueryWrapper queryWrapper) {
List<QueryColumn> groupByColumns = CPI.getGroupByColumns(queryWrapper);
CPI.setFromIfNecessary(queryWrapper, tableName);
List<QueryColumn> selectColumns = CPI.getSelectColumns(queryWrapper);
List<Join> joins = CPI.getJoins(queryWrapper);
boolean removedJoins = true;
// 只有 totalRow 小于 0 的时候才会去查询总量
// 这样方便用户做总数缓存而非每次都要去查询总量
// 一般的分页场景中只有第一页的时候有必要去查询总量第二页以后是不需要的
if (page.getTotalRow() < 0) {
//清除group by 去查询数据
CPI.setGroupByColumns(queryWrapper, null);
CPI.setSelectColumns(queryWrapper, Collections.singletonList(count()));
CPI.setSelectColumns(queryWrapper, Collections.singletonList(count().as("total")));
if (joins != null && !joins.isEmpty()) {
for (Join join : joins) {
if (!Join.TYPE_LEFT.equals(CPI.getJoinType(join))) {
removedJoins = false;
break;
}
}
} else {
removedJoins = false;
}
if (removedJoins) {
List<String> joinTables = new ArrayList<>();
joins.forEach(join -> {
QueryTable joinQueryTable = CPI.getJoinQueryTable(join);
if (joinQueryTable != null && StringUtil.isNotBlank(joinQueryTable.getName())) {
joinTables.add(joinQueryTable.getName());
}
});
QueryCondition where = CPI.getWhereQueryCondition(queryWrapper);
if (CPI.containsTable(where, CollectionUtil.toArrayString(joinTables))) {
removedJoins = false;
}
}
if (removedJoins) {
CPI.setJoins(queryWrapper, null);
}
long count = selectCountByQuery(tableName, queryWrapper);
page.setTotalRow(count);
@ -423,17 +452,21 @@ public interface RowMapper {
return page;
}
//恢复数量查询清除的 groupBy
CPI.setGroupByColumns(queryWrapper, groupByColumns);
//重置 selectColumns
CPI.setSelectColumns(queryWrapper, selectColumns);
//重置 join
if (removedJoins) {
CPI.setJoins(queryWrapper, joins);
}
int offset = page.getPageSize() * (page.getPageNumber() - 1);
queryWrapper.limit(offset, page.getPageSize());
List<Row> records = selectListByQuery(tableName, queryWrapper);
page.setRecords(records);
return page;
}
}

View File

@ -91,4 +91,16 @@ public class CollectionUtil {
}
}
public static String[] toArrayString(Collection<?> collection) {
if (isEmpty(collection)) {
return new String[0];
}
String[] results = new String[collection.size()];
int index = 0;
for (Object o : collection) {
results[index++] = String.valueOf(o);
}
return results;
}
}

View File

@ -50,7 +50,7 @@ public class SqlUtil {
}
private static final char[] UN_SAFE_CHARS = "'`\"<>&*+=#-;".toCharArray();
private static final char[] UN_SAFE_CHARS = "'`\"<>&+=#-;".toCharArray();
private static boolean isUnSafeChar(char ch) {
for (char c : UN_SAFE_CHARS) {