diff --git a/mybatis-flex-spring/src/main/java/com/mybatisflex/spring/datasource/DataSourceInterceptor.java b/mybatis-flex-spring/src/main/java/com/mybatisflex/spring/datasource/DataSourceInterceptor.java index 74dd63ec..5face08d 100644 --- a/mybatis-flex-spring/src/main/java/com/mybatisflex/spring/datasource/DataSourceInterceptor.java +++ b/mybatis-flex-spring/src/main/java/com/mybatisflex/spring/datasource/DataSourceInterceptor.java @@ -19,6 +19,7 @@ package com.mybatisflex.spring.datasource; import com.mybatisflex.annotation.UseDataSource; import com.mybatisflex.core.datasource.DataSourceKey; +import com.mybatisflex.core.util.StringUtil; import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInvocation; @@ -32,7 +33,17 @@ public class DataSourceInterceptor implements MethodInterceptor { @Override public Object invoke(MethodInvocation invocation) throws Throwable { - String dsKey = determineDataSourceKey(invocation); + + String dsKey = DataSourceKey.get(); + if (StringUtil.isNotBlank(dsKey)) { + return invocation.proceed(); + } + + dsKey = determineDataSourceKey(invocation); + if (StringUtil.isBlank(dsKey)) { + return invocation.proceed(); + } + DataSourceKey.use(dsKey); try { return invocation.proceed(); @@ -42,27 +53,35 @@ public class DataSourceInterceptor implements MethodInterceptor { } private String determineDataSourceKey(MethodInvocation invocation) { - UseDataSource annotation; - - Object aThis = invocation.getThis(); - - if (aThis != null) { - // 类上定义有 UseDataSource 注解 - Class aClass = aThis.getClass(); - annotation = aClass.getAnnotation(UseDataSource.class); - if (annotation != null) { - return annotation.value(); - } - } // 方法上定义有 UseDataSource 注解 - annotation = invocation.getMethod().getAnnotation(UseDataSource.class); + UseDataSource annotation = invocation.getMethod().getAnnotation(UseDataSource.class); if (annotation != null) { return annotation.value(); } - // 没有的话使用当前数据源 - return DataSourceKey.get(); + Object target = invocation.getThis(); + + if (target != null) { + // 类上定义有 UseDataSource 注解 + Class targetClass = target.getClass(); + annotation = targetClass.getAnnotation(UseDataSource.class); + if (annotation != null) { + return annotation.value(); + } + + // 接口上定义有 UseDataSource 注解 + Class[] interfaces = targetClass.getInterfaces(); + for (Class anInterface : interfaces) { + annotation = anInterface.getAnnotation(UseDataSource.class); + if (annotation != null) { + return annotation.value(); + } + } + } + + + return null; } } \ No newline at end of file