diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/BatchArgsSetter.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/BatchArgsSetter.java
new file mode 100644
index 00000000..5a58e8e6
--- /dev/null
+++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/BatchArgsSetter.java
@@ -0,0 +1,25 @@
+/**
+ * Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.mybatisflex.core.row;
+
+public interface BatchArgsSetter {
+
+ Object[] NONE_ARGS = new Object[0];
+
+ int getBatchSize();
+
+ Object[] getSqlArgs(int index);
+}
diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/Db.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/Db.java
index ce0ff1b1..428eb127 100644
--- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/Db.java
+++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/Db.java
@@ -86,7 +86,7 @@ public class Db {
* @param tableName 表名
* @param rows 数据
*/
- public static int[] insertBatch(String tableName, Collection rows) {
+ public static int[] insertBatch(String tableName, List rows) {
return insertBatch(tableName, rows, rows.size());
}
@@ -97,12 +97,12 @@ public class Db {
* @param rows 数据
* @param batchSize 每次提交的数据量
*/
- public static int[] insertBatch(String tableName, Collection rows, int batchSize) {
+ public static int[] insertBatch(String tableName, List rows, int batchSize) {
return invoker().insertBatch(tableName, rows, batchSize);
}
/**
- * 批量插入数据,根据第一条内容来构建插入的字段,效率比 {@link #insertBatch(String, Collection, int)} 高
+ * 批量插入数据,根据第一条内容来构建插入的字段,效率比 {@link #insertBatch(String, List, int)} 高
*
* @param tableName 表名
* @param rows 数据
@@ -196,6 +196,17 @@ public class Db {
}
+ /**
+ *
+ * @param sql
+ * @param batchArgsSetter
+ * @return
+ */
+ public static int[] updateBatch(String sql, BatchArgsSetter batchArgsSetter){
+ return invoker().updateBatch(sql, batchArgsSetter);
+ }
+
+
/**
* 根据 id 来更新数据
*
diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/RowMapperInvoker.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/RowMapperInvoker.java
index 354c2e40..a6279056 100644
--- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/RowMapperInvoker.java
+++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/RowMapperInvoker.java
@@ -24,6 +24,7 @@ import org.apache.ibatis.session.SqlSessionFactory;
import java.util.Collection;
import java.util.List;
+import java.util.function.BiConsumer;
import java.util.function.Function;
public class RowMapperInvoker {
@@ -50,39 +51,11 @@ public class RowMapperInvoker {
return execute(mapper -> mapper.insertBySql(sql, args));
}
-
- public int[] insertBatch(String tableName, Collection rows, int batchSize) {
- int[] results = new int[rows.size()];
- try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, true)) {
- RowMapper mapper = sqlSession.getMapper(RowMapper.class);
- int counter = 0;
- int resultsPos = 0;
- for (Row row : rows) {
- if (++counter > batchSize) {
- counter = 0;
- List batchResults = sqlSession.flushStatements();
- for (BatchResult batchResult : batchResults) {
- int[] updateCounts = batchResult.getUpdateCounts();
- for (int updateCount : updateCounts) {
- results[resultsPos++] = updateCount;
- }
- }
- } else {
- mapper.insert(tableName, row);
- }
- }
-
- if (counter != 0) {
- List batchResults = sqlSession.flushStatements();
- for (BatchResult batchResult : batchResults) {
- int[] updateCounts = batchResult.getUpdateCounts();
- for (int updateCount : updateCounts) {
- results[resultsPos++] = updateCount;
- }
- }
- }
- }
- return results;
+ public int[] insertBatch(String tableName, List rows, int batchSize) {
+ return executeBatch(rows.size(), batchSize, (mapper, index) -> {
+ Row row = rows.get(index);
+ mapper.insert(tableName, row);
+ });
}
public int insertBatchWithFirstRowColumns(String tableName, List rows) {
@@ -114,6 +87,48 @@ public class RowMapperInvoker {
return execute(mapper -> mapper.updateBySql(sql, args));
}
+ public int[] updateBatch(String sql, BatchArgsSetter batchArgsSetter) {
+ int batchSize = batchArgsSetter.getBatchSize();
+ return executeBatch(batchSize, batchSize,
+ (mapper, index) -> mapper.updateBySql(sql, batchArgsSetter.getSqlArgs(index))
+ );
+ }
+
+
+ public int[] executeBatch(int totalSize, int batchSize, BiConsumer consumer) {
+ int[] results = new int[totalSize];
+ try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, true)) {
+ RowMapper mapper = sqlSession.getMapper(RowMapper.class);
+ int counter = 0;
+ int resultsPos = 0;
+ for (int i = 0; i < batchSize; i++) {
+ if (++counter > batchSize) {
+ counter = 0;
+ List batchResults = sqlSession.flushStatements();
+ for (BatchResult batchResult : batchResults) {
+ int[] updateCounts = batchResult.getUpdateCounts();
+ for (int updateCount : updateCounts) {
+ results[resultsPos++] = updateCount;
+ }
+ }
+ } else {
+ consumer.accept(mapper, i);
+ }
+ }
+
+ if (counter != 0) {
+ List batchResults = sqlSession.flushStatements();
+ for (BatchResult batchResult : batchResults) {
+ int[] updateCounts = batchResult.getUpdateCounts();
+ for (int updateCount : updateCounts) {
+ results[resultsPos++] = updateCount;
+ }
+ }
+ }
+ }
+ return results;
+ }
+
public int updateById(String tableName, Row row) {
return execute(mapper -> mapper.updateById(tableName, row));
}
diff --git a/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTestStarter.java b/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTestStarter.java
index 7dfb70ac..b38459bf 100644
--- a/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTestStarter.java
+++ b/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTestStarter.java
@@ -16,6 +16,7 @@
package com.mybatisflex.test;
import com.mybatisflex.core.MybatisFlexBootstrap;
+import com.mybatisflex.core.row.BatchArgsSetter;
import com.mybatisflex.core.row.Db;
import com.mybatisflex.core.row.Row;
import com.mybatisflex.core.row.RowKey;
@@ -57,6 +58,25 @@ public class DbTestStarter {
//查看刚刚插入数据的主键 id
System.out.println(">>>>>>>>>id: " + row.get("id"));
+ //INSERT INTO tb_account
+ //VALUES (1, '张三', 18, 0,'2020-01-11', null,0),
+
+ Db.updateBatch("insert into tb_account(user_name,age,birthday) values (?,?,?)", new BatchArgsSetter() {
+ @Override
+ public int getBatchSize() {
+ return 10;
+ }
+
+ @Override
+ public Object[] getSqlArgs(int index) {
+ Object[] args = new Object[3];
+ args[0] = "michael yang";
+ args[1] = 18 + index;
+ args[2] = new Date();
+ return args;
+ }
+ });
+
//再次查询全部数据
rows = Db.selectAll("tb_account");