diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/FlexDataSource.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/FlexDataSource.java index 10dafd79..e28ca7e2 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/FlexDataSource.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/datasource/FlexDataSource.java @@ -53,16 +53,23 @@ public class FlexDataSource extends AbstractDataSource { } public FlexDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) { + this(dataSourceKey, dataSource, DbTypeUtil.getDbType(dataSource), needDecryptDataSource); + } + + public FlexDataSource(String dataSourceKey, DataSource dataSource, DbType dbType, boolean needDecryptDataSource){ if (needDecryptDataSource) { DataSourceManager.decryptDataSource(dataSource); } + // 处理dbType + dbType = Optional.ofNullable(dbType).orElse(DbTypeUtil.getDbType(dataSource)); + this.defaultDataSourceKey = dataSourceKey; this.defaultDataSource = dataSource; - this.defaultDbType = DbTypeUtil.getDbType(dataSource); + this.defaultDbType = dbType; dataSourceMap.put(dataSourceKey, dataSource); - dbTypeHashMap.put(dataSourceKey, defaultDbType); + dbTypeHashMap.put(dataSourceKey, dbType); } /** @@ -71,26 +78,37 @@ public class FlexDataSource extends AbstractDataSource { public void setDefaultDataSource(String dataSourceKey) { DataSource ds = dataSourceMap.get(dataSourceKey); - if (ds != null) { - this.defaultDataSourceKey = dataSourceKey; - this.defaultDataSource = ds; - this.defaultDbType = DbTypeUtil.getDbType(ds); - } else { + if (Objects.isNull(ds)) { throw new IllegalStateException("DataSource not found by key: \"" + dataSourceKey + "\""); } + + // 优先取缓存,否则根据数据源返回数据库类型 + DbType dbType = Optional.ofNullable(dbTypeHashMap.get(dataSourceKey)) + .orElse(DbTypeUtil.getDbType(ds)); + + this.defaultDataSourceKey = dataSourceKey; + this.defaultDataSource = ds; + this.defaultDbType = dbType; } public void addDataSource(String dataSourceKey, DataSource dataSource) { addDataSource(dataSourceKey, dataSource, true); } - public void addDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) { + addDataSource(dataSourceKey, dataSource, DbTypeUtil.getDbType(dataSource), needDecryptDataSource); + } + + public void addDataSource(String dataSourceKey, DataSource dataSource, DbType dbType,boolean needDecryptDataSource) { if (needDecryptDataSource) { DataSourceManager.decryptDataSource(dataSource); } + + dbType = Optional.ofNullable(dbTypeHashMap.get(dataSourceKey)) + .orElse(DbTypeUtil.getDbType(dataSource)); + dataSourceMap.put(dataSourceKey, dataSource); - dbTypeHashMap.put(dataSourceKey, DbTypeUtil.getDbType(dataSource)); + dbTypeHashMap.put(dataSourceKey, dbType); } diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/dialect/DbType.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/dialect/DbType.java index 4b260ee7..2f273804 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/dialect/DbType.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/dialect/DbType.java @@ -16,6 +16,10 @@ package com.mybatisflex.core.dialect; +import com.mybatisflex.core.util.StringUtil; + +import java.util.Arrays; + public enum DbType { /** @@ -263,4 +267,21 @@ public enum DbType { public String getName() { return name; } + + /** + * 根据数据库类型名称自动识别数据库类型 + * + * @param name 名称 + * @return 数据库类型 + */ + public static DbType findByName(String name) { + if (StringUtil.noText(name)) { + return null; + } + + return Arrays.stream(values()) + .filter(em -> em.getName().equalsIgnoreCase(name)) + .findFirst() + .orElse(null); + } } diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/dialect/impl/CommonsDialectImpl.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/dialect/impl/CommonsDialectImpl.java index 16eb206f..a4cdb4db 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/dialect/impl/CommonsDialectImpl.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/dialect/impl/CommonsDialectImpl.java @@ -19,6 +19,7 @@ import com.mybatisflex.core.dialect.IDialect; import com.mybatisflex.core.dialect.KeywordWrap; import com.mybatisflex.core.dialect.LimitOffsetProcessor; import com.mybatisflex.core.dialect.OperateType; +import com.mybatisflex.core.exception.FlexAssert; import com.mybatisflex.core.exception.FlexExceptions; import com.mybatisflex.core.exception.locale.LocalizedFormats; import com.mybatisflex.core.logicdelete.LogicDeleteManager; @@ -40,12 +41,7 @@ import com.mybatisflex.core.util.CollectionUtil; import com.mybatisflex.core.util.SqlUtil; import com.mybatisflex.core.util.StringUtil; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.StringJoiner; +import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -195,6 +191,7 @@ public class CommonsDialectImpl implements IDialect { @Override public String forDeleteById(String schema, String tableName, String[] primaryKeys) { + assertPrimaryKeysNotEmpty(primaryKeys); String table = getRealTable(tableName, OperateType.DELETE); StringBuilder sql = new StringBuilder(); sql.append(DELETE_FROM); @@ -216,6 +213,7 @@ public class CommonsDialectImpl implements IDialect { @Override public String forDeleteBatchByIds(String schema, String tableName, String[] primaryKeys, Object[] ids) { + assertPrimaryKeysNotEmpty(primaryKeys); String table = getRealTable(tableName, OperateType.DELETE); StringBuilder sql = new StringBuilder(); sql.append(DELETE_FROM); @@ -268,6 +266,7 @@ public class CommonsDialectImpl implements IDialect { Set modifyAttrs = RowCPI.getModifyAttrs(row); Map rawValueMap = RowCPI.getRawValueMap(row); String[] primaryKeys = RowCPI.obtainsPrimaryKeyStrings(row); + assertPrimaryKeysNotEmpty(primaryKeys); sql.append(UPDATE); if (StringUtil.hasText(schema)) { @@ -369,6 +368,7 @@ public class CommonsDialectImpl implements IDialect { @Override public String forSelectOneById(String schema, String tableName, String[] primaryKeys, Object[] primaryValues) { + assertPrimaryKeysNotEmpty(primaryKeys); String table = getRealTable(tableName, OperateType.SELECT); StringBuilder sql = new StringBuilder(SELECT_ALL_FROM); if (StringUtil.hasText(schema)) { @@ -701,16 +701,17 @@ public class CommonsDialectImpl implements IDialect { String logicDeleteColumn = tableInfo.getLogicDeleteColumnOrSkip(); Object[] tenantIdArgs = tableInfo.buildTenantIdArgs(); + String[] primaryKeys = tableInfo.getPrimaryColumns(); + assertPrimaryKeysNotEmpty(primaryKeys); + // 正常删除 if (StringUtil.noText(logicDeleteColumn)) { - String deleteByIdSql = forDeleteById(tableInfo.getSchema(), tableInfo.getTableName(), tableInfo.getPrimaryColumns()); + String deleteByIdSql = forDeleteById(tableInfo.getSchema(), tableInfo.getTableName(), primaryKeys); return tableInfo.buildTenantCondition(deleteByIdSql, tenantIdArgs, this); } // 逻辑删除 StringBuilder sql = new StringBuilder(); - String[] primaryKeys = tableInfo.getPrimaryColumns(); - sql.append(UPDATE).append(tableInfo.getWrapSchemaAndTableName(this, OperateType.UPDATE)); sql.append(SET).append(buildLogicDeletedSet(logicDeleteColumn, tableInfo)); sql.append(WHERE); @@ -735,9 +736,12 @@ public class CommonsDialectImpl implements IDialect { String logicDeleteColumn = tableInfo.getLogicDeleteColumnOrSkip(); Object[] tenantIdArgs = tableInfo.buildTenantIdArgs(); + String[] primaryKeys = tableInfo.getPrimaryColumns(); + assertPrimaryKeysNotEmpty(primaryKeys); + // 正常删除 if (StringUtil.noText(logicDeleteColumn)) { - String deleteSQL = forDeleteBatchByIds(tableInfo.getSchema(), tableInfo.getTableName(), tableInfo.getPrimaryColumns(), primaryValues); + String deleteSQL = forDeleteBatchByIds(tableInfo.getSchema(), tableInfo.getTableName(), primaryKeys, primaryValues); // 多租户 if (ArrayUtil.isNotEmpty(tenantIdArgs)) { @@ -754,8 +758,6 @@ public class CommonsDialectImpl implements IDialect { sql.append(WHERE); sql.append(BRACKET_LEFT); - String[] primaryKeys = tableInfo.getPrimaryColumns(); - // 多主键的场景 if (primaryKeys.length > 1) { for (int i = 0; i < primaryValues.length / primaryKeys.length; i++) { @@ -832,6 +834,7 @@ public class CommonsDialectImpl implements IDialect { Set updateColumns = tableInfo.obtainUpdateColumns(entity, ignoreNulls, false); Map rawValueMap = tableInfo.obtainUpdateRawValueMap(entity); String[] primaryKeys = tableInfo.getPrimaryColumns(); + assertPrimaryKeysNotEmpty(primaryKeys); sql.append(UPDATE).append(tableInfo.getWrapSchemaAndTableName(this, OperateType.UPDATE)).append(SET); @@ -966,6 +969,8 @@ public class CommonsDialectImpl implements IDialect { sql.append(FROM).append(tableInfo.getWrapSchemaAndTableName(this, OperateType.SELECT)); sql.append(WHERE); String[] pKeys = tableInfo.getPrimaryColumns(); + assertPrimaryKeysNotEmpty(pKeys); + for (int i = 0; i < pKeys.length; i++) { if (i > 0) { sql.append(AND); @@ -994,6 +999,7 @@ public class CommonsDialectImpl implements IDialect { sql.append(FROM).append(tableInfo.getWrapSchemaAndTableName(this, OperateType.SELECT)); sql.append(WHERE); String[] primaryKeys = tableInfo.getPrimaryColumns(); + assertPrimaryKeysNotEmpty(primaryKeys); String logicDeleteColumn = tableInfo.getLogicDeleteColumnOrSkip(); Object[] tenantIdArgs = tableInfo.buildTenantIdArgs(); @@ -1138,5 +1144,14 @@ public class CommonsDialectImpl implements IDialect { return LogicDeleteManager.getProcessor().buildLogicDeletedSet(logicColumn, tableInfo, this); } - + /** + * 断言主键非空 + * + * @param primaryKeys 主键 + */ + protected void assertPrimaryKeysNotEmpty(String[] primaryKeys) { + if (Objects.isNull(primaryKeys) || primaryKeys.length == 0 || Arrays.stream(primaryKeys).allMatch(String::isEmpty)) { + throw FlexExceptions.wrap("primary key not recognized! Please check the @com.mybatisflex.annotation.Id annotation"); + } + } } diff --git a/mybatis-flex-spring-boot-starter/src/main/java/com/mybatisflex/spring/boot/MultiDataSourceAutoConfiguration.java b/mybatis-flex-spring-boot-starter/src/main/java/com/mybatisflex/spring/boot/MultiDataSourceAutoConfiguration.java index cea0bf19..939da577 100644 --- a/mybatis-flex-spring-boot-starter/src/main/java/com/mybatisflex/spring/boot/MultiDataSourceAutoConfiguration.java +++ b/mybatis-flex-spring-boot-starter/src/main/java/com/mybatisflex/spring/boot/MultiDataSourceAutoConfiguration.java @@ -19,6 +19,8 @@ import com.mybatisflex.core.datasource.DataSourceBuilder; import com.mybatisflex.core.datasource.DataSourceDecipher; import com.mybatisflex.core.datasource.DataSourceManager; import com.mybatisflex.core.datasource.FlexDataSource; +import com.mybatisflex.core.dialect.DbType; +import com.mybatisflex.core.dialect.DbTypeUtil; import com.mybatisflex.core.exception.FlexExceptions; import com.mybatisflex.core.util.MapUtil; import com.mybatisflex.spring.boot.MybatisFlexProperties.SeataConfig; @@ -41,6 +43,7 @@ import org.springframework.context.annotation.Role; import javax.sql.DataSource; import java.util.Map; +import java.util.Optional; /** * MyBatis-Flex 多数据源的配置支持。 @@ -91,7 +94,8 @@ public class MultiDataSourceAutoConfiguration { if (master != null) { Map map = dataSourceProperties.remove(master); if (map != null) { - flexDataSource = addDataSource(MapUtil.entry(master, map), flexDataSource); + // 这里创建master时,flexDataSource一定是null + flexDataSource = addDataSource(MapUtil.entry(master, map), null); } else { throw FlexExceptions.wrap("没有找到默认数据源 \"%s\" 对应的配置,请检查您的多数据源配置。", master); } @@ -109,18 +113,26 @@ public class MultiDataSourceAutoConfiguration { DataSource dataSource = new DataSourceBuilder(entry.getValue()).build(); DataSourceManager.decryptDataSource(dataSource); + // 数据库类型 + DbType dbType = null; if (seataConfig != null && seataConfig.isEnable()) { if (seataConfig.getSeataMode() == MybatisFlexProperties.SeataMode.XA) { - dataSource = new DataSourceProxyXA(dataSource); + DataSourceProxyXA sourceProxyXa = new DataSourceProxyXA(dataSource); + dbType = DbType.findByName(sourceProxyXa.getDbType()); + dataSource = sourceProxyXa; } else { - dataSource = new DataSourceProxy(dataSource); + DataSourceProxy dataSourceProxy = new DataSourceProxy(dataSource); + dbType = DbType.findByName(dataSourceProxy.getDbType()); + dataSource = dataSourceProxy; } } + // 如果没有构建成功dbType,需要自解析 + dbType = Optional.ofNullable(dbType).orElse(DbTypeUtil.getDbType(dataSource)); if (flexDataSource == null) { - flexDataSource = new FlexDataSource(entry.getKey(), dataSource, false); + flexDataSource = new FlexDataSource(entry.getKey(), dataSource, dbType, false); } else { - flexDataSource.addDataSource(entry.getKey(), dataSource, false); + flexDataSource.addDataSource(entry.getKey(), dataSource, dbType, false); } return flexDataSource; }