fix: 修复自定义 TypeHandler 无法设置值的问题。

This commit is contained in:
Suomm 2024-02-29 18:58:22 +08:00
parent 5df8972abe
commit 2d00037a43
2 changed files with 50 additions and 42 deletions

View File

@ -16,54 +16,69 @@
package com.mybatisflex.core.mybatis; package com.mybatisflex.core.mybatis;
import com.mybatisflex.core.FlexConsts; import com.mybatisflex.core.FlexConsts;
import com.mybatisflex.core.exception.FlexExceptions;
import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler; import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.type.TypeHandler;
import org.apache.ibatis.type.TypeHandlerRegistry;
import java.sql.PreparedStatement; import java.sql.PreparedStatement;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Date; import java.sql.Types;
import java.util.Map; import java.util.Map;
/**
* {@link PreparedStatement} 中的占位符设置值
*
* @author michael
* @author 王帅
*/
public class SqlArgsParameterHandler extends DefaultParameterHandler { public class SqlArgsParameterHandler extends DefaultParameterHandler {
private final Map parameterObject; private final TypeHandlerRegistry typeHandlerRegistry;
public SqlArgsParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
public SqlArgsParameterHandler(MappedStatement mappedStatement, Map parameterObject, BoundSql boundSql) {
super(mappedStatement, parameterObject, boundSql); super(mappedStatement, parameterObject, boundSql);
this.parameterObject = parameterObject; this.typeHandlerRegistry = mappedStatement.getConfiguration().getTypeHandlerRegistry();
} }
@Override @Override
public void setParameters(PreparedStatement ps) { public void setParameters(PreparedStatement ps) {
try { try {
doSetParameters(ps); doSetParameters(ps);
} catch (SQLException e) { } catch (SQLException e) {
throw new RuntimeException(e); throw FlexExceptions.wrap(e);
} }
} }
@SuppressWarnings({"rawtypes", "unchecked"})
private void doSetParameters(PreparedStatement ps) throws SQLException { private void doSetParameters(PreparedStatement ps) throws SQLException {
Object[] sqlArgs = (Object[]) ((Map<?, ?>) parameterObject).get(FlexConsts.SQL_ARGS); Object[] sqlArgs = (Object[]) ((Map) getParameterObject()).get(FlexConsts.SQL_ARGS);
if (sqlArgs != null && sqlArgs.length > 0) { if (sqlArgs != null && sqlArgs.length > 0) {
int index = 1; int index = 1;
for (Object value : sqlArgs) { for (Object value : sqlArgs) {
//通过配置的 TypeHandler 去设置内容 // 设置 NULL
if (value == null) {
ps.setNull(index++, Types.NULL);
continue;
}
// 通过配置的 TypeHandler 去设置值
if (value instanceof TypeHandlerObject) { if (value instanceof TypeHandlerObject) {
((TypeHandlerObject) value).setParameter(ps, index++); ((TypeHandlerObject) value).setParameter(ps, index++);
continue;
} }
// OracleSqlServer TIMESTAMPDATE 类型的数据是支持 java.util.Date 给值的
else if (value instanceof java.util.Date) { TypeHandler typeHandler = typeHandlerRegistry.getTypeHandler(value.getClass());
setDateParameter(ps, (Date) value, index++); if (typeHandler != null) {
} else if (value instanceof byte[]) { // 通过对应的 TypeHandler 去设置值
ps.setBytes(index++, (byte[]) value); typeHandler.setParameter(ps, index++, value, null);
} else { } else {
/** MySqlOracle 等驱动中通过 PreparedStatement.setObject 驱动会自动根据 value 内容进行转换 /*
* MySqlOracle 等驱动中通过 PreparedStatement.setObject 驱动会自动根据 value 内容进行转换
* 源码可参考 {{@link com.mysql.jdbc.PreparedStatement#setObject(int, Object)} * 源码可参考 {{@link com.mysql.jdbc.PreparedStatement#setObject(int, Object)}
**/ */
ps.setObject(index++, value); ps.setObject(index++, value);
} }
} }
@ -72,23 +87,4 @@ public class SqlArgsParameterHandler extends DefaultParameterHandler {
} }
} }
/**
* OracleSqlServer 需要主动设置下 date 类型
* MySql 通过 setObject 后会自动转换具体查看 MySql 驱动源码
*
* @param ps PreparedStatement
* @param value date value
* @param index set to index
* @throws SQLException
*/
private void setDateParameter(PreparedStatement ps, Date value, int index) throws SQLException {
if (value instanceof java.sql.Date) {
ps.setDate(index, (java.sql.Date) value);
} else if (value instanceof java.sql.Timestamp) {
ps.setTimestamp(index, (java.sql.Timestamp) value);
} else {
ps.setTimestamp(index, new java.sql.Timestamp(value.getTime()));
}
}
} }

View File

@ -17,15 +17,13 @@
package com.mybatisflex.test; package com.mybatisflex.test;
import com.mybatisflex.core.MybatisFlexBootstrap; import com.mybatisflex.core.MybatisFlexBootstrap;
import com.mybatisflex.core.audit.AuditManager;
import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.core.row.Db; import com.mybatisflex.core.row.Db;
import com.mybatisflex.core.row.Row; import com.mybatisflex.core.row.Row;
import com.mybatisflex.core.row.RowKey;
import com.mybatisflex.core.row.RowUtil; import com.mybatisflex.core.row.RowUtil;
import com.mybatisflex.core.update.RawValue;
import com.mybatisflex.core.update.UpdateWrapper; import com.mybatisflex.core.update.UpdateWrapper;
import com.mybatisflex.core.util.UpdateEntity; import com.mybatisflex.core.util.UpdateEntity;
import org.apache.ibatis.logging.stdout.StdOutImpl;
import org.apache.ibatis.session.Configuration; import org.apache.ibatis.session.Configuration;
import org.junit.Assert; import org.junit.Assert;
import org.junit.BeforeClass; import org.junit.BeforeClass;
@ -53,6 +51,7 @@ public class DbTest {
.build(); .build();
MybatisFlexBootstrap bootstrap = MybatisFlexBootstrap.getInstance() MybatisFlexBootstrap bootstrap = MybatisFlexBootstrap.getInstance()
.setLogImpl(StdOutImpl.class)
.setDataSource(dataSource) .setDataSource(dataSource)
.start(); .start();
@ -63,8 +62,6 @@ public class DbTest {
*/ */
Configuration configuration = bootstrap.getConfiguration(); Configuration configuration = bootstrap.getConfiguration();
configuration.setCallSettersOnNulls(true); configuration.setCallSettersOnNulls(true);
Db.updateBySql("update tb_account set options = null;");
} }
@SuppressWarnings("all") @SuppressWarnings("all")
@ -72,6 +69,8 @@ public class DbTest {
@Test @Test
public void test01() { public void test01() {
Db.updateBySql("update tb_account set options = null;");
List<Row> rows = Db.selectAll(tb_account); List<Row> rows = Db.selectAll(tb_account);
rows.stream() rows.stream()
@ -88,6 +87,7 @@ public class DbTest {
assert map.equals(map2); assert map.equals(map2);
} }
@Test @Test
public void test03() { public void test03() {
try { try {
@ -105,9 +105,21 @@ public class DbTest {
account3.setAge(4); account3.setAge(4);
accounts.add(account3); accounts.add(account3);
Db.updateEntitiesBatch(accounts); Db.updateEntitiesBatch(accounts);
}catch (Exception e){ } catch (Exception e) {
assert false; assert false;
} }
} }
@Test
public void testTypeHandler() {
QueryWrapper queryWrapper = QueryWrapper.create()
.select("*")
.from("tb_account")
.where("age = ?", 3);
List<Row> rows = Db.selectListByQuery(queryWrapper);
RowUtil.printPretty(rows);
}
} }