diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/query/QueryCondition.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/query/QueryCondition.java index a21d4413..98d56480 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/query/QueryCondition.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/query/QueryCondition.java @@ -19,6 +19,7 @@ package com.mybatisflex.core.query; import com.mybatisflex.core.dialect.IDialect; import java.io.Serializable; +import java.lang.reflect.Array; import java.util.List; import java.util.function.Supplier; @@ -98,48 +99,6 @@ public class QueryCondition implements Serializable { this.logic = logic; } - /** - * 计算问号(?)的数量 - * - * @return 问号的数量 - */ - public int calculateQuestionMarkCount() { - if (LOGIC_IS_NULL.equals(logic) - || LOGIC_IS_NOT_NULL.equals(logic) - || value instanceof QueryColumn - || value instanceof QueryWrapper) { - return 0; - } - //between, not between - else if (LOGIC_BETWEEN.equals(logic) || LOGIC_NOT_BETWEEN.equals(logic)) { - return 2; - } - //in, not in - else if (LOGIC_IN.equals(logic) || LOGIC_NOT_IN.equals(logic)) { - return calculateValueArrayCount(); - } - // - else { - return 1; - } - } - - private int calculateValueArrayCount() { - Object[] values = (Object[]) value; - int paramsCount = 0; - for (Object v : values) { - if (v.getClass() == int[].class) { - paramsCount += ((int[]) v).length; - } else if (v.getClass() == long[].class) { - paramsCount += ((long[]) v).length; - } else if (v.getClass() == short[].class) { - paramsCount += ((short[]) v).length; - } else { - paramsCount++; - } - } - return paramsCount; - } public QueryCondition when(boolean effective) { this.effective = effective; @@ -194,7 +153,7 @@ public class QueryCondition implements Serializable { } //正常查询,构建问号 else { - appendQuestionMark(sql, calculateQuestionMarkCount()); + appendQuestionMark(sql); } } @@ -217,16 +176,21 @@ public class QueryCondition implements Serializable { } - protected static void appendQuestionMark(StringBuilder sqlBuilder, int paramsCount) { - if (paramsCount == 1) { - sqlBuilder.append(" ? "); + protected void appendQuestionMark(StringBuilder sqlBuilder) { + if (LOGIC_IS_NULL.equals(logic) + || LOGIC_IS_NOT_NULL.equals(logic) + || value instanceof QueryColumn + || value instanceof QueryWrapper) { + //do nothing } + //between, not between - else if (paramsCount == 2) { + else if (LOGIC_BETWEEN.equals(logic) || LOGIC_NOT_BETWEEN.equals(logic)) { sqlBuilder.append(" ? AND ? "); } //in, not in - else if (paramsCount > 0) { + else if (LOGIC_IN.equals(logic) || LOGIC_NOT_IN.equals(logic)) { + int paramsCount = calculateValueArrayCount(); sqlBuilder.append('('); for (int i = 0; i < paramsCount; i++) { sqlBuilder.append('?'); @@ -236,10 +200,29 @@ public class QueryCondition implements Serializable { } sqlBuilder.append(')'); } else { - // paramsCount == 0, ignore + sqlBuilder.append(" ? "); } } + + private int calculateValueArrayCount() { + Object[] values = (Object[]) value; + int paramsCount = 0; + for (Object object : values) { + if (object != null && (object.getClass().isArray() + || object.getClass() == int[].class + || object.getClass() == long[].class + || object.getClass() == short[].class + || object.getClass() == float[].class + || object.getClass() == double[].class)) { + paramsCount += Array.getLength(object); + } else { + paramsCount++; + } + } + return paramsCount; + } + @Override public String toString() { return "QueryCondition{" + diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/query/WrapperUtil.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/query/WrapperUtil.java index 7796d4bb..45a9930c 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/query/WrapperUtil.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/query/WrapperUtil.java @@ -19,6 +19,7 @@ package com.mybatisflex.core.query; import com.mybatisflex.core.util.CollectionUtil; import com.mybatisflex.core.util.StringUtil; +import java.lang.reflect.Array; import java.util.Arrays; import java.util.LinkedList; import java.util.List; @@ -52,15 +53,18 @@ class WrapperUtil { if (value != null) { if (value.getClass().isArray()) { Object[] values = (Object[]) value; - for (Object v : values) { - if (v.getClass() == int[].class) { - addAll(paras, (int[]) v); - } else if (v.getClass() == long[].class) { - addAll(paras, (long[]) v); - } else if (v.getClass() == short[].class) { - addAll(paras, (short[]) v); + for (Object object : values) { + if (object != null && (object.getClass().isArray() + || object.getClass() == int[].class + || object.getClass() == long[].class + || object.getClass() == short[].class + || object.getClass() == float[].class + || object.getClass() == double[].class)) { + for (int i = 0; i < Array.getLength(object); i++) { + paras.add(Array.get(object, i)); + } } else { - paras.add(v); + paras.add(object); } } } else if (value instanceof QueryWrapper) { @@ -75,26 +79,6 @@ class WrapperUtil { } - private static void addAll(List paras, int[] ints) { - for (int i : ints) { - paras.add(i); - } - } - - private static void addAll(List paras, long[] longs) { - for (long i : longs) { - paras.add(i); - } - } - - - private static void addAll(List paras, short[] shorts) { - for (short i : shorts) { - paras.add(i); - } - } - - public static String getColumnTableName(List queryTables, QueryTable queryTable) { if (queryTables == null) { return "";