feat: add Db.updateBatch method; close #I6ZMPO

This commit is contained in:
开源海哥 2023-05-15 17:49:43 +08:00
parent 03f8e9136c
commit 63231e519f
4 changed files with 107 additions and 36 deletions

View File

@ -0,0 +1,25 @@
/**
* Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mybatisflex.core.row;
public interface BatchArgsSetter {
Object[] NONE_ARGS = new Object[0];
int getBatchSize();
Object[] getSqlArgs(int index);
}

View File

@ -86,7 +86,7 @@ public class Db {
* @param tableName 表名 * @param tableName 表名
* @param rows 数据 * @param rows 数据
*/ */
public static int[] insertBatch(String tableName, Collection<Row> rows) { public static int[] insertBatch(String tableName, List<Row> rows) {
return insertBatch(tableName, rows, rows.size()); return insertBatch(tableName, rows, rows.size());
} }
@ -97,12 +97,12 @@ public class Db {
* @param rows 数据 * @param rows 数据
* @param batchSize 每次提交的数据量 * @param batchSize 每次提交的数据量
*/ */
public static int[] insertBatch(String tableName, Collection<Row> rows, int batchSize) { public static int[] insertBatch(String tableName, List<Row> rows, int batchSize) {
return invoker().insertBatch(tableName, rows, batchSize); return invoker().insertBatch(tableName, rows, batchSize);
} }
/** /**
* 批量插入数据根据第一条内容来构建插入的字段效率比 {@link #insertBatch(String, Collection, int)} * 批量插入数据根据第一条内容来构建插入的字段效率比 {@link #insertBatch(String, List, int)}
* *
* @param tableName 表名 * @param tableName 表名
* @param rows 数据 * @param rows 数据
@ -196,6 +196,17 @@ public class Db {
} }
/**
*
* @param sql
* @param batchArgsSetter
* @return
*/
public static int[] updateBatch(String sql, BatchArgsSetter batchArgsSetter){
return invoker().updateBatch(sql, batchArgsSetter);
}
/** /**
* 根据 id 来更新数据 * 根据 id 来更新数据
* *

View File

@ -24,6 +24,7 @@ import org.apache.ibatis.session.SqlSessionFactory;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Function; import java.util.function.Function;
public class RowMapperInvoker { public class RowMapperInvoker {
@ -50,39 +51,11 @@ public class RowMapperInvoker {
return execute(mapper -> mapper.insertBySql(sql, args)); return execute(mapper -> mapper.insertBySql(sql, args));
} }
public int[] insertBatch(String tableName, List<Row> rows, int batchSize) {
public int[] insertBatch(String tableName, Collection<Row> rows, int batchSize) { return executeBatch(rows.size(), batchSize, (mapper, index) -> {
int[] results = new int[rows.size()]; Row row = rows.get(index);
try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, true)) { mapper.insert(tableName, row);
RowMapper mapper = sqlSession.getMapper(RowMapper.class); });
int counter = 0;
int resultsPos = 0;
for (Row row : rows) {
if (++counter > batchSize) {
counter = 0;
List<BatchResult> batchResults = sqlSession.flushStatements();
for (BatchResult batchResult : batchResults) {
int[] updateCounts = batchResult.getUpdateCounts();
for (int updateCount : updateCounts) {
results[resultsPos++] = updateCount;
}
}
} else {
mapper.insert(tableName, row);
}
}
if (counter != 0) {
List<BatchResult> batchResults = sqlSession.flushStatements();
for (BatchResult batchResult : batchResults) {
int[] updateCounts = batchResult.getUpdateCounts();
for (int updateCount : updateCounts) {
results[resultsPos++] = updateCount;
}
}
}
}
return results;
} }
public int insertBatchWithFirstRowColumns(String tableName, List<Row> rows) { public int insertBatchWithFirstRowColumns(String tableName, List<Row> rows) {
@ -114,6 +87,48 @@ public class RowMapperInvoker {
return execute(mapper -> mapper.updateBySql(sql, args)); return execute(mapper -> mapper.updateBySql(sql, args));
} }
public int[] updateBatch(String sql, BatchArgsSetter batchArgsSetter) {
int batchSize = batchArgsSetter.getBatchSize();
return executeBatch(batchSize, batchSize,
(mapper, index) -> mapper.updateBySql(sql, batchArgsSetter.getSqlArgs(index))
);
}
public int[] executeBatch(int totalSize, int batchSize, BiConsumer<RowMapper, Integer> consumer) {
int[] results = new int[totalSize];
try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, true)) {
RowMapper mapper = sqlSession.getMapper(RowMapper.class);
int counter = 0;
int resultsPos = 0;
for (int i = 0; i < batchSize; i++) {
if (++counter > batchSize) {
counter = 0;
List<BatchResult> batchResults = sqlSession.flushStatements();
for (BatchResult batchResult : batchResults) {
int[] updateCounts = batchResult.getUpdateCounts();
for (int updateCount : updateCounts) {
results[resultsPos++] = updateCount;
}
}
} else {
consumer.accept(mapper, i);
}
}
if (counter != 0) {
List<BatchResult> batchResults = sqlSession.flushStatements();
for (BatchResult batchResult : batchResults) {
int[] updateCounts = batchResult.getUpdateCounts();
for (int updateCount : updateCounts) {
results[resultsPos++] = updateCount;
}
}
}
}
return results;
}
public int updateById(String tableName, Row row) { public int updateById(String tableName, Row row) {
return execute(mapper -> mapper.updateById(tableName, row)); return execute(mapper -> mapper.updateById(tableName, row));
} }

View File

@ -16,6 +16,7 @@
package com.mybatisflex.test; package com.mybatisflex.test;
import com.mybatisflex.core.MybatisFlexBootstrap; import com.mybatisflex.core.MybatisFlexBootstrap;
import com.mybatisflex.core.row.BatchArgsSetter;
import com.mybatisflex.core.row.Db; import com.mybatisflex.core.row.Db;
import com.mybatisflex.core.row.Row; import com.mybatisflex.core.row.Row;
import com.mybatisflex.core.row.RowKey; import com.mybatisflex.core.row.RowKey;
@ -57,6 +58,25 @@ public class DbTestStarter {
//查看刚刚插入数据的主键 id //查看刚刚插入数据的主键 id
System.out.println(">>>>>>>>>id: " + row.get("id")); System.out.println(">>>>>>>>>id: " + row.get("id"));
//INSERT INTO tb_account
//VALUES (1, '张三', 18, 0,'2020-01-11', null,0),
Db.updateBatch("insert into tb_account(user_name,age,birthday) values (?,?,?)", new BatchArgsSetter() {
@Override
public int getBatchSize() {
return 10;
}
@Override
public Object[] getSqlArgs(int index) {
Object[] args = new Object[3];
args[0] = "michael yang";
args[1] = 18 + index;
args[2] = new Date();
return args;
}
});
//再次查询全部数据 //再次查询全部数据
rows = Db.selectAll("tb_account"); rows = Db.selectAll("tb_account");