diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala index 4c15670a4..74a74882a 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala @@ -161,11 +161,25 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi val resolvedCondition = condition.map(resolveExpressionFrom(resolvedSource)(_)) val resolvedAssignments = if (isInsertOrUpdateStar(assignments)) { // assignments is empty means insert * or update set * - // we fill assign all the source fields to the target fields - target.output - .filter(attr => !HoodieSqlUtils.isMetaField(attr.name)) - .zip(resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))) - .map { case (targetAttr, sourceAttr) => Assignment(targetAttr, sourceAttr) } + val resolvedSourceOutputWithoutMetaFields = resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name)) + val targetOutputWithoutMetaFields = target.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name)) + val resolvedSourceColumnNamesWithoutMetaFields = resolvedSourceOutputWithoutMetaFields.map(_.name) + val targetColumnNamesWithoutMetaFields = targetOutputWithoutMetaFields.map(_.name) + + if(targetColumnNamesWithoutMetaFields.toSet.subsetOf(resolvedSourceColumnNamesWithoutMetaFields.toSet)){ + //If sourceTable's columns contains all targetTable's columns, + //We fill assign all the source fields to the target fields by column name matching. + val sourceColNameAttrMap = resolvedSourceOutputWithoutMetaFields.map(attr => (attr.name, attr)).toMap + targetOutputWithoutMetaFields.map(targetAttr => { + val sourceAttr = sourceColNameAttrMap(targetAttr.name) + Assignment(targetAttr, sourceAttr) + }) + } else { + // We fill assign all the source fields to the target fields by order. + targetOutputWithoutMetaFields + .zip(resolvedSourceOutputWithoutMetaFields) + .map { case (targetAttr, sourceAttr) => Assignment(targetAttr, sourceAttr) } + } } else { assignments.map(assignment => { val resolvedKey = resolveExpressionFrom(target)(assignment.key) diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala index 827f91723..88d2e97e0 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala @@ -237,4 +237,87 @@ class TestMergeIntoTable2 extends TestHoodieSqlBase { } } + test("Test column name matching for insert * and update set *") { + withTempDir { tmp => + val tableName = generateTableName + // Create table + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price double, + | ts long, + | dt string + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | options ( + | primaryKey ='id', + | preCombineField = 'ts' + | ) + """.stripMargin) + + // Insert data to source table + spark.sql(s"insert into $tableName select 1, 'a1', 1, 10, '2021-03-21'") + checkAnswer(s"select id, name, price, ts, dt from $tableName")( + Seq(1, "a1", 1.0, 10, "2021-03-21") + ) + + // Test the order of column types in sourceTable is similar to that in targetTable + spark.sql( + s""" + |merge into $tableName as t0 + |using ( + | select 1 as id, '2021-05-05' as dt, 1002 as ts, 97 as price, 'a1' as name union all + | select 1 as id, '2021-05-05' as dt, 1003 as ts, 98 as price, 'a2' as name union all + | select 2 as id, '2021-05-05' as dt, 1001 as ts, 99 as price, 'a3' as name + | ) as s0 + |on t0.id = s0.id + |when matched then update set * + |when not matched then insert * + |""".stripMargin) + checkAnswer(s"select id, name, price, ts, dt from $tableName")( + Seq(1, "a2", 98.0, 1003, "2021-05-05"), + Seq(2, "a3", 99.0, 1001, "2021-05-05") + ) + // Test the order of the column types of sourceTable is different from the column types of targetTable + spark.sql( + s""" + |merge into $tableName as t0 + |using ( + | select 1 as id, 'a1' as name, 1004 as ts, '2021-05-05' as dt, 100 as price union all + | select 2 as id, 'a5' as name, 1000 as ts, '2021-05-05' as dt, 101 as price union all + | select 3 as id, 'a3' as name, 1000 as ts, '2021-05-05' as dt, 102 as price + | ) as s0 + |on t0.id = s0.id + |when matched then update set * + |when not matched then insert * + |""".stripMargin) + checkAnswer(s"select id, name, price, ts, dt from $tableName")( + Seq(1, "a1", 100.0, 1004, "2021-05-05"), + Seq(2, "a3", 99.0, 1001, "2021-05-05"), + Seq(3, "a3", 102.0, 1000, "2021-05-05") + ) + + // Test an extra input field 'flag' + spark.sql( + s""" + |merge into $tableName as t0 + |using ( + | select 1 as id, 'a6' as name, 1006 as ts, '2021-05-05' as dt, 106 as price, '0' as flag union all + | select 4 as id, 'a4' as name, 1000 as ts, '2021-05-06' as dt, 100 as price, '1' as flag + | ) as s0 + |on t0.id = s0.id + |when matched and flag = '1' then update set * + |when not matched and flag = '1' then insert * + |""".stripMargin) + checkAnswer(s"select id, name, price, ts, dt from $tableName")( + Seq(1, "a1", 100.0, 1004, "2021-05-05"), + Seq(2, "a3", 99.0, 1001, "2021-05-05"), + Seq(3, "a3", 102.0, 1000, "2021-05-05"), + Seq(4, "a4", 100.0, 1000, "2021-05-06") + ) + } + } + }