diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/MybatisFlexBootstrap.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/MybatisFlexBootstrap.java index db31351a..c48d1ed6 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/MybatisFlexBootstrap.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/MybatisFlexBootstrap.java @@ -27,25 +27,30 @@ import org.apache.ibatis.session.SqlSession; import org.apache.ibatis.session.SqlSessionFactory; import org.apache.ibatis.transaction.TransactionFactory; import org.apache.ibatis.transaction.jdbc.JdbcTransactionFactory; +import org.apache.ibatis.util.MapUtil; import javax.sql.DataSource; +import java.lang.reflect.Proxy; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import java.util.function.Supplier; /** * MybatisFlex 的启动类 * * * MybatisFlexBootstrap.getInstance() - * .setDatasource(...) - * .addMapper(...) - * .start(); + * .setDatasource(...) + * .addMapper(...) + * .start(); *

*

* MybatisFlexBootstrap.getInstance() - * .execute(...) + * .execute(...) * */ public class MybatisFlexBootstrap { @@ -62,6 +67,8 @@ public class MybatisFlexBootstrap { protected DbType dbType; protected SqlSessionFactory sqlSessionFactory; protected Class logImpl; + private Map, Object> mapperObjects = new ConcurrentHashMap<>(); + private ThreadLocal sessionThreadLocal = new ThreadLocal<>(); /** * 虽然提供了 getInstance,但也允许用户进行实例化, @@ -132,6 +139,7 @@ public class MybatisFlexBootstrap { } + @Deprecated public R execute(Class mapperClass, Function function) { try (SqlSession sqlSession = openSession()) { DialectFactory.setHintDbType(dbType); @@ -143,11 +151,66 @@ public class MybatisFlexBootstrap { } - private SqlSession openSession() { + protected SqlSession openSession() { + SqlSession sqlSession = sessionThreadLocal.get(); + if (sqlSession != null) { + return sqlSession; + } return sqlSessionFactory.openSession(configuration.getDefaultExecutorType(), true); } + /** + * 直接获取 mapper 对象执行 + * @param mapperClass + * @return mapperObject + */ + public T getMapper(Class mapperClass) { + Object mapperObject = MapUtil.computeIfAbsent(mapperObjects, mapperClass, clazz -> + Proxy.newProxyInstance(MybatisFlexBootstrap.class.getClassLoader() + , new Class[]{mapperClass} + , (proxy, method, args) -> { + try (SqlSession sqlSession = openSession()) { + DialectFactory.setHintDbType(dbType); + T mapper1 = sqlSession.getMapper(mapperClass); + return method.invoke(mapper1, args); + } finally { + DialectFactory.clearHintDbType(); + } + })); + return (T) mapperObject; + } + + + /** + * 执行事务操作,不支持嵌套事务 + * + * @param supplier + * @return false 回滚事务,true 正常执行 + */ + public boolean tx(Supplier supplier) { + SqlSession sqlSession = sqlSessionFactory.openSession(configuration.getDefaultExecutorType()); + boolean success = false; + boolean rollback = true; + try { + sessionThreadLocal.set(sqlSession); + success = supplier.get(); + } catch (Throwable e) { + rollback = false; + sqlSession.rollback(); + } finally { + sessionThreadLocal.remove(); + if (!success && rollback) { + sqlSession.rollback(); + } else if (success) { + sqlSession.commit(); + } + } + return success; + } + + + public String getEnvironmentId() { return environmentId; } diff --git a/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/EntityTestStarter.java b/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/EntityTestStarter.java index a479ec9f..733cf98c 100644 --- a/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/EntityTestStarter.java +++ b/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/EntityTestStarter.java @@ -39,10 +39,13 @@ public class EntityTestStarter { // //查询 ID 为 1 的数据 - Account account = bootstrap.execute(AccountMapper.class, accountMapper -> - accountMapper.selectOneById(1)); - System.out.println(account); +// Account account = bootstrap.execute(AccountMapper.class, accountMapper -> +// accountMapper.selectOneById(1)); +// System.out.println(account); + AccountMapper accountMapper = bootstrap.getMapper(AccountMapper.class); + Account account = accountMapper.selectOneById(1); + System.out.println(account); // // List allAccount = bootstrap.execute(AccountMapper.class, accountMapper -> // accountMapper.selectListByQuery(QueryWrapper.create())); diff --git a/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/RowTestStarter.java b/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/RowTestStarter.java index 4cca784a..5a95a270 100644 --- a/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/RowTestStarter.java +++ b/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/RowTestStarter.java @@ -52,6 +52,7 @@ public class RowTestStarter { .set("user_name", "lisi") .set("age", 22) .set("birthday", new Date()); + bootstrap.execute(RowMapper.class, rowMapper -> rowMapper.insert("tb_account", newRow));