diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/BatchArgsSetter.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/BatchArgsSetter.java new file mode 100644 index 00000000..5a58e8e6 --- /dev/null +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/BatchArgsSetter.java @@ -0,0 +1,25 @@ +/** + * Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com). + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mybatisflex.core.row; + +public interface BatchArgsSetter { + + Object[] NONE_ARGS = new Object[0]; + + int getBatchSize(); + + Object[] getSqlArgs(int index); +} diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/Db.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/Db.java index ce0ff1b1..428eb127 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/Db.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/Db.java @@ -86,7 +86,7 @@ public class Db { * @param tableName 表名 * @param rows 数据 */ - public static int[] insertBatch(String tableName, Collection rows) { + public static int[] insertBatch(String tableName, List rows) { return insertBatch(tableName, rows, rows.size()); } @@ -97,12 +97,12 @@ public class Db { * @param rows 数据 * @param batchSize 每次提交的数据量 */ - public static int[] insertBatch(String tableName, Collection rows, int batchSize) { + public static int[] insertBatch(String tableName, List rows, int batchSize) { return invoker().insertBatch(tableName, rows, batchSize); } /** - * 批量插入数据,根据第一条内容来构建插入的字段,效率比 {@link #insertBatch(String, Collection, int)} 高 + * 批量插入数据,根据第一条内容来构建插入的字段,效率比 {@link #insertBatch(String, List, int)} 高 * * @param tableName 表名 * @param rows 数据 @@ -196,6 +196,17 @@ public class Db { } + /** + * + * @param sql + * @param batchArgsSetter + * @return + */ + public static int[] updateBatch(String sql, BatchArgsSetter batchArgsSetter){ + return invoker().updateBatch(sql, batchArgsSetter); + } + + /** * 根据 id 来更新数据 * diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/RowMapperInvoker.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/RowMapperInvoker.java index 354c2e40..a6279056 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/RowMapperInvoker.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/row/RowMapperInvoker.java @@ -24,6 +24,7 @@ import org.apache.ibatis.session.SqlSessionFactory; import java.util.Collection; import java.util.List; +import java.util.function.BiConsumer; import java.util.function.Function; public class RowMapperInvoker { @@ -50,39 +51,11 @@ public class RowMapperInvoker { return execute(mapper -> mapper.insertBySql(sql, args)); } - - public int[] insertBatch(String tableName, Collection rows, int batchSize) { - int[] results = new int[rows.size()]; - try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, true)) { - RowMapper mapper = sqlSession.getMapper(RowMapper.class); - int counter = 0; - int resultsPos = 0; - for (Row row : rows) { - if (++counter > batchSize) { - counter = 0; - List batchResults = sqlSession.flushStatements(); - for (BatchResult batchResult : batchResults) { - int[] updateCounts = batchResult.getUpdateCounts(); - for (int updateCount : updateCounts) { - results[resultsPos++] = updateCount; - } - } - } else { - mapper.insert(tableName, row); - } - } - - if (counter != 0) { - List batchResults = sqlSession.flushStatements(); - for (BatchResult batchResult : batchResults) { - int[] updateCounts = batchResult.getUpdateCounts(); - for (int updateCount : updateCounts) { - results[resultsPos++] = updateCount; - } - } - } - } - return results; + public int[] insertBatch(String tableName, List rows, int batchSize) { + return executeBatch(rows.size(), batchSize, (mapper, index) -> { + Row row = rows.get(index); + mapper.insert(tableName, row); + }); } public int insertBatchWithFirstRowColumns(String tableName, List rows) { @@ -114,6 +87,48 @@ public class RowMapperInvoker { return execute(mapper -> mapper.updateBySql(sql, args)); } + public int[] updateBatch(String sql, BatchArgsSetter batchArgsSetter) { + int batchSize = batchArgsSetter.getBatchSize(); + return executeBatch(batchSize, batchSize, + (mapper, index) -> mapper.updateBySql(sql, batchArgsSetter.getSqlArgs(index)) + ); + } + + + public int[] executeBatch(int totalSize, int batchSize, BiConsumer consumer) { + int[] results = new int[totalSize]; + try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, true)) { + RowMapper mapper = sqlSession.getMapper(RowMapper.class); + int counter = 0; + int resultsPos = 0; + for (int i = 0; i < batchSize; i++) { + if (++counter > batchSize) { + counter = 0; + List batchResults = sqlSession.flushStatements(); + for (BatchResult batchResult : batchResults) { + int[] updateCounts = batchResult.getUpdateCounts(); + for (int updateCount : updateCounts) { + results[resultsPos++] = updateCount; + } + } + } else { + consumer.accept(mapper, i); + } + } + + if (counter != 0) { + List batchResults = sqlSession.flushStatements(); + for (BatchResult batchResult : batchResults) { + int[] updateCounts = batchResult.getUpdateCounts(); + for (int updateCount : updateCounts) { + results[resultsPos++] = updateCount; + } + } + } + } + return results; + } + public int updateById(String tableName, Row row) { return execute(mapper -> mapper.updateById(tableName, row)); } diff --git a/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTestStarter.java b/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTestStarter.java index 7dfb70ac..b38459bf 100644 --- a/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTestStarter.java +++ b/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTestStarter.java @@ -16,6 +16,7 @@ package com.mybatisflex.test; import com.mybatisflex.core.MybatisFlexBootstrap; +import com.mybatisflex.core.row.BatchArgsSetter; import com.mybatisflex.core.row.Db; import com.mybatisflex.core.row.Row; import com.mybatisflex.core.row.RowKey; @@ -57,6 +58,25 @@ public class DbTestStarter { //查看刚刚插入数据的主键 id System.out.println(">>>>>>>>>id: " + row.get("id")); + //INSERT INTO tb_account + //VALUES (1, '张三', 18, 0,'2020-01-11', null,0), + + Db.updateBatch("insert into tb_account(user_name,age,birthday) values (?,?,?)", new BatchArgsSetter() { + @Override + public int getBatchSize() { + return 10; + } + + @Override + public Object[] getSqlArgs(int index) { + Object[] args = new Object[3]; + args[0] = "michael yang"; + args[1] = 18 + index; + args[2] = new Date(); + return args; + } + }); + //再次查询全部数据 rows = Db.selectAll("tb_account");