From 5574e092fb40c24dfb417f160e4ac229d9cbcc0c Mon Sep 17 00:00:00 2001 From: pengzhiwei Date: Wed, 4 Aug 2021 18:20:29 +0800 Subject: [PATCH] [HUDI-2232] [SQL] MERGE INTO fails with table having nested struct (#3379) --- .../apache/hudi/sql/IExpressionEvaluator.java | 6 +- .../command/payload/ExpressionCodeGen.scala | 26 +-- .../command/payload/ExpressionPayload.scala | 66 ++------ .../hudi/command/payload/SqlTypedRecord.scala | 155 +----------------- .../spark/sql/hudi/TestMergeIntoTable2.scala | 65 ++++++++ 5 files changed, 111 insertions(+), 207 deletions(-) diff --git a/hudi-spark-datasource/hudi-spark/src/main/java/org/apache/hudi/sql/IExpressionEvaluator.java b/hudi-spark-datasource/hudi-spark/src/main/java/org/apache/hudi/sql/IExpressionEvaluator.java index 90b685a09..210f00351 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/java/org/apache/hudi/sql/IExpressionEvaluator.java +++ b/hudi-spark-datasource/hudi-spark/src/main/java/org/apache/hudi/sql/IExpressionEvaluator.java @@ -17,18 +17,18 @@ package org.apache.hudi.sql; +import org.apache.avro.generic.GenericRecord; import org.apache.avro.generic.IndexedRecord; /*** - * A interface for CodeGen to execute expressions on the record - * and return the results with a array for each expression. + * A interface for CodeGen to execute expressions on the record. */ public interface IExpressionEvaluator { /** * Evaluate the result of the expressions based on the record. */ - Object[] eval(IndexedRecord record); + GenericRecord eval(IndexedRecord record); /** * Get the code of the expressions. This is used for debug. diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala index 4ee479697..509746bae 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala @@ -18,15 +18,15 @@ package org.apache.spark.sql.hudi.command.payload import java.util.UUID - -import org.apache.avro.generic.IndexedRecord +import org.apache.avro.generic.{GenericRecord, IndexedRecord} import org.apache.hudi.sql.IExpressionEvaluator import org.apache.spark.executor.InputMetrics import org.apache.spark.internal.Logging +import org.apache.spark.sql.avro.AvroSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, LeafExpression, UnsafeArrayData, UnsafeMapData, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, LeafExpression, UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.hudi.command.payload.ExpressionCodeGen.RECORD_NAME import org.apache.spark.sql.types.{DataType, Decimal} @@ -53,7 +53,7 @@ object ExpressionCodeGen extends Logging { * @return An IExpressionEvaluator generate by CodeGen which take a IndexedRecord as input * param and return a Array of results for each expression. */ - def doCodeGen(exprs: Seq[Expression]): IExpressionEvaluator = { + def doCodeGen(exprs: Seq[Expression], serializer: AvroSerializer): IExpressionEvaluator = { val ctx = new CodegenContext() // Set the input_row to null as we do not use row as the input object but Record. ctx.INPUT_ROW = null @@ -65,13 +65,15 @@ object ExpressionCodeGen extends Logging { s""" |private Object[] references; |private String code; + |private AvroSerializer serializer; | - |public $className(Object references, String code) { + |public $className(Object references, String code, AvroSerializer serializer) { | this.references = (Object[])references; | this.code = code; + | this.serializer = serializer; |} | - |public Object[] eval(IndexedRecord $RECORD_NAME) { + |public GenericRecord eval(IndexedRecord $RECORD_NAME) { | ${resultVars.map(_.code).mkString("\n")} | Object[] results = new Object[${resultVars.length}]; | ${ @@ -85,7 +87,8 @@ object ExpressionCodeGen extends Logging { """.stripMargin }).mkString("\n") } - return results; + InternalRow row = new GenericInternalRow(results); + return (GenericRecord) serializer.serialize(row); | } | |public String getCode() { @@ -115,7 +118,10 @@ object ExpressionCodeGen extends Logging { classOf[TaskContext].getName, classOf[TaskKilledException].getName, classOf[InputMetrics].getName, - classOf[IndexedRecord].getName + classOf[IndexedRecord].getName, + classOf[AvroSerializer].getName, + classOf[GenericRecord].getName, + classOf[GenericInternalRow].getName ) evaluator.setImplementedInterfaces(Array(classOf[IExpressionEvaluator])) try { @@ -133,8 +139,8 @@ object ExpressionCodeGen extends Logging { val referenceArray = ctx.references.toArray.map(_.asInstanceOf[Object]) val expressionSql = exprs.map(_.sql).mkString(" ") - evaluator.getClazz.getConstructor(classOf[Object], classOf[String]) - .newInstance(referenceArray, s"Expressions is: [$expressionSql]\nCodeBody is: {\n$codeBody\n}") + evaluator.getClazz.getConstructor(classOf[Object], classOf[String], classOf[AvroSerializer]) + .newInstance(referenceArray, s"Expressions is: [$expressionSql]\nCodeBody is: {\n$codeBody\n}", serializer) .asInstanceOf[IExpressionEvaluator] } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala index a43416cb2..caa3af4da 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala @@ -21,11 +21,8 @@ import java.util.{Base64, Properties} import java.util.concurrent.Callable import scala.collection.JavaConverters._ import com.google.common.cache.CacheBuilder -import org.apache.avro.Conversions.DecimalConversion -import org.apache.avro.Schema.Type -import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Schema import org.apache.avro.generic.{GenericData, GenericRecord, IndexedRecord} -import org.apache.avro.util.Utf8 import org.apache.hudi.DataSourceWriteOptions._ import org.apache.hudi.avro.HoodieAvroUtils import org.apache.hudi.avro.HoodieAvroUtils.bytesToAvro @@ -34,11 +31,11 @@ import org.apache.hudi.common.util.{ValidationUtils, Option => HOption} import org.apache.hudi.config.HoodieWriteConfig import org.apache.hudi.io.HoodieWriteHandle import org.apache.hudi.sql.IExpressionEvaluator +import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.hudi.SerDeUtils import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.getEvaluator -import org.apache.spark.sql.types.Decimal -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types.{StructField, StructType} import scala.collection.mutable.ArrayBuffer @@ -104,12 +101,11 @@ class ExpressionPayload(record: GenericRecord, val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString, writeSchema) for ((conditionEvaluator, assignmentEvaluator) <- updateConditionAndAssignments if resultRecordOpt == null) { - val conditionVal = evaluate(conditionEvaluator, inputRecord).head.asInstanceOf[Boolean] + val conditionVal = evaluate(conditionEvaluator, inputRecord).get(0).asInstanceOf[Boolean] // If the update condition matched then execute assignment expression // to compute final record to update. We will return the first matched record. if (conditionVal) { - val results = evaluate(assignmentEvaluator, inputRecord) - val resultRecord = convertToRecord(results, writeSchema) + val resultRecord = evaluate(assignmentEvaluator, inputRecord) if (targetRecord.isEmpty || needUpdatingPersistedRecord(targetRecord.get, resultRecord, properties)) { resultRecordOpt = HOption.of(resultRecord) @@ -125,7 +121,7 @@ class ExpressionPayload(record: GenericRecord, val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION) if (deleteConditionText != null) { val deleteCondition = getEvaluator(deleteConditionText.toString, writeSchema).head._1 - val deleteConditionVal = evaluate(deleteCondition, inputRecord).head.asInstanceOf[Boolean] + val deleteConditionVal = evaluate(deleteCondition, inputRecord).get(0).asInstanceOf[Boolean] if (deleteConditionVal) { resultRecordOpt = HOption.empty() } @@ -159,12 +155,12 @@ class ExpressionPayload(record: GenericRecord, var resultRecordOpt: HOption[IndexedRecord] = null for ((conditionEvaluator, assignmentEvaluator) <- insertConditionAndAssignments if resultRecordOpt == null) { - val conditionVal = evaluate(conditionEvaluator, inputRecord).head.asInstanceOf[Boolean] + val conditionVal = evaluate(conditionEvaluator, inputRecord).get(0).asInstanceOf[Boolean] // If matched the insert condition then execute the assignment expressions to compute the // result record. We will return the first matched record. if (conditionVal) { - val results = evaluate(assignmentEvaluator, inputRecord) - resultRecordOpt = HOption.of(convertToRecord(results, writeSchema)) + val resultRecord = evaluate(assignmentEvaluator, inputRecord) + resultRecordOpt = HOption.of(resultRecord) } } if (resultRecordOpt != null) { @@ -258,7 +254,7 @@ class ExpressionPayload(record: GenericRecord, Schema.createRecord(a.getName, a.getDoc, a.getNamespace, a.isError, mergedFields.asJava) } - private def evaluate(evaluator: IExpressionEvaluator, sqlTypedRecord: SqlTypedRecord): Array[Object] = { + private def evaluate(evaluator: IExpressionEvaluator, sqlTypedRecord: SqlTypedRecord): GenericRecord = { try evaluator.eval(sqlTypedRecord) catch { case e: Throwable => throw new RuntimeException(s"Error in execute expression: ${e.getMessage}.\n${evaluator.getCode}", e) @@ -295,8 +291,6 @@ object ExpressionPayload { /** * Do the CodeGen for each condition and assignment expressions.We will cache it to reduce * the compile time for each method call. - * @param serializedConditionAssignments - * @return */ def getEvaluator( serializedConditionAssignments: String, writeSchema: Schema): Map[IExpressionEvaluator, IExpressionEvaluator] = { @@ -310,42 +304,18 @@ object ExpressionPayload { // Do the CodeGen for condition expression and assignment expression conditionAssignments.map { case (condition, assignments) => - val conditionEvaluator = ExpressionCodeGen.doCodeGen(Seq(condition)) - val assignmentEvaluator = AvroTypeConvertEvaluator(ExpressionCodeGen.doCodeGen(assignments), writeSchema) + val conditionType = StructType(Seq(StructField("_col0", condition.dataType, nullable = true))) + val conditionSerializer = new AvroSerializer(conditionType, + SchemaConverters.toAvroType(conditionType), false) + val conditionEvaluator = ExpressionCodeGen.doCodeGen(Seq(condition), conditionSerializer) + + val assignSqlType = SchemaConverters.toSqlType(writeSchema).dataType.asInstanceOf[StructType] + val assignSerializer = new AvroSerializer(assignSqlType, writeSchema, false) + val assignmentEvaluator = ExpressionCodeGen.doCodeGen(assignments, assignSerializer) conditionEvaluator -> assignmentEvaluator } } }) } - - /** - * A IExpressionEvaluator wrapped the base evaluator which convert the result of the base evaluator - * to the avro typed-value. - */ - case class AvroTypeConvertEvaluator(baseEvaluator: IExpressionEvaluator, writeSchema: Schema) extends IExpressionEvaluator { - private lazy val decimalConversions = new DecimalConversion() - - /** - * Convert to the avro typed-value. - * e.g. convert UTF8String -> Utf8, Dicimal -> GenericFixed. - */ - override def eval(record: IndexedRecord): Array[AnyRef] = { - baseEvaluator.eval(record).zipWithIndex.map { - case (s: UTF8String, _) => new Utf8(s.toString) - case (d: Decimal, i) => - val schema = writeSchema.getFields.get(i).schema() - val fixedSchema = if (schema.getType == Type.UNION) { - schema.getTypes.asScala.filter(s => s.getType != Type.NULL).head - } else { - schema - } - decimalConversions.toFixed(d.toJavaBigDecimal, fixedSchema - , LogicalTypes.decimal(d.precision, d.scale)) - case (o, _) => o - } - } - - override def getCode: String = baseEvaluator.getCode - } } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala index 6895ca811..2a12e9227 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala @@ -17,165 +17,28 @@ package org.apache.spark.sql.hudi.command.payload -import java.math.BigDecimal -import java.nio.ByteBuffer - -import org.apache.avro.Conversions.DecimalConversion -import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} -import org.apache.avro.Schema.Type._ -import org.apache.avro.generic.{GenericFixed, IndexedRecord} -import org.apache.avro.util.Utf8 -import org.apache.avro.{LogicalTypes, Schema} -import org.apache.spark.sql.avro.{IncompatibleSchemaException, SchemaConverters} +import org.apache.avro.generic.IndexedRecord +import org.apache.avro.Schema +import org.apache.spark.sql.avro.{AvroDeserializer, SchemaConverters} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -import scala.collection.JavaConverters._ /** - * A sql typed record which will convert the avro field to sql typed value. - * This is referred to the org.apache.spark.sql.avro.AvroDeserializer#newWriter in spark project. - * @param record - */ + * A sql typed record which will convert the avro field to sql typed value. + */ class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord { - private lazy val decimalConversions = new DecimalConversion() private lazy val sqlType = SchemaConverters.toSqlType(getSchema).dataType.asInstanceOf[StructType] + private lazy val avroDeserializer = new AvroDeserializer(record.getSchema, sqlType) + private lazy val sqlRow = avroDeserializer.deserialize(record).asInstanceOf[InternalRow] override def put(i: Int, v: Any): Unit = { record.put(i, v) } override def get(i: Int): AnyRef = { - val value = record.get(i) - val avroFieldType = getSchema.getFields.get(i).schema() - val sqlFieldType = sqlType.fields(i).dataType - if (value == null) { - null - } else { - convert(avroFieldType, sqlFieldType, value) - } - } - - private def convert(avroFieldType: Schema, sqlFieldType: DataType, value: AnyRef): AnyRef = { - (avroFieldType.getType, sqlFieldType) match { - case (NULL, NullType) => null - - case (BOOLEAN, BooleanType) => value.asInstanceOf[Boolean].asInstanceOf[java.lang.Boolean] - - case (INT, IntegerType) => value.asInstanceOf[Int].asInstanceOf[java.lang.Integer] - - case (INT, DateType) => value.asInstanceOf[Int].asInstanceOf[java.lang.Integer] - - case (LONG, LongType) => value.asInstanceOf[Long].asInstanceOf[java.lang.Long] - - case (LONG, TimestampType) => avroFieldType.getLogicalType match { - case _: TimestampMillis => (value.asInstanceOf[Long] * 1000).asInstanceOf[java.lang.Long] - case _: TimestampMicros => value.asInstanceOf[Long].asInstanceOf[java.lang.Long] - case null => - // For backward compatibility, if the Avro type is Long and it is not logical type, - // the value is processed as timestamp type with millisecond precision. - java.lang.Long.valueOf(value.asInstanceOf[Long] * 1000) - case other => throw new IncompatibleSchemaException( - s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.") - } - - // Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date. - // For backward compatibility, we still keep this conversion. - case (LONG, DateType) => - java.lang.Integer.valueOf((value.asInstanceOf[Long] / SqlTypedRecord.MILLIS_PER_DAY).toInt) - - case (FLOAT, FloatType) => value.asInstanceOf[Float].asInstanceOf[java.lang.Float] - - case (DOUBLE, DoubleType) => value.asInstanceOf[Double].asInstanceOf[java.lang.Double] - - case (STRING, StringType) => value match { - case s: String => UTF8String.fromString(s) - case s: Utf8 => UTF8String.fromString(s.toString) - case o => throw new IllegalArgumentException(s"Cannot convert $o to StringType") - } - - case (ENUM, StringType) => value.toString - - case (FIXED, BinaryType) => value.asInstanceOf[GenericFixed].bytes().clone() - - case (BYTES, BinaryType) => value match { - case b: ByteBuffer => - val bytes = new Array[Byte](b.remaining) - b.get(bytes) - bytes - case b: Array[Byte] => b - case other => throw new RuntimeException(s"$other is not a valid avro binary.") - } - - case (FIXED, d: DecimalType) => - val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroFieldType, - LogicalTypes.decimal(d.precision, d.scale)) - createDecimal(bigDecimal, d.precision, d.scale) - - case (BYTES, d: DecimalType) => - val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroFieldType, - LogicalTypes.decimal(d.precision, d.scale)) - createDecimal(bigDecimal, d.precision, d.scale) - - case (RECORD, _: StructType) => - throw new IllegalArgumentException(s"UnSupport StructType yet") - - case (ARRAY, ArrayType(_, _)) => - throw new IllegalArgumentException(s"UnSupport ARRAY type yet") - - case (MAP, MapType(keyType, _, _)) if keyType == StringType => - throw new IllegalArgumentException(s"UnSupport MAP type yet") - - case (UNION, _) => - val allTypes = avroFieldType.getTypes.asScala - val nonNullTypes = allTypes.filter(_.getType != NULL) - if (nonNullTypes.nonEmpty) { - if (nonNullTypes.length == 1) { - convert(nonNullTypes.head, sqlFieldType, value) - } else { - nonNullTypes.map(_.getType) match { - case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlFieldType == LongType => - value match { - case null => null - case l: java.lang.Long => l - case i: java.lang.Integer => i.longValue().asInstanceOf[java.lang.Long] - } - - case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlFieldType == DoubleType => - value match { - case null => null - case d: java.lang.Double => d - case f: java.lang.Float => f.doubleValue().asInstanceOf[java.lang.Double] - } - - case _ => - throw new IllegalArgumentException(s"UnSupport UNION type: ${sqlFieldType}") - } - } - } else { - null - } - case _ => - throw new IncompatibleSchemaException( - s"Cannot convert Avro to catalyst because schema " + - s"is not compatible (avroType = $avroFieldType, sqlType = $sqlFieldType).\n") - } - } - - private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { - if (precision <= Decimal.MAX_LONG_DIGITS) { - // Constructs a `Decimal` with an unscaled `Long` value if possible. - Decimal(decimal.unscaledValue().longValue(), precision, scale) - } else { - // Otherwise, resorts to an unscaled `BigInteger` instead. - Decimal(decimal, precision, scale) - } + sqlRow.get(i, sqlType(i).dataType) } override def getSchema: Schema = record.getSchema } - -object SqlTypedRecord { - val MILLIS_PER_DAY = 24 * 60 * 60 * 1000L -} diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala index 478e0eec8..827f91723 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hudi import org.apache.hudi.common.table.HoodieTableMetaClient +import org.apache.spark.sql.Row class TestMergeIntoTable2 extends TestHoodieSqlBase { @@ -172,4 +173,68 @@ class TestMergeIntoTable2 extends TestHoodieSqlBase { ) } } + + test("Test Merge With Complex Data Type") { + withTempDir{tmp => + val tableName = generateTableName + spark.sql( + s""" + | create table $tableName ( + | id int, + | name string, + | s_value struct, + | a_value array, + | m_value map, + | ts long + | ) using hudi + | options ( + | type = 'mor', + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + | location '${tmp.getCanonicalPath}' + """.stripMargin) + + spark.sql( + s""" + |merge into $tableName h0 + |using ( + |select + | 1 as id, + | 'a1' as name, + | struct(1, '10') as s_value, + | split('a0,a1', ',') as a_value, + | map('k0', 'v0') as m_value, + | 1000 as ts + |) s0 + |on h0.id = s0.id + |when not matched then insert * + |""".stripMargin) + + checkAnswer(s"select id, name, s_value, a_value, m_value, ts from $tableName")( + Seq(1, "a1", Row(1, "10"), Seq("a0", "a1"), Map("k0" -> "v0"), 1000) + ) + // update value + spark.sql( + s""" + |merge into $tableName h0 + |using ( + |select + | 1 as id, + | 'a1' as name, + | struct(1, '12') as s_value, + | split('a0,a1,a2', ',') as a_value, + | map('k1', 'v1') as m_value, + | 1000 as ts + |) s0 + |on h0.id = s0.id + |when matched then update set * + |when not matched then insert * + |""".stripMargin) + checkAnswer(s"select id, name, s_value, a_value, m_value, ts from $tableName")( + Seq(1, "a1", Row(1, "12"), Seq("a0", "a1", "a2"), Map("k1" -> "v1"), 1000) + ) + } + } + }