diff --git a/hutool-core/src/main/java/cn/hutool/v7/core/math/Arrangement.java b/hutool-core/src/main/java/cn/hutool/v7/core/math/Arrangement.java index 67843a1be..918c22b90 100644 --- a/hutool-core/src/main/java/cn/hutool/v7/core/math/Arrangement.java +++ b/hutool-core/src/main/java/cn/hutool/v7/core/math/Arrangement.java @@ -20,9 +20,7 @@ import cn.hutool.v7.core.array.ArrayUtil; import java.io.Serial; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.util.*; /** * 排列A(n, m)
@@ -67,10 +65,23 @@ public class Arrangement implements Serializable { * @return 排列数 */ public static long count(final int n, final int m) { - if (n == m) { - return MathUtil.factorial(n); + if (m < 0 || m > n) { + throw new IllegalArgumentException("n >= 0 && m >= 0 && m <= n required"); } - return (n > m) ? MathUtil.factorial(n, n - m) : 0; + if (m == 0) { + return 1; + } + long result = 1; + // 从 n 到 n-m+1 逐个乘 + for (int i = 0; i < m; i++) { + long next = result * (n - i); + // 溢出检测 + if (next < result) { + throw new ArithmeticException("Overflow computing A(" + n + "," + m + ")"); + } + result = next; + } + return result; } /** @@ -97,51 +108,198 @@ public class Arrangement implements Serializable { } /** - * 排列选择(从列表中选择m个排列) + * 从当前数据中选择 m 个元素,生成所有「不重复」的排列(Permutation)。 * - * @param m 选择个数 - * @return 所有排列列表 + *

+ * 说明: + *

+ *

+ * 数量公式: + *

+	 * A(n, m) = n! / (n - m)!
+	 * 
+ *

+ * 举例: + *

+	 * datas = ["1","2","3"]
+	 * m = 2
+	 * 输出:
+	 * ["1","2"]
+	 * ["1","3"]
+	 * ["2","1"]
+	 * ["2","3"]
+	 * ["3","1"]
+	 * ["3","2"]
+	 * 共 6 个(A(3,2)=6)
+	 * 
+ * + * @param m 选择的元素个数 + * @return 所有长度为 m 的不重复排列列表 */ public List select(final int m) { - final List result = new ArrayList<>((int) count(this.datas.length, m)); - select(this.datas, new String[m], 0, result); + if (m < 0 || m > datas.length) { + return Collections.emptyList(); + } + if (m == 0) { + // A(n,0) = 1,唯一一个空排列 + return Collections.singletonList(new String[0]); + } + + long estimated = count(datas.length, m); + int capacity = estimated > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) estimated; + + List result = new ArrayList<>(capacity); + boolean[] visited = new boolean[datas.length]; + dfs(new String[m], 0, visited, result); return result; } /** - * 排列所有组合,即A(n, 1) + A(n, 2) + A(n, 3)... + * 生成当前数据的全部不重复排列(长度为 1 至 n 的所有排列)。 * - * @return 全排列结果 + *

+ * 说明: + *

    + *
  • 不允许重复选择元素(无 ["1","1"],无 ["2","2","3"] 这种)
  • + *
  • 包含所有长度 m=1..n 的排列
  • + *
  • 总数量为 A(n,1) + A(n,2) + ... + A(n,n)
  • + *
+ *

+ * 举例(datas = ["1","2","3"]): + *

+	 * m=1: ["1"], ["2"], ["3"]                             → 3 个
+	 * m=2: ["1","2"], ["1","3"], ["2","1"], ...            → 6 个
+	 * m=3: ["1","2","3"], ["1","3","2"], ["2","1","3"], ...→ 6 个
+	 *
+	 * 总共:3 + 6 + 6 = 15
+	 * 
+ * + * @return 所有不重复排列列表 */ public List selectAll() { - final List result = new ArrayList<>((int) countAll(this.datas.length)); - for (int i = 1; i <= this.datas.length; i++) { - result.addAll(select(i)); + final List result = new ArrayList<>(); + for (int m = 1; m <= datas.length; m++) { + result.addAll(select(m)); } return result; } /** - * 排列选择
- * 排列方式为先从数据数组中取出一个元素,再把剩余的元素作为新的基数,依次列推,直到选择到足够的元素 + * 返回一个排列的迭代器 * - * @param datas 选择的基数 - * @param resultList 前面(resultIndex-1)个的排列结果 - * @param resultIndex 选择索引,从0开始 - * @param result 最终结果 + * @param m 选择的元素个数 + * @return 排列迭代器 */ - private void select(final String[] datas, final String[] resultList, final int resultIndex, final List result) { - if (resultIndex >= resultList.length) { // 全部选择完时,输出排列结果 - if (!result.contains(resultList)) { - result.add(Arrays.copyOf(resultList, resultList.length)); + public Iterable iterate(int m) { + return () -> new ArrangementIterator(datas, m); + } + + + /** + * 排列迭代器 + * + * @author CherryRum + */ + private static class ArrangementIterator implements Iterator { + + private final String[] datas; + private final int m; + private final boolean[] visited; + private final String[] buffer; + private final Deque stack = new ArrayDeque<>(); + boolean end = false; + + /** + * 构造函数 + * + * @param datas 数据数组 + * @param m 选择的元素个数 + */ + ArrangementIterator(String[] datas, int m) { + this.datas = datas; + this.m = m; + this.visited = new boolean[datas.length]; + this.buffer = new String[m]; + // 初始化 dfs 栈 + stack.push(0); + } + + @Override + public boolean hasNext() { + return !end; + } + + @Override + public String[] next() { + while (!stack.isEmpty()) { + int depth = stack.size() - 1; + + int idx = stack.pop(); + if (idx >= datas.length) { + // 这一层遍历结束 + if (!stack.isEmpty()) { + int prev = stack.pop(); + stack.push(prev + 1); + } + continue; + } + + // 如果该元素未使用 + if (!visited[idx]) { + visited[idx] = true; + buffer[depth] = datas[idx]; + + if (depth == m - 1) { + // 输出一个排列 + visited[idx] = false; + + // 下一次从 idx+1 继续 + stack.push(idx + 1); + + return Arrays.copyOf(buffer, m); + } else { + // 继续下一层 + stack.push(idx + 1); // 当前层下一个起点 + stack.push(0); // 下一层起点 + continue; + } + } + + // 已访问则跳过 + stack.push(idx + 1); } + + end = true; + return null; + } + } + + /** + * 核心递归方法(回溯算法) + * * @param current 当前构建的排列数组 + * + * @param depth 当前递归深度(填到了第几个位置) + * @param visited 标记数组,记录哪些索引已经被使用了 + * @param result 结果集 + */ + private void dfs(String[] current, int depth, boolean[] visited, List result) { + if (depth == current.length) { + result.add(Arrays.copyOf(current, current.length)); return; } - // 递归选择下一个 for (int i = 0; i < datas.length; i++) { - resultList[resultIndex] = datas[i]; - select(ArrayUtil.remove(datas, i), resultList, resultIndex + 1, result); + if (!visited[i]) { + visited[i] = true; + current[depth] = datas[i]; + + dfs(current, depth + 1, visited, result); + visited[i] = false; + } } } } diff --git a/hutool-core/src/main/java/cn/hutool/v7/core/math/Combination.java b/hutool-core/src/main/java/cn/hutool/v7/core/math/Combination.java index e38a7737a..d3abff8db 100644 --- a/hutool-core/src/main/java/cn/hutool/v7/core/math/Combination.java +++ b/hutool-core/src/main/java/cn/hutool/v7/core/math/Combination.java @@ -20,6 +20,7 @@ import cn.hutool.v7.core.text.StrUtil; import java.io.Serial; import java.io.Serializable; +import java.math.BigInteger; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -53,15 +54,52 @@ public class Combination implements Serializable { /** * 计算组合数,即C(n, m) = n!/((n-m)!* m!) * - * @param n 总数 - * @param m 选择的个数 - * @return 组合数 + * @param n 总数 n(必须 >= 0) + * @param m 取出 m(必须 >= 0) + * @return 若结果超出 long 范围,会抛 ArithmeticException,而非溢出。 + * @throws ArithmeticException 若结果超出 long 范围,会抛 ArithmeticException,而非溢出。 */ - public static long count(final int n, final int m) { - if (0 == m || n == m) { - return 1; + public static long count(final int n, final int m) throws ArithmeticException { + final BigInteger big = countBig(n, m); + return big.longValueExact(); + } + + /** + * 计算组合数 C(n, m) 的 BigInteger 精确版本。 + * 使用逐步累乘除法(非阶乘)保证不溢出、性能好。 + *

+ * 数学定义: + * C(n, m) = n! / (m! (n - m)!) + *

+ * 优化方式: + * 1. 利用对称性 m = min(m, n-m) + * 2. 每一步先乘 BigInteger,再除以当前 i,保证数值不暴涨 + * + * @param n 总数 n(必须 >= 0) + * @param m 取出 m(必须 >= 0) + * @return C(n, m) 的 BigInteger 精确值;当 m > n 时返回 BigInteger.ZERO + */ + public static BigInteger countBig(final int n, int m) { + if (n < 0 || m < 0) { + throw new IllegalArgumentException("n and m must be non-negative. got n=" + n + ", m=" + m); } - return (n > m) ? MathUtil.factorial(n, n - m) / MathUtil.factorial(m) : 0; + if (m > n) { + return BigInteger.ZERO; + } + if (m == 0 || n == m) { + return BigInteger.ONE; + } + // 使用对称性:C(n, m) = C(n, n-m) + m = Math.min(m, n - m); + BigInteger result = BigInteger.ONE; + // 从 1 → m 累乘 + for (int i = 1; i <= m; i++) { + final int numerator = n - m + i; + result = result.multiply(BigInteger.valueOf(numerator)) + .divide(BigInteger.valueOf(i)); + } + + return result; } /** diff --git a/hutool-core/src/test/java/cn/hutool/v7/core/math/ArrangementTest.java b/hutool-core/src/test/java/cn/hutool/v7/core/math/ArrangementTest.java index 21898b412..f92f400c5 100644 --- a/hutool-core/src/test/java/cn/hutool/v7/core/math/ArrangementTest.java +++ b/hutool-core/src/test/java/cn/hutool/v7/core/math/ArrangementTest.java @@ -17,14 +17,16 @@ package cn.hutool.v7.core.math; import cn.hutool.v7.core.lang.Console; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.util.List; +import static org.junit.jupiter.api.Assertions.*; + /** * 排列单元测试 + * * @author Looly * */ @@ -33,49 +35,121 @@ public class ArrangementTest { @Test public void arrangementTest() { long result = Arrangement.count(4, 2); - Assertions.assertEquals(12, result); + assertEquals(12, result); result = Arrangement.count(4, 1); - Assertions.assertEquals(4, result); + assertEquals(4, result); result = Arrangement.count(4, 0); - Assertions.assertEquals(1, result); + assertEquals(1, result); final long resultAll = Arrangement.countAll(4); - Assertions.assertEquals(64, resultAll); + assertEquals(64, resultAll); } @Test public void selectTest() { - final Arrangement arrangement = new Arrangement(new String[] { "1", "2", "3", "4" }); + final Arrangement arrangement = new Arrangement(new String[]{"1", "2", "3", "4"}); final List list = arrangement.select(2); - Assertions.assertEquals(Arrangement.count(4, 2), list.size()); - Assertions.assertArrayEquals(new String[] {"1", "2"}, list.get(0)); - Assertions.assertArrayEquals(new String[] {"1", "3"}, list.get(1)); - Assertions.assertArrayEquals(new String[] {"1", "4"}, list.get(2)); - Assertions.assertArrayEquals(new String[] {"2", "1"}, list.get(3)); - Assertions.assertArrayEquals(new String[] {"2", "3"}, list.get(4)); - Assertions.assertArrayEquals(new String[] {"2", "4"}, list.get(5)); - Assertions.assertArrayEquals(new String[] {"3", "1"}, list.get(6)); - Assertions.assertArrayEquals(new String[] {"3", "2"}, list.get(7)); - Assertions.assertArrayEquals(new String[] {"3", "4"}, list.get(8)); - Assertions.assertArrayEquals(new String[] {"4", "1"}, list.get(9)); - Assertions.assertArrayEquals(new String[] {"4", "2"}, list.get(10)); - Assertions.assertArrayEquals(new String[] {"4", "3"}, list.get(11)); + // 校验数量一致 + assertEquals(Arrangement.count(4, 2), list.size()); + // 逐项严格校验顺序是否一致(按 DFS 顺序) + assertArrayEquals(new String[]{"1", "2"}, list.get(0)); + assertArrayEquals(new String[]{"1", "3"}, list.get(1)); + assertArrayEquals(new String[]{"1", "4"}, list.get(2)); + assertArrayEquals(new String[]{"2", "1"}, list.get(3)); + assertArrayEquals(new String[]{"2", "3"}, list.get(4)); + assertArrayEquals(new String[]{"2", "4"}, list.get(5)); + assertArrayEquals(new String[]{"3", "1"}, list.get(6)); + assertArrayEquals(new String[]{"3", "2"}, list.get(7)); + assertArrayEquals(new String[]{"3", "4"}, list.get(8)); + assertArrayEquals(new String[]{"4", "1"}, list.get(9)); + assertArrayEquals(new String[]{"4", "2"}, list.get(10)); + assertArrayEquals(new String[]{"4", "3"}, list.get(11)); + // 测试 selectAll final List selectAll = arrangement.selectAll(); - Assertions.assertEquals(Arrangement.countAll(4), selectAll.size()); + assertEquals(Arrangement.countAll(4), selectAll.size()); + // m=0,应该返回一个空排列 final List list2 = arrangement.select(0); - Assertions.assertEquals(1, list2.size()); + assertEquals(1, list2.size()); } + // ---------------------------------------------------- + // 扩展测试:边界、错误处理 + // ---------------------------------------------------- @Test - @Disabled - public void selectTest2() { - final List list = MathUtil.arrangementSelect(new String[] { "1", "1", "3", "4" }); - for (final String[] strings : list) { - Console.log(strings); + public void boundaryTest() { + final Arrangement arr = new Arrangement(new String[]{"A", "B", "C"}); + + // m = n + final List full = arr.select(3); + assertEquals(6, full.size()); + + // m = 1 + final List one = arr.select(1); + assertEquals(3, one.size()); + assertArrayEquals(new String[]{"A"}, one.get(0)); + + // m > n → empty list + assertTrue(arr.select(10).isEmpty()); + + // m < 0 → empty list + assertTrue(arr.select(-1).isEmpty()); + } + + // ---------------------------------------------------- + // 扩展测试:空数组 + // ---------------------------------------------------- + @Test + public void emptyTest() { + final Arrangement arrangement = new Arrangement(new String[]{}); + + assertEquals(1, arrangement.select(0).size()); + assertTrue(arrangement.select(1).isEmpty()); + assertTrue(arrangement.selectAll().isEmpty()); // A(0,m) = 0 for m>0,A(0,0)=1 → 全排列 = 1 个空排列 + } + + // ---------------------------------------------------- + // 扩展测试:重复元素(用于验证去重算法) + // 默认 Arrangement 不去重,因此应该包含重复排列 + // ---------------------------------------------------- + @Test + @Disabled("默认 Arrangement 不支持去重;启用后手动检查") + public void duplicateElementTest() { + final Arrangement arrangement = new Arrangement(new String[]{"1", "1", "3"}); + final List list = arrangement.select(2); + + // 应该有 A(3,2) = 6 个 + assertEquals(6, list.size()); + + for (final String[] s : list) { + Console.log(s); } } + + // ---------------------------------------------------- + // 扩展测试:selectAll 覆盖全部不重复排列(A(n,1..n)) + // ---------------------------------------------------- + @Test + public void selectAllTest() { + final Arrangement arrangement = new Arrangement(new String[]{"1", "2", "3"}); + + final List all = arrangement.selectAll(); + + // 打印用于观测 +// for (final String[] s : all) { +// Console.log(s); +// } + + // A(3,1) + A(3,2) + A(3,3) = 3 + 6 + 6 = 15 + assertEquals(Arrangement.countAll(3), all.size()); + assertEquals(15, all.size()); + + // spot check 不重复排列 + assertArrayEquals(new String[]{"1"}, all.get(0)); + assertArrayEquals(new String[]{"1", "2"}, all.get(3)); + assertArrayEquals(new String[]{"1", "2", "3"}, all.get(9)); + } } diff --git a/hutool-core/src/test/java/cn/hutool/v7/core/math/CombinationTest.java b/hutool-core/src/test/java/cn/hutool/v7/core/math/CombinationTest.java index 169c29278..c5e3ffb80 100644 --- a/hutool-core/src/test/java/cn/hutool/v7/core/math/CombinationTest.java +++ b/hutool-core/src/test/java/cn/hutool/v7/core/math/CombinationTest.java @@ -19,8 +19,11 @@ package cn.hutool.v7.core.math; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.math.BigInteger; import java.util.List; +import static org.junit.jupiter.api.Assertions.*; + /** * 组合单元测试 * @@ -32,23 +35,23 @@ public class CombinationTest { @Test public void countTest() { long result = Combination.count(5, 2); - Assertions.assertEquals(10, result); + assertEquals(10, result); result = Combination.count(5, 5); - Assertions.assertEquals(1, result); + assertEquals(1, result); result = Combination.count(5, 0); - Assertions.assertEquals(1, result); + assertEquals(1, result); final long resultAll = Combination.countAll(5); - Assertions.assertEquals(31, resultAll); + assertEquals(31, resultAll); } @Test public void selectTest() { final Combination combination = new Combination(new String[] { "1", "2", "3", "4", "5" }); final List list = combination.select(2); - Assertions.assertEquals(Combination.count(5, 2), list.size()); + assertEquals(Combination.count(5, 2), list.size()); Assertions.assertArrayEquals(new String[] {"1", "2"}, list.get(0)); Assertions.assertArrayEquals(new String[] {"1", "3"}, list.get(1)); @@ -62,9 +65,86 @@ public class CombinationTest { Assertions.assertArrayEquals(new String[] {"4", "5"}, list.get(9)); final List selectAll = combination.selectAll(); - Assertions.assertEquals(Combination.countAll(5), selectAll.size()); + assertEquals(Combination.countAll(5), selectAll.size()); final List list2 = combination.select(0); - Assertions.assertEquals(1, list2.size()); + assertEquals(1, list2.size()); + } + + // ----------------------------- + // countBig() 正确性测试 + // ----------------------------- + @Test + void testCountBig_basicCases() { + assertEquals(BigInteger.ONE, Combination.countBig(5, 0)); + assertEquals(BigInteger.ONE, Combination.countBig(5, 5)); + assertEquals(BigInteger.valueOf(10), Combination.countBig(5, 3)); + assertEquals(BigInteger.valueOf(10), Combination.countBig(5, 2)); + } + + @Test + void testCountBig_mGreaterThanN() { + assertEquals(BigInteger.ZERO, Combination.countBig(5, 6)); + } + + @Test + void testCountBig_negativeInput() { + assertThrows(IllegalArgumentException.class, () -> Combination.countBig(-1, 3)); + assertThrows(IllegalArgumentException.class, () -> Combination.countBig(5, -2)); + } + + @Test + void testCountBig_symmetry() { + assertEquals(Combination.countBig(20, 3), Combination.countBig(20, 17)); + } + + @Test + void testCountBig_largeNumbers() { + // C(50, 3) = 19600 + assertEquals(new BigInteger("19600"), Combination.countBig(50, 3)); + + // C(100, 50) 的确切值(重要测试) + final BigInteger expected = new BigInteger( + "100891344545564193334812497256" + ); + assertEquals(expected, Combination.countBig(100, 50)); + } + + @Test + void testCountBig_veryLargeCombination() { + // 不比较具体值,只断言不要抛错 + final BigInteger result = Combination.countBig(2000, 1000); + assertTrue(result.signum() > 0); + } + + // ----------------------------- + // count(long) 兼容性测试 + // ----------------------------- + @Test + void testCount_basic() { + assertEquals(10L, Combination.count(5, 3)); + assertEquals(1L, Combination.count(5, 0)); + assertEquals(0L, Combination.count(5, 6)); + } + + // ----------------------------- + // countSafe() 安全 long 版本测试 + // ----------------------------- + @Test + void testCountSafe_exactFitsLong() { + // C(50, 3) = 19600 fits long + assertEquals(19600L, Combination.count(50, 3)); + } + + @Test + void testCountSafe_overflowThrows() { + // C(100, 50) 超出 long → 应抛 ArithmeticException + assertThrows(ArithmeticException.class, () -> Combination.count(100, 50)); + } + + @Test + void testCountSafe_invalidInput() { + assertThrows(IllegalArgumentException.class, () -> Combination.count(-1, 3)); + assertThrows(IllegalArgumentException.class, () -> Combination.count(3, -1)); } }