From 012381510d251bfdfad8df8274a9bb1d1dd4d013 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 23 Aug 2024 16:04:18 +0800 Subject: [PATCH] feat: add support nested types for MappedStatementTypes.java --- .../core/mybatis/MappedStatementTypes.java | 17 +++++++--- .../core/util/MappedStatementTypesTest.java | 32 +++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 mybatis-flex-core/src/test/java/com/mybatisflex/core/util/MappedStatementTypesTest.java diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/mybatis/MappedStatementTypes.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/mybatis/MappedStatementTypes.java index bcba101c..903a2cbf 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/mybatis/MappedStatementTypes.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/mybatis/MappedStatementTypes.java @@ -15,23 +15,32 @@ */ package com.mybatisflex.core.mybatis; +import java.util.Stack; + public class MappedStatementTypes { private MappedStatementTypes() { } - private static final ThreadLocal> currentTypeTL = new ThreadLocal<>(); + private static final ThreadLocal>> currentTypeTL = ThreadLocal.withInitial(Stack::new); public static void setCurrentType(Class type) { - currentTypeTL.set(type); + currentTypeTL.get().push(type); } public static Class getCurrentType() { - return currentTypeTL.get(); + Stack> stack = currentTypeTL.get(); + return stack.isEmpty() ? null : stack.lastElement(); } public static void clear() { - currentTypeTL.remove(); + Stack> stack = currentTypeTL.get(); + if (!stack.isEmpty()) { + stack.pop(); + } + if (stack.isEmpty()) { + currentTypeTL.remove(); + } } } diff --git a/mybatis-flex-core/src/test/java/com/mybatisflex/core/util/MappedStatementTypesTest.java b/mybatis-flex-core/src/test/java/com/mybatisflex/core/util/MappedStatementTypesTest.java new file mode 100644 index 00000000..25e6330f --- /dev/null +++ b/mybatis-flex-core/src/test/java/com/mybatisflex/core/util/MappedStatementTypesTest.java @@ -0,0 +1,32 @@ +package com.mybatisflex.core.util; + +import com.mybatisflex.core.mybatis.MappedStatementTypes; +import org.junit.Assert; +import org.junit.Test; + +public class MappedStatementTypesTest { + + @Test + public void test() { + MappedStatementTypes.clear(); + + MappedStatementTypes.setCurrentType(String.class); + MappedStatementTypes.setCurrentType(MappedStatementTypesTest.class); + MappedStatementTypes.setCurrentType(StringUtilTest.class); + + Assert.assertEquals(StringUtilTest.class, MappedStatementTypes.getCurrentType()); + System.out.println(MappedStatementTypes.getCurrentType()); + MappedStatementTypes.clear(); + + Assert.assertEquals(MappedStatementTypesTest.class, MappedStatementTypes.getCurrentType()); + System.out.println(MappedStatementTypes.getCurrentType()); + MappedStatementTypes.clear(); + + Assert.assertEquals(String.class, MappedStatementTypes.getCurrentType()); + System.out.println(MappedStatementTypes.getCurrentType()); + MappedStatementTypes.clear(); + + Assert.assertNull(MappedStatementTypes.getCurrentType()); + System.out.println(MappedStatementTypes.getCurrentType()); + } +}