feat: add "Load Balance" for multi dataSource

This commit is contained in:
开源海哥 2023-07-27 10:35:43 +08:00
parent 665f3fc7a4
commit 66d50d1621
2 changed files with 27 additions and 7 deletions

View File

@ -25,7 +25,7 @@ public class DataSourceKey {
private static final ThreadLocal<String> keyThreadLocal = new ThreadLocal<>(); private static final ThreadLocal<String> keyThreadLocal = new ThreadLocal<>();
public static void use(String dataSourceKey) { public static void use(String dataSourceKey) {
keyThreadLocal.set(dataSourceKey); keyThreadLocal.set(dataSourceKey.trim());
} }
public static <T> T use(String dataSourceKey, Supplier<T> supplier) { public static <T> T use(String dataSourceKey, Supplier<T> supplier) {

View File

@ -30,15 +30,15 @@ import java.lang.reflect.Method;
import java.lang.reflect.Proxy; import java.lang.reflect.Proxy;
import java.sql.Connection; import java.sql.Connection;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.HashMap; import java.util.*;
import java.util.Map; import java.util.concurrent.ThreadLocalRandom;
import java.util.Objects;
/** /**
* @author michael * @author michael
*/ */
public class FlexDataSource extends AbstractDataSource { public class FlexDataSource extends AbstractDataSource {
private static final char LOAD_BALANCE_KEY_SUFFIX = '*';
private static final Log log = LogFactory.getLog(FlexDataSource.class); private static final Log log = LogFactory.getLog(FlexDataSource.class);
private final Map<String, DataSource> dataSourceMap = new HashMap<>(); private final Map<String, DataSource> dataSourceMap = new HashMap<>();
@ -194,9 +194,29 @@ public class FlexDataSource extends AbstractDataSource {
if (dataSourceMap.size() > 1) { if (dataSourceMap.size() > 1) {
String dataSourceKey = DataSourceKey.get(); String dataSourceKey = DataSourceKey.get();
if (StringUtil.isNotBlank(dataSourceKey)) { if (StringUtil.isNotBlank(dataSourceKey)) {
//负载均衡 key
if (dataSourceKey.charAt(dataSourceKey.length() - 1) == LOAD_BALANCE_KEY_SUFFIX) {
String prefix = dataSourceKey.substring(0, dataSourceKey.length() - 1);
List<String> matchedKeys = new ArrayList<>();
for (String key : dataSourceMap.keySet()) {
if (key.startsWith(prefix)) {
matchedKeys.add(key);
}
}
if (matchedKeys.isEmpty()) {
throw new IllegalStateException("Can not matched dataSource by key: \"" + dataSourceKey + "\"");
}
String randomKey = matchedKeys.get(ThreadLocalRandom.current().nextInt(matchedKeys.size()));
return dataSourceMap.get(randomKey);
}
//非负载均衡 key
else {
dataSource = dataSourceMap.get(dataSourceKey); dataSource = dataSourceMap.get(dataSourceKey);
if (dataSource == null) { if (dataSource == null) {
throw new IllegalStateException("Cannot get target DataSource for dataSourceKey [" + dataSourceKey + "]"); throw new IllegalStateException("Cannot get target dataSource by key: \"" + dataSourceKey + "\"");
}
} }
} }
} }