add Db.tx() method

This commit is contained in:
开源海哥 2023-03-31 18:22:10 +08:00
parent 02cc1d0263
commit 5ff724e119
6 changed files with 132 additions and 6 deletions

View File

@ -19,13 +19,19 @@ import com.mybatisflex.core.dialect.DbType;
import com.mybatisflex.core.dialect.DbTypeUtil; import com.mybatisflex.core.dialect.DbTypeUtil;
import com.mybatisflex.core.transaction.TransactionContext; import com.mybatisflex.core.transaction.TransactionContext;
import com.mybatisflex.core.transaction.TransactionalManager; import com.mybatisflex.core.transaction.TransactionalManager;
import com.mybatisflex.core.util.ArrayUtil;
import com.mybatisflex.core.util.StringUtil; import com.mybatisflex.core.util.StringUtil;
import javax.sql.DataSource; import javax.sql.DataSource;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.Connection; import java.sql.Connection;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Objects;
public class FlexDataSource extends AbstractDataSource { public class FlexDataSource extends AbstractDataSource {
@ -64,7 +70,7 @@ public class FlexDataSource extends AbstractDataSource {
if (connection != null) { if (connection != null) {
return connection; return connection;
} else { } else {
connection = getDataSource().getConnection(); connection = proxy(getDataSource().getConnection(), xid);
TransactionalManager.hold(xid, dataSourceKey, connection); TransactionalManager.hold(xid, dataSourceKey, connection);
return connection; return connection;
} }
@ -86,7 +92,7 @@ public class FlexDataSource extends AbstractDataSource {
if (connection != null) { if (connection != null) {
return connection; return connection;
} else { } else {
connection = getDataSource().getConnection(username, password); connection = proxy(getDataSource().getConnection(username, password), xid);
TransactionalManager.hold(xid, dataSourceKey, connection); TransactionalManager.hold(xid, dataSourceKey, connection);
return connection; return connection;
} }
@ -95,6 +101,12 @@ public class FlexDataSource extends AbstractDataSource {
} }
} }
public Connection proxy(Connection connection, String xid) {
return (Connection) Proxy.newProxyInstance(FlexDataSource.class.getClassLoader()
, new Class[]{Connection.class}
, new ConnectionHandler(connection, xid));
}
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -125,5 +137,31 @@ public class FlexDataSource extends AbstractDataSource {
return dataSource; return dataSource;
} }
private static class ConnectionHandler implements InvocationHandler {
private static String[] proxyMethods = new String[]{"commit", "rollback", "close",};
private Connection original;
private String xid;
public ConnectionHandler(Connection original, String xid) {
this.original = original;
this.xid = xid;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (ArrayUtil.contains(proxyMethods, method.getName())
&& isTransactional()) {
//do nothing
return null;
}
System.out.println(">>>>>>invoke: " + method.getName() + " args: " + Arrays.toString(args));
return method.invoke(original, args);
}
private boolean isTransactional() {
return Objects.equals(xid, TransactionContext.getXID());
}
}
} }

View File

@ -399,23 +399,26 @@ public class Db {
try { try {
String xid = UUID.randomUUID().toString(); String xid = UUID.randomUUID().toString();
TransactionContext.hold(xid); TransactionContext.hold(xid);
boolean success = false; Boolean success = false;
boolean rollbacked = false; boolean rollbacked = false;
try { try {
success = supplier.get(); success = supplier.get();
} catch (Exception e) { } catch (Exception e) {
rollbacked = true; rollbacked = true;
TransactionContext.release();
TransactionalManager.rollback(xid); TransactionalManager.rollback(xid);
e.printStackTrace(); e.printStackTrace();
} finally { } finally {
if (success) { if (success != null && success) {
//必须优先 release xid才能正常 commit()
TransactionContext.release();
TransactionalManager.commit(xid); TransactionalManager.commit(xid);
} else if (!rollbacked) { } else if (!rollbacked) {
TransactionContext.release();
TransactionalManager.rollback(xid); TransactionalManager.rollback(xid);
} }
TransactionContext.release();
} }
return success; return success != null && success;
} finally { } finally {
//恢复上一级事务 //恢复上一级事务
if (prevXID != null) { if (prevXID != null) {

View File

@ -53,6 +53,7 @@ public class TransactionalManager {
log.debug("Error set AutoCommit to false. Cause: " + e); log.debug("Error set AutoCommit to false. Cause: " + e);
} }
} }
connMap.put(ds, connection);
} }

View File

@ -0,0 +1,73 @@
/**
* Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mybatisflex.test;
import com.mybatisflex.core.MybatisFlexBootstrap;
import com.mybatisflex.core.datasource.DataSourceKey;
import com.mybatisflex.core.row.Db;
import com.mybatisflex.core.row.Row;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import javax.sql.DataSource;
import java.util.List;
public class MultiDataSourceTester {
public static void main(String[] args) {
DataSource dataSource = new EmbeddedDatabaseBuilder()
.setType(EmbeddedDatabaseType.H2)
.setName("db1")
.addScript("schema.sql")
.addScript("data.sql")
.build();
DataSource dataSource2 = new EmbeddedDatabaseBuilder()
.setType(EmbeddedDatabaseType.H2)
.setName("db2")
.addScript("schema02.sql")
.addScript("data02.sql")
.build();
MybatisFlexBootstrap.getInstance()
.setDataSource(dataSource)
.addDataSource("ds2", dataSource2)
.start();
//默认查询 db1
List<Row> rows = Db.selectAll("tb_account");
System.out.println(rows);
System.out.println("------");
//查询数据源 ds2
DataSourceKey.use("ds2");
rows = Db.selectAll("tb_account");
System.out.println(rows);
boolean success = Db.tx(() -> {
Db.updateById("tb_account",Row.ofKey("id",1)
.set("user_name","测试的user"));
return false;
});
System.out.println("tx: " + success);
rows = Db.selectAll("tb_account");
System.out.println(rows);
}
}

View File

@ -0,0 +1,3 @@
INSERT INTO tb_account
VALUES (1, 'zhang', 18, '2020-01-11', null),
(2, 'wang', 19, '2021-03-21', null);

View File

@ -0,0 +1,8 @@
CREATE TABLE IF NOT EXISTS `tb_account`
(
`id` INTEGER PRIMARY KEY auto_increment,
`user_name` VARCHAR(100),
`age` Integer,
`birthday` DATETIME,
`options` VARCHAR(1024)
);