diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala index 0cf8a6458..d906f0c50 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala @@ -51,7 +51,11 @@ class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord { val value = record.get(i) val avroFieldType = getSchema.getFields.get(i).schema() val sqlFieldType = sqlType.fields(i).dataType - convert(avroFieldType, sqlFieldType, value) + if (value == null) { + null + } else { + convert(avroFieldType, sqlFieldType, value) + } } private def convert(avroFieldType: Schema, sqlFieldType: DataType, value: AnyRef): AnyRef = { @@ -89,6 +93,7 @@ class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord { case (STRING, StringType) => value match { case s: String => UTF8String.fromString(s) case s: Utf8 => UTF8String.fromString(s.toString) + case o => throw new IllegalArgumentException(s"Cannot convert $o to StringType") } case (ENUM, StringType) => value.toString diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala index 28c47aa5d..1c20a82fd 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala @@ -642,4 +642,49 @@ class TestMergeIntoTable extends TestHoodieSqlBase { ) } } + + test("Test MereInto With Null Fields") { + withTempDir { tmp => + val types = Seq( + "string" , + "int", + "bigint", + "double", + "float", + "timestamp", + "date", + "decimal" + ) + types.foreach { dataType => + val tableName = generateTableName + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | value $dataType, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | options ( + | primaryKey ='id', + | preCombineField = 'ts' + | ) + """.stripMargin) + + spark.sql( + s""" + |merge into $tableName h0 + |using ( + | select 1 as id, 'a1' as name, cast(null as $dataType) as value, 1000 as ts + | ) s0 + | on h0.id = s0.id + | when not matched then insert * + |""".stripMargin) + checkAnswer(s"select id, name, value, ts from $tableName")( + Seq(1, "a1", null, 1000) + ) + } + } + } }