[HUDI-2471] Add support ignoring case in merge into (#3700)
This commit is contained in:
@@ -125,6 +125,7 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
|
|||||||
case mergeInto @ MergeIntoTable(target, source, mergeCondition, matchedActions, notMatchedActions)
|
case mergeInto @ MergeIntoTable(target, source, mergeCondition, matchedActions, notMatchedActions)
|
||||||
if isHoodieTable(target, sparkSession) && target.resolved =>
|
if isHoodieTable(target, sparkSession) && target.resolved =>
|
||||||
|
|
||||||
|
val resolver = sparkSession.sessionState.conf.resolver
|
||||||
val resolvedSource = analyzer.execute(source)
|
val resolvedSource = analyzer.execute(source)
|
||||||
def isInsertOrUpdateStar(assignments: Seq[Assignment]): Boolean = {
|
def isInsertOrUpdateStar(assignments: Seq[Assignment]): Boolean = {
|
||||||
if (assignments.isEmpty) {
|
if (assignments.isEmpty) {
|
||||||
@@ -161,23 +162,21 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
|
|||||||
val resolvedCondition = condition.map(resolveExpressionFrom(resolvedSource)(_))
|
val resolvedCondition = condition.map(resolveExpressionFrom(resolvedSource)(_))
|
||||||
val resolvedAssignments = if (isInsertOrUpdateStar(assignments)) {
|
val resolvedAssignments = if (isInsertOrUpdateStar(assignments)) {
|
||||||
// assignments is empty means insert * or update set *
|
// assignments is empty means insert * or update set *
|
||||||
val resolvedSourceOutputWithoutMetaFields = resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
|
val resolvedSourceOutput = resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
|
||||||
val targetOutputWithoutMetaFields = target.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
|
val targetOutput = target.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
|
||||||
val resolvedSourceColumnNamesWithoutMetaFields = resolvedSourceOutputWithoutMetaFields.map(_.name)
|
val resolvedSourceColumnNames = resolvedSourceOutput.map(_.name)
|
||||||
val targetColumnNamesWithoutMetaFields = targetOutputWithoutMetaFields.map(_.name)
|
|
||||||
|
|
||||||
if(targetColumnNamesWithoutMetaFields.toSet.subsetOf(resolvedSourceColumnNamesWithoutMetaFields.toSet)){
|
if(targetOutput.filter(attr => resolvedSourceColumnNames.exists(resolver(_, attr.name))).equals(targetOutput)){
|
||||||
//If sourceTable's columns contains all targetTable's columns,
|
//If sourceTable's columns contains all targetTable's columns,
|
||||||
//We fill assign all the source fields to the target fields by column name matching.
|
//We fill assign all the source fields to the target fields by column name matching.
|
||||||
val sourceColNameAttrMap = resolvedSourceOutputWithoutMetaFields.map(attr => (attr.name, attr)).toMap
|
targetOutput.map(targetAttr => {
|
||||||
targetOutputWithoutMetaFields.map(targetAttr => {
|
val sourceAttr = resolvedSourceOutput.find(f => resolver(f.name, targetAttr.name)).get
|
||||||
val sourceAttr = sourceColNameAttrMap(targetAttr.name)
|
|
||||||
Assignment(targetAttr, sourceAttr)
|
Assignment(targetAttr, sourceAttr)
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
// We fill assign all the source fields to the target fields by order.
|
// We fill assign all the source fields to the target fields by order.
|
||||||
targetOutputWithoutMetaFields
|
targetOutput
|
||||||
.zip(resolvedSourceOutputWithoutMetaFields)
|
.zip(resolvedSourceOutput)
|
||||||
.map { case (targetAttr, sourceAttr) => Assignment(targetAttr, sourceAttr) }
|
.map { case (targetAttr, sourceAttr) => Assignment(targetAttr, sourceAttr) }
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -214,8 +213,9 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
|
|||||||
}.toMap
|
}.toMap
|
||||||
|
|
||||||
// Validate if there are incorrect target attributes.
|
// Validate if there are incorrect target attributes.
|
||||||
|
val targetColumnNames = removeMetaFields(target.output).map(_.name)
|
||||||
val unKnowTargets = target2Values.keys
|
val unKnowTargets = target2Values.keys
|
||||||
.filterNot(removeMetaFields(target.output).map(_.name).contains(_))
|
.filterNot(name => targetColumnNames.exists(resolver(_, name)))
|
||||||
if (unKnowTargets.nonEmpty) {
|
if (unKnowTargets.nonEmpty) {
|
||||||
throw new AnalysisException(s"Cannot find target attributes: ${unKnowTargets.mkString(",")}.")
|
throw new AnalysisException(s"Cannot find target attributes: ${unKnowTargets.mkString(",")}.")
|
||||||
}
|
}
|
||||||
@@ -224,19 +224,20 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
|
|||||||
// e.g. If the update action missing 'id' attribute, we fill a "id = target.id" to the update action.
|
// e.g. If the update action missing 'id' attribute, we fill a "id = target.id" to the update action.
|
||||||
val newAssignments = removeMetaFields(target.output)
|
val newAssignments = removeMetaFields(target.output)
|
||||||
.map(attr => {
|
.map(attr => {
|
||||||
|
val valueOption = target2Values.find(f => resolver(f._1, attr.name))
|
||||||
// TODO support partial update for MOR.
|
// TODO support partial update for MOR.
|
||||||
if (!target2Values.contains(attr.name) && targetTableType == MOR_TABLE_TYPE_OPT_VAL) {
|
if (valueOption.isEmpty && targetTableType == MOR_TABLE_TYPE_OPT_VAL) {
|
||||||
throw new AnalysisException(s"Missing specify the value for target field: '${attr.name}' in merge into update action" +
|
throw new AnalysisException(s"Missing specify the value for target field: '${attr.name}' in merge into update action" +
|
||||||
s" for MOR table. Currently we cannot support partial update for MOR," +
|
s" for MOR table. Currently we cannot support partial update for MOR," +
|
||||||
s" please complete all the target fields just like '...update set id = s0.id, name = s0.name ....'")
|
s" please complete all the target fields just like '...update set id = s0.id, name = s0.name ....'")
|
||||||
}
|
}
|
||||||
if (preCombineField.isDefined && preCombineField.get.equalsIgnoreCase(attr.name)
|
if (preCombineField.isDefined && preCombineField.get.equalsIgnoreCase(attr.name)
|
||||||
&& !target2Values.contains(attr.name)) {
|
&& valueOption.isEmpty) {
|
||||||
throw new AnalysisException(s"Missing specify value for the preCombineField:" +
|
throw new AnalysisException(s"Missing specify value for the preCombineField:" +
|
||||||
s" ${preCombineField.get} in merge-into update action. You should add" +
|
s" ${preCombineField.get} in merge-into update action. You should add" +
|
||||||
s" '... update set ${preCombineField.get} = xx....' to the when-matched clause.")
|
s" '... update set ${preCombineField.get} = xx....' to the when-matched clause.")
|
||||||
}
|
}
|
||||||
Assignment(attr, target2Values.getOrElse(attr.name, attr))
|
Assignment(attr, if (valueOption.isEmpty) attr else valueOption.get._2)
|
||||||
})
|
})
|
||||||
UpdateAction(resolvedCondition, newAssignments)
|
UpdateAction(resolvedCondition, newAssignments)
|
||||||
case DeleteAction(condition) =>
|
case DeleteAction(condition) =>
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import org.apache.hudi.hive.ddl.HiveSyncMode
|
|||||||
import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, HoodieWriterUtils, SparkAdapterSupport}
|
import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, HoodieWriterUtils, SparkAdapterSupport}
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||||
|
import org.apache.spark.sql.catalyst.analysis.Resolver
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, Cast, EqualTo, Expression, Literal}
|
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, Cast, EqualTo, Expression, Literal}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical._
|
import org.apache.spark.sql.catalyst.plans.logical._
|
||||||
import org.apache.spark.sql.execution.command.RunnableCommand
|
import org.apache.spark.sql.execution.command.RunnableCommand
|
||||||
@@ -90,6 +91,7 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
|
|||||||
* TODO Currently Non-equivalent conditions are not supported.
|
* TODO Currently Non-equivalent conditions are not supported.
|
||||||
*/
|
*/
|
||||||
private lazy val targetKey2SourceExpression: Map[String, Expression] = {
|
private lazy val targetKey2SourceExpression: Map[String, Expression] = {
|
||||||
|
val resolver = sparkSession.sessionState.conf.resolver
|
||||||
val conditions = splitByAnd(mergeInto.mergeCondition)
|
val conditions = splitByAnd(mergeInto.mergeCondition)
|
||||||
val allEqs = conditions.forall(p => p.isInstanceOf[EqualTo])
|
val allEqs = conditions.forall(p => p.isInstanceOf[EqualTo])
|
||||||
if (!allEqs) {
|
if (!allEqs) {
|
||||||
@@ -101,11 +103,11 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
|
|||||||
val target2Source = conditions.map(_.asInstanceOf[EqualTo])
|
val target2Source = conditions.map(_.asInstanceOf[EqualTo])
|
||||||
.map {
|
.map {
|
||||||
case EqualTo(left: AttributeReference, right)
|
case EqualTo(left: AttributeReference, right)
|
||||||
if targetAttrs.indexOf(left) >= 0 => // left is the target field
|
if targetAttrs.exists(f => attributeEqual(f, left, resolver)) => // left is the target field
|
||||||
left.name -> right
|
targetAttrs.find(f => resolver(f.name, left.name)).get.name -> right
|
||||||
case EqualTo(left, right: AttributeReference)
|
case EqualTo(left, right: AttributeReference)
|
||||||
if targetAttrs.indexOf(right) >= 0 => // right is the target field
|
if targetAttrs.exists(f => attributeEqual(f, right, resolver)) => // right is the target field
|
||||||
right.name -> left
|
targetAttrs.find(f => resolver(f.name, right.name)).get.name -> left
|
||||||
case eq =>
|
case eq =>
|
||||||
throw new AnalysisException(s"Invalidate Merge-On condition: ${eq.sql}." +
|
throw new AnalysisException(s"Invalidate Merge-On condition: ${eq.sql}." +
|
||||||
"The validate condition should be 'targetColumn = sourceColumnExpression', e.g." +
|
"The validate condition should be 'targetColumn = sourceColumnExpression', e.g." +
|
||||||
@@ -196,15 +198,24 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
|
|||||||
}
|
}
|
||||||
|
|
||||||
private def isEqualToTarget(targetColumnName: String, sourceExpression: Expression): Boolean = {
|
private def isEqualToTarget(targetColumnName: String, sourceExpression: Expression): Boolean = {
|
||||||
val sourceColNameMap = sourceDFOutput.map(attr => (attr.name.toLowerCase, attr.name)).toMap
|
val sourceColumnName = sourceDFOutput.map(_.name)
|
||||||
|
val resolver = sparkSession.sessionState.conf.resolver
|
||||||
|
|
||||||
sourceExpression match {
|
sourceExpression match {
|
||||||
case attr: AttributeReference if sourceColNameMap(attr.name.toLowerCase).equals(targetColumnName) => true
|
case attr: AttributeReference if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true
|
||||||
case Cast(attr: AttributeReference, _, _) if sourceColNameMap(attr.name.toLowerCase).equals(targetColumnName) => true
|
case Cast(attr: AttributeReference, _, _) if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true
|
||||||
case _=> false
|
case _=> false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compare a [[Attribute]] to another, return true if they have the same column name(by resolver) and exprId
|
||||||
|
*/
|
||||||
|
private def attributeEqual(
|
||||||
|
attr: Attribute, other: Attribute, resolver: Resolver): Boolean = {
|
||||||
|
resolver(attr.name, other.name) && attr.exprId == other.exprId
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Execute the update and delete action. All the matched and not-matched actions will
|
* Execute the update and delete action. All the matched and not-matched actions will
|
||||||
* execute in one upsert write operation. We pushed down the matched condition and assignment
|
* execute in one upsert write operation. We pushed down the matched condition and assignment
|
||||||
@@ -361,9 +372,9 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
|
|||||||
mergeInto.targetTable.output
|
mergeInto.targetTable.output
|
||||||
.filterNot(attr => isMetaField(attr.name))
|
.filterNot(attr => isMetaField(attr.name))
|
||||||
.map(attr => {
|
.map(attr => {
|
||||||
val assignment = attr2Assignment.getOrElse(attr,
|
val assignment = attr2Assignment.find(f => attributeEqual(f._1, attr, sparkSession.sessionState.conf.resolver))
|
||||||
throw new IllegalArgumentException(s"Cannot find related assignment for field: ${attr.name}"))
|
.getOrElse(throw new IllegalArgumentException(s"Cannot find related assignment for field: ${attr.name}"))
|
||||||
castIfNeeded(assignment, attr.dataType, sparkSession.sqlContext.conf)
|
castIfNeeded(assignment._2, attr.dataType, sparkSession.sqlContext.conf)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -432,4 +432,115 @@ class TestMergeIntoTable2 extends TestHoodieSqlBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Test ignoring case") {
|
||||||
|
withTempDir { tmp =>
|
||||||
|
val tableName = generateTableName
|
||||||
|
// Create table
|
||||||
|
spark.sql(
|
||||||
|
s"""
|
||||||
|
|create table $tableName (
|
||||||
|
| ID int,
|
||||||
|
| name string,
|
||||||
|
| price double,
|
||||||
|
| TS long,
|
||||||
|
| DT string
|
||||||
|
|) using hudi
|
||||||
|
| location '${tmp.getCanonicalPath}/$tableName'
|
||||||
|
| options (
|
||||||
|
| primaryKey ='ID',
|
||||||
|
| preCombineField = 'TS'
|
||||||
|
| )
|
||||||
|
""".stripMargin)
|
||||||
|
|
||||||
|
// First merge with a extra input field 'flag' (insert a new record)
|
||||||
|
spark.sql(
|
||||||
|
s"""
|
||||||
|
| merge into $tableName
|
||||||
|
| using (
|
||||||
|
| select 1 as id, 'a1' as name, 10 as PRICE, 1000 as ts, '2021-05-05' as dt, '1' as flag
|
||||||
|
| ) s0
|
||||||
|
| on s0.id = $tableName.id
|
||||||
|
| when matched and flag = '1' then update set
|
||||||
|
| id = s0.id, name = s0.name, PRICE = s0.price, ts = s0.ts, dt = s0.dt
|
||||||
|
| when not matched and flag = '1' then insert *
|
||||||
|
""".stripMargin)
|
||||||
|
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
|
||||||
|
Seq(1, "a1", 10.0, 1000, "2021-05-05")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Second merge (update the record)
|
||||||
|
spark.sql(
|
||||||
|
s"""
|
||||||
|
| merge into $tableName
|
||||||
|
| using (
|
||||||
|
| select 1 as id, 'a1' as name, 20 as PRICE, '2021-05-05' as dt, 1001 as ts
|
||||||
|
| ) s0
|
||||||
|
| on s0.id = $tableName.id
|
||||||
|
| when matched then update set
|
||||||
|
| id = s0.id, name = s0.name, PRICE = s0.price, ts = s0.ts, dt = s0.dt
|
||||||
|
| when not matched then insert *
|
||||||
|
""".stripMargin)
|
||||||
|
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
|
||||||
|
Seq(1, "a1", 20.0, 1001, "2021-05-05")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test ignoring case when column name matches
|
||||||
|
spark.sql(
|
||||||
|
s"""
|
||||||
|
| merge into $tableName as t0
|
||||||
|
| using (
|
||||||
|
| select 1 as id, 'a1' as name, 1111 as ts, '2021-05-05' as dt, 111 as PRICE union all
|
||||||
|
| select 2 as id, 'a2' as name, 1112 as ts, '2021-05-05' as dt, 112 as PRICE
|
||||||
|
| ) as s0
|
||||||
|
| on t0.id = s0.id
|
||||||
|
| when matched then update set *
|
||||||
|
| when not matched then insert *
|
||||||
|
|""".stripMargin)
|
||||||
|
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
|
||||||
|
Seq(1, "a1", 111.0, 1111, "2021-05-05"),
|
||||||
|
Seq(2, "a2", 112.0, 1112, "2021-05-05")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Test ignoring case for MOR table") {
|
||||||
|
withTempDir { tmp =>
|
||||||
|
val tableName = generateTableName
|
||||||
|
// Create a mor partitioned table.
|
||||||
|
spark.sql(
|
||||||
|
s"""
|
||||||
|
| create table $tableName (
|
||||||
|
| ID int,
|
||||||
|
| NAME string,
|
||||||
|
| price double,
|
||||||
|
| TS long,
|
||||||
|
| dt string
|
||||||
|
| ) using hudi
|
||||||
|
| options (
|
||||||
|
| type = 'mor',
|
||||||
|
| primaryKey = 'ID',
|
||||||
|
| preCombineField = 'TS'
|
||||||
|
| )
|
||||||
|
| partitioned by(dt)
|
||||||
|
| location '${tmp.getCanonicalPath}/$tableName'
|
||||||
|
""".stripMargin)
|
||||||
|
|
||||||
|
// Test ignoring case when column name matches
|
||||||
|
spark.sql(
|
||||||
|
s"""
|
||||||
|
| merge into $tableName as t0
|
||||||
|
| using (
|
||||||
|
| select 1 as id, 'a1' as NAME, 1111 as ts, '2021-05-05' as DT, 111 as price
|
||||||
|
| ) as s0
|
||||||
|
| on t0.id = s0.id
|
||||||
|
| when matched then update set *
|
||||||
|
| when not matched then insert *
|
||||||
|
""".stripMargin
|
||||||
|
)
|
||||||
|
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
|
||||||
|
Seq(1, "a1", 111.0, 1111, "2021-05-05")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user