diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/DataSourceKey.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/DataSourceKey.java index bc87eb83..97da09f7 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/DataSourceKey.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/DataSourceKey.java @@ -25,7 +25,7 @@ public class DataSourceKey { private static final ThreadLocal keyThreadLocal = new ThreadLocal<>(); public static void use(String dataSourceKey) { - keyThreadLocal.set(dataSourceKey); + keyThreadLocal.set(dataSourceKey.trim()); } public static T use(String dataSourceKey, Supplier supplier) { diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/FlexDataSource.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/FlexDataSource.java index 32949192..29c7472b 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/FlexDataSource.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/FlexDataSource.java @@ -30,15 +30,15 @@ import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.sql.Connection; import java.sql.SQLException; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; /** * @author michael */ public class FlexDataSource extends AbstractDataSource { + private static final char LOAD_BALANCE_KEY_SUFFIX = '*'; private static final Log log = LogFactory.getLog(FlexDataSource.class); private final Map dataSourceMap = new HashMap<>(); @@ -194,9 +194,29 @@ public class FlexDataSource extends AbstractDataSource { if (dataSourceMap.size() > 1) { String dataSourceKey = DataSourceKey.get(); if (StringUtil.isNotBlank(dataSourceKey)) { - dataSource = dataSourceMap.get(dataSourceKey); - if (dataSource == null) { - throw new IllegalStateException("Cannot get target DataSource for dataSourceKey [" + dataSourceKey + "]"); + //负载均衡 key + if (dataSourceKey.charAt(dataSourceKey.length() - 1) == LOAD_BALANCE_KEY_SUFFIX) { + String prefix = dataSourceKey.substring(0, dataSourceKey.length() - 1); + List matchedKeys = new ArrayList<>(); + for (String key : dataSourceMap.keySet()) { + if (key.startsWith(prefix)) { + matchedKeys.add(key); + } + } + + if (matchedKeys.isEmpty()) { + throw new IllegalStateException("Can not matched dataSource by key: \"" + dataSourceKey + "\""); + } + + String randomKey = matchedKeys.get(ThreadLocalRandom.current().nextInt(matchedKeys.size())); + return dataSourceMap.get(randomKey); + } + //非负载均衡 key + else { + dataSource = dataSourceMap.get(dataSourceKey); + if (dataSource == null) { + throw new IllegalStateException("Cannot get target dataSource by key: \"" + dataSourceKey + "\""); + } } } }