diff --git a/mybatis-flex-core/src/main/java/com/mybatisflex/core/mybatis/SqlArgsParameterHandler.java b/mybatis-flex-core/src/main/java/com/mybatisflex/core/mybatis/SqlArgsParameterHandler.java index c286f32b..c2b05cb0 100644 --- a/mybatis-flex-core/src/main/java/com/mybatisflex/core/mybatis/SqlArgsParameterHandler.java +++ b/mybatis-flex-core/src/main/java/com/mybatisflex/core/mybatis/SqlArgsParameterHandler.java @@ -16,54 +16,69 @@ package com.mybatisflex.core.mybatis; import com.mybatisflex.core.FlexConsts; +import com.mybatisflex.core.exception.FlexExceptions; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.scripting.defaults.DefaultParameterHandler; +import org.apache.ibatis.type.TypeHandler; +import org.apache.ibatis.type.TypeHandlerRegistry; import java.sql.PreparedStatement; import java.sql.SQLException; -import java.util.Date; +import java.sql.Types; import java.util.Map; +/** + * 向 {@link PreparedStatement} 中的占位符设置值。 + * + * @author michael + * @author 王帅 + */ public class SqlArgsParameterHandler extends DefaultParameterHandler { - private final Map parameterObject; + private final TypeHandlerRegistry typeHandlerRegistry; - - public SqlArgsParameterHandler(MappedStatement mappedStatement, Map parameterObject, BoundSql boundSql) { + public SqlArgsParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) { super(mappedStatement, parameterObject, boundSql); - this.parameterObject = parameterObject; + this.typeHandlerRegistry = mappedStatement.getConfiguration().getTypeHandlerRegistry(); } - @Override public void setParameters(PreparedStatement ps) { try { doSetParameters(ps); } catch (SQLException e) { - throw new RuntimeException(e); + throw FlexExceptions.wrap(e); } } - + @SuppressWarnings({"rawtypes", "unchecked"}) private void doSetParameters(PreparedStatement ps) throws SQLException { - Object[] sqlArgs = (Object[]) ((Map) parameterObject).get(FlexConsts.SQL_ARGS); + Object[] sqlArgs = (Object[]) ((Map) getParameterObject()).get(FlexConsts.SQL_ARGS); if (sqlArgs != null && sqlArgs.length > 0) { int index = 1; for (Object value : sqlArgs) { - //通过配置的 TypeHandler 去设置内容 + // 设置 NULL 值 + if (value == null) { + ps.setNull(index++, Types.NULL); + continue; + } + + // 通过配置的 TypeHandler 去设置值 if (value instanceof TypeHandlerObject) { ((TypeHandlerObject) value).setParameter(ps, index++); + continue; } - //在 Oracle、SqlServer 中 TIMESTAMP、DATE 类型的数据是支持 java.util.Date 给值的 - else if (value instanceof java.util.Date) { - setDateParameter(ps, (Date) value, index++); - } else if (value instanceof byte[]) { - ps.setBytes(index++, (byte[]) value); + + TypeHandler typeHandler = typeHandlerRegistry.getTypeHandler(value.getClass()); + if (typeHandler != null) { + // 通过对应的 TypeHandler 去设置值 + typeHandler.setParameter(ps, index++, value, null); } else { - /** 在 MySql,Oracle 等驱动中,通过 PreparedStatement.setObject 后,驱动会自动根据 value 内容进行转换 + /* + * 在 MySql,Oracle 等驱动中,通过 PreparedStatement.setObject 后,驱动会自动根据 value 内容进行转换 * 源码可参考: {{@link com.mysql.jdbc.PreparedStatement#setObject(int, Object)} - **/ + */ ps.setObject(index++, value); } } @@ -72,23 +87,4 @@ public class SqlArgsParameterHandler extends DefaultParameterHandler { } } - /** - * Oracle、SqlServer 需要主动设置下 date 类型 - * MySql 通过 setObject 后会自动转换,具体查看 MySql 驱动源码 - * - * @param ps PreparedStatement - * @param value date value - * @param index set to index - * @throws SQLException - */ - private void setDateParameter(PreparedStatement ps, Date value, int index) throws SQLException { - if (value instanceof java.sql.Date) { - ps.setDate(index, (java.sql.Date) value); - } else if (value instanceof java.sql.Timestamp) { - ps.setTimestamp(index, (java.sql.Timestamp) value); - } else { - ps.setTimestamp(index, new java.sql.Timestamp(value.getTime())); - } - } - } diff --git a/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTest.java b/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTest.java index 4ddcf363..282d734d 100644 --- a/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTest.java +++ b/mybatis-flex-test/mybatis-flex-native-test/src/main/java/com/mybatisflex/test/DbTest.java @@ -17,15 +17,13 @@ package com.mybatisflex.test; import com.mybatisflex.core.MybatisFlexBootstrap; -import com.mybatisflex.core.audit.AuditManager; import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.row.Db; import com.mybatisflex.core.row.Row; -import com.mybatisflex.core.row.RowKey; import com.mybatisflex.core.row.RowUtil; -import com.mybatisflex.core.update.RawValue; import com.mybatisflex.core.update.UpdateWrapper; import com.mybatisflex.core.util.UpdateEntity; +import org.apache.ibatis.logging.stdout.StdOutImpl; import org.apache.ibatis.session.Configuration; import org.junit.Assert; import org.junit.BeforeClass; @@ -53,6 +51,7 @@ public class DbTest { .build(); MybatisFlexBootstrap bootstrap = MybatisFlexBootstrap.getInstance() + .setLogImpl(StdOutImpl.class) .setDataSource(dataSource) .start(); @@ -63,8 +62,6 @@ public class DbTest { */ Configuration configuration = bootstrap.getConfiguration(); configuration.setCallSettersOnNulls(true); - - Db.updateBySql("update tb_account set options = null;"); } @SuppressWarnings("all") @@ -72,6 +69,8 @@ public class DbTest { @Test public void test01() { + Db.updateBySql("update tb_account set options = null;"); + List rows = Db.selectAll(tb_account); rows.stream() @@ -88,6 +87,7 @@ public class DbTest { assert map.equals(map2); } + @Test public void test03() { try { @@ -105,9 +105,21 @@ public class DbTest { account3.setAge(4); accounts.add(account3); Db.updateEntitiesBatch(accounts); - }catch (Exception e){ + } catch (Exception e) { assert false; } } + @Test + public void testTypeHandler() { + QueryWrapper queryWrapper = QueryWrapper.create() + .select("*") + .from("tb_account") + .where("age = ?", 3); + + List rows = Db.selectListByQuery(queryWrapper); + + RowUtil.printPretty(rows); + } + }