result) {
+ if (depth == current.length) {
+ result.add(Arrays.copyOf(current, current.length));
return;
}
- // 递归选择下一个
for (int i = 0; i < datas.length; i++) {
- resultList[resultIndex] = datas[i];
- select(ArrayUtil.remove(datas, i), resultList, resultIndex + 1, result);
+ if (!visited[i]) {
+ visited[i] = true;
+ current[depth] = datas[i];
+
+ dfs(current, depth + 1, visited, result);
+ visited[i] = false;
+ }
}
}
}
diff --git a/hutool-core/src/main/java/cn/hutool/v7/core/math/Combination.java b/hutool-core/src/main/java/cn/hutool/v7/core/math/Combination.java
index e38a7737a..d3abff8db 100644
--- a/hutool-core/src/main/java/cn/hutool/v7/core/math/Combination.java
+++ b/hutool-core/src/main/java/cn/hutool/v7/core/math/Combination.java
@@ -20,6 +20,7 @@ import cn.hutool.v7.core.text.StrUtil;
import java.io.Serial;
import java.io.Serializable;
+import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -53,15 +54,52 @@ public class Combination implements Serializable {
/**
* 计算组合数,即C(n, m) = n!/((n-m)!* m!)
*
- * @param n 总数
- * @param m 选择的个数
- * @return 组合数
+ * @param n 总数 n(必须 >= 0)
+ * @param m 取出 m(必须 >= 0)
+ * @return 若结果超出 long 范围,会抛 ArithmeticException,而非溢出。
+ * @throws ArithmeticException 若结果超出 long 范围,会抛 ArithmeticException,而非溢出。
*/
- public static long count(final int n, final int m) {
- if (0 == m || n == m) {
- return 1;
+ public static long count(final int n, final int m) throws ArithmeticException {
+ final BigInteger big = countBig(n, m);
+ return big.longValueExact();
+ }
+
+ /**
+ * 计算组合数 C(n, m) 的 BigInteger 精确版本。
+ * 使用逐步累乘除法(非阶乘)保证不溢出、性能好。
+ *
+ * 数学定义:
+ * C(n, m) = n! / (m! (n - m)!)
+ *
+ * 优化方式:
+ * 1. 利用对称性 m = min(m, n-m)
+ * 2. 每一步先乘 BigInteger,再除以当前 i,保证数值不暴涨
+ *
+ * @param n 总数 n(必须 >= 0)
+ * @param m 取出 m(必须 >= 0)
+ * @return C(n, m) 的 BigInteger 精确值;当 m > n 时返回 BigInteger.ZERO
+ */
+ public static BigInteger countBig(final int n, int m) {
+ if (n < 0 || m < 0) {
+ throw new IllegalArgumentException("n and m must be non-negative. got n=" + n + ", m=" + m);
}
- return (n > m) ? MathUtil.factorial(n, n - m) / MathUtil.factorial(m) : 0;
+ if (m > n) {
+ return BigInteger.ZERO;
+ }
+ if (m == 0 || n == m) {
+ return BigInteger.ONE;
+ }
+ // 使用对称性:C(n, m) = C(n, n-m)
+ m = Math.min(m, n - m);
+ BigInteger result = BigInteger.ONE;
+ // 从 1 → m 累乘
+ for (int i = 1; i <= m; i++) {
+ final int numerator = n - m + i;
+ result = result.multiply(BigInteger.valueOf(numerator))
+ .divide(BigInteger.valueOf(i));
+ }
+
+ return result;
}
/**
diff --git a/hutool-core/src/test/java/cn/hutool/v7/core/math/ArrangementTest.java b/hutool-core/src/test/java/cn/hutool/v7/core/math/ArrangementTest.java
index 21898b412..f92f400c5 100644
--- a/hutool-core/src/test/java/cn/hutool/v7/core/math/ArrangementTest.java
+++ b/hutool-core/src/test/java/cn/hutool/v7/core/math/ArrangementTest.java
@@ -17,14 +17,16 @@
package cn.hutool.v7.core.math;
import cn.hutool.v7.core.lang.Console;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.*;
+
/**
* 排列单元测试
+ *
* @author Looly
*
*/
@@ -33,49 +35,121 @@ public class ArrangementTest {
@Test
public void arrangementTest() {
long result = Arrangement.count(4, 2);
- Assertions.assertEquals(12, result);
+ assertEquals(12, result);
result = Arrangement.count(4, 1);
- Assertions.assertEquals(4, result);
+ assertEquals(4, result);
result = Arrangement.count(4, 0);
- Assertions.assertEquals(1, result);
+ assertEquals(1, result);
final long resultAll = Arrangement.countAll(4);
- Assertions.assertEquals(64, resultAll);
+ assertEquals(64, resultAll);
}
@Test
public void selectTest() {
- final Arrangement arrangement = new Arrangement(new String[] { "1", "2", "3", "4" });
+ final Arrangement arrangement = new Arrangement(new String[]{"1", "2", "3", "4"});
final List list = arrangement.select(2);
- Assertions.assertEquals(Arrangement.count(4, 2), list.size());
- Assertions.assertArrayEquals(new String[] {"1", "2"}, list.get(0));
- Assertions.assertArrayEquals(new String[] {"1", "3"}, list.get(1));
- Assertions.assertArrayEquals(new String[] {"1", "4"}, list.get(2));
- Assertions.assertArrayEquals(new String[] {"2", "1"}, list.get(3));
- Assertions.assertArrayEquals(new String[] {"2", "3"}, list.get(4));
- Assertions.assertArrayEquals(new String[] {"2", "4"}, list.get(5));
- Assertions.assertArrayEquals(new String[] {"3", "1"}, list.get(6));
- Assertions.assertArrayEquals(new String[] {"3", "2"}, list.get(7));
- Assertions.assertArrayEquals(new String[] {"3", "4"}, list.get(8));
- Assertions.assertArrayEquals(new String[] {"4", "1"}, list.get(9));
- Assertions.assertArrayEquals(new String[] {"4", "2"}, list.get(10));
- Assertions.assertArrayEquals(new String[] {"4", "3"}, list.get(11));
+ // 校验数量一致
+ assertEquals(Arrangement.count(4, 2), list.size());
+ // 逐项严格校验顺序是否一致(按 DFS 顺序)
+ assertArrayEquals(new String[]{"1", "2"}, list.get(0));
+ assertArrayEquals(new String[]{"1", "3"}, list.get(1));
+ assertArrayEquals(new String[]{"1", "4"}, list.get(2));
+ assertArrayEquals(new String[]{"2", "1"}, list.get(3));
+ assertArrayEquals(new String[]{"2", "3"}, list.get(4));
+ assertArrayEquals(new String[]{"2", "4"}, list.get(5));
+ assertArrayEquals(new String[]{"3", "1"}, list.get(6));
+ assertArrayEquals(new String[]{"3", "2"}, list.get(7));
+ assertArrayEquals(new String[]{"3", "4"}, list.get(8));
+ assertArrayEquals(new String[]{"4", "1"}, list.get(9));
+ assertArrayEquals(new String[]{"4", "2"}, list.get(10));
+ assertArrayEquals(new String[]{"4", "3"}, list.get(11));
+ // 测试 selectAll
final List selectAll = arrangement.selectAll();
- Assertions.assertEquals(Arrangement.countAll(4), selectAll.size());
+ assertEquals(Arrangement.countAll(4), selectAll.size());
+ // m=0,应该返回一个空排列
final List list2 = arrangement.select(0);
- Assertions.assertEquals(1, list2.size());
+ assertEquals(1, list2.size());
}
+ // ----------------------------------------------------
+ // 扩展测试:边界、错误处理
+ // ----------------------------------------------------
@Test
- @Disabled
- public void selectTest2() {
- final List list = MathUtil.arrangementSelect(new String[] { "1", "1", "3", "4" });
- for (final String[] strings : list) {
- Console.log(strings);
+ public void boundaryTest() {
+ final Arrangement arr = new Arrangement(new String[]{"A", "B", "C"});
+
+ // m = n
+ final List full = arr.select(3);
+ assertEquals(6, full.size());
+
+ // m = 1
+ final List one = arr.select(1);
+ assertEquals(3, one.size());
+ assertArrayEquals(new String[]{"A"}, one.get(0));
+
+ // m > n → empty list
+ assertTrue(arr.select(10).isEmpty());
+
+ // m < 0 → empty list
+ assertTrue(arr.select(-1).isEmpty());
+ }
+
+ // ----------------------------------------------------
+ // 扩展测试:空数组
+ // ----------------------------------------------------
+ @Test
+ public void emptyTest() {
+ final Arrangement arrangement = new Arrangement(new String[]{});
+
+ assertEquals(1, arrangement.select(0).size());
+ assertTrue(arrangement.select(1).isEmpty());
+ assertTrue(arrangement.selectAll().isEmpty()); // A(0,m) = 0 for m>0,A(0,0)=1 → 全排列 = 1 个空排列
+ }
+
+ // ----------------------------------------------------
+ // 扩展测试:重复元素(用于验证去重算法)
+ // 默认 Arrangement 不去重,因此应该包含重复排列
+ // ----------------------------------------------------
+ @Test
+ @Disabled("默认 Arrangement 不支持去重;启用后手动检查")
+ public void duplicateElementTest() {
+ final Arrangement arrangement = new Arrangement(new String[]{"1", "1", "3"});
+ final List list = arrangement.select(2);
+
+ // 应该有 A(3,2) = 6 个
+ assertEquals(6, list.size());
+
+ for (final String[] s : list) {
+ Console.log(s);
}
}
+
+ // ----------------------------------------------------
+ // 扩展测试:selectAll 覆盖全部不重复排列(A(n,1..n))
+ // ----------------------------------------------------
+ @Test
+ public void selectAllTest() {
+ final Arrangement arrangement = new Arrangement(new String[]{"1", "2", "3"});
+
+ final List all = arrangement.selectAll();
+
+ // 打印用于观测
+// for (final String[] s : all) {
+// Console.log(s);
+// }
+
+ // A(3,1) + A(3,2) + A(3,3) = 3 + 6 + 6 = 15
+ assertEquals(Arrangement.countAll(3), all.size());
+ assertEquals(15, all.size());
+
+ // spot check 不重复排列
+ assertArrayEquals(new String[]{"1"}, all.get(0));
+ assertArrayEquals(new String[]{"1", "2"}, all.get(3));
+ assertArrayEquals(new String[]{"1", "2", "3"}, all.get(9));
+ }
}
diff --git a/hutool-core/src/test/java/cn/hutool/v7/core/math/CombinationTest.java b/hutool-core/src/test/java/cn/hutool/v7/core/math/CombinationTest.java
index 169c29278..c5e3ffb80 100644
--- a/hutool-core/src/test/java/cn/hutool/v7/core/math/CombinationTest.java
+++ b/hutool-core/src/test/java/cn/hutool/v7/core/math/CombinationTest.java
@@ -19,8 +19,11 @@ package cn.hutool.v7.core.math;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import java.math.BigInteger;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.*;
+
/**
* 组合单元测试
*
@@ -32,23 +35,23 @@ public class CombinationTest {
@Test
public void countTest() {
long result = Combination.count(5, 2);
- Assertions.assertEquals(10, result);
+ assertEquals(10, result);
result = Combination.count(5, 5);
- Assertions.assertEquals(1, result);
+ assertEquals(1, result);
result = Combination.count(5, 0);
- Assertions.assertEquals(1, result);
+ assertEquals(1, result);
final long resultAll = Combination.countAll(5);
- Assertions.assertEquals(31, resultAll);
+ assertEquals(31, resultAll);
}
@Test
public void selectTest() {
final Combination combination = new Combination(new String[] { "1", "2", "3", "4", "5" });
final List list = combination.select(2);
- Assertions.assertEquals(Combination.count(5, 2), list.size());
+ assertEquals(Combination.count(5, 2), list.size());
Assertions.assertArrayEquals(new String[] {"1", "2"}, list.get(0));
Assertions.assertArrayEquals(new String[] {"1", "3"}, list.get(1));
@@ -62,9 +65,86 @@ public class CombinationTest {
Assertions.assertArrayEquals(new String[] {"4", "5"}, list.get(9));
final List selectAll = combination.selectAll();
- Assertions.assertEquals(Combination.countAll(5), selectAll.size());
+ assertEquals(Combination.countAll(5), selectAll.size());
final List list2 = combination.select(0);
- Assertions.assertEquals(1, list2.size());
+ assertEquals(1, list2.size());
+ }
+
+ // -----------------------------
+ // countBig() 正确性测试
+ // -----------------------------
+ @Test
+ void testCountBig_basicCases() {
+ assertEquals(BigInteger.ONE, Combination.countBig(5, 0));
+ assertEquals(BigInteger.ONE, Combination.countBig(5, 5));
+ assertEquals(BigInteger.valueOf(10), Combination.countBig(5, 3));
+ assertEquals(BigInteger.valueOf(10), Combination.countBig(5, 2));
+ }
+
+ @Test
+ void testCountBig_mGreaterThanN() {
+ assertEquals(BigInteger.ZERO, Combination.countBig(5, 6));
+ }
+
+ @Test
+ void testCountBig_negativeInput() {
+ assertThrows(IllegalArgumentException.class, () -> Combination.countBig(-1, 3));
+ assertThrows(IllegalArgumentException.class, () -> Combination.countBig(5, -2));
+ }
+
+ @Test
+ void testCountBig_symmetry() {
+ assertEquals(Combination.countBig(20, 3), Combination.countBig(20, 17));
+ }
+
+ @Test
+ void testCountBig_largeNumbers() {
+ // C(50, 3) = 19600
+ assertEquals(new BigInteger("19600"), Combination.countBig(50, 3));
+
+ // C(100, 50) 的确切值(重要测试)
+ final BigInteger expected = new BigInteger(
+ "100891344545564193334812497256"
+ );
+ assertEquals(expected, Combination.countBig(100, 50));
+ }
+
+ @Test
+ void testCountBig_veryLargeCombination() {
+ // 不比较具体值,只断言不要抛错
+ final BigInteger result = Combination.countBig(2000, 1000);
+ assertTrue(result.signum() > 0);
+ }
+
+ // -----------------------------
+ // count(long) 兼容性测试
+ // -----------------------------
+ @Test
+ void testCount_basic() {
+ assertEquals(10L, Combination.count(5, 3));
+ assertEquals(1L, Combination.count(5, 0));
+ assertEquals(0L, Combination.count(5, 6));
+ }
+
+ // -----------------------------
+ // countSafe() 安全 long 版本测试
+ // -----------------------------
+ @Test
+ void testCountSafe_exactFitsLong() {
+ // C(50, 3) = 19600 fits long
+ assertEquals(19600L, Combination.count(50, 3));
+ }
+
+ @Test
+ void testCountSafe_overflowThrows() {
+ // C(100, 50) 超出 long → 应抛 ArithmeticException
+ assertThrows(ArithmeticException.class, () -> Combination.count(100, 50));
+ }
+
+ @Test
+ void testCountSafe_invalidInput() {
+ assertThrows(IllegalArgumentException.class, () -> Combination.count(-1, 3));
+ assertThrows(IllegalArgumentException.class, () -> Combination.count(3, -1));
}
}