1
0

[HUDI-2232] [SQL] MERGE INTO fails with table having nested struct (#3379)

This commit is contained in:
pengzhiwei
2021-08-04 18:20:29 +08:00
committed by GitHub
parent b8b9d6db83
commit 5574e092fb
5 changed files with 111 additions and 207 deletions

View File

@@ -17,18 +17,18 @@
package org.apache.hudi.sql;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.generic.IndexedRecord;
/***
* A interface for CodeGen to execute expressions on the record
* and return the results with a array for each expression.
* A interface for CodeGen to execute expressions on the record.
*/
public interface IExpressionEvaluator {
/**
* Evaluate the result of the expressions based on the record.
*/
Object[] eval(IndexedRecord record);
GenericRecord eval(IndexedRecord record);
/**
* Get the code of the expressions. This is used for debug.

View File

@@ -18,15 +18,15 @@
package org.apache.spark.sql.hudi.command.payload
import java.util.UUID
import org.apache.avro.generic.IndexedRecord
import org.apache.avro.generic.{GenericRecord, IndexedRecord}
import org.apache.hudi.sql.IExpressionEvaluator
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.AvroSerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, LeafExpression, UnsafeArrayData, UnsafeMapData, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, LeafExpression, UnsafeArrayData, UnsafeMapData, UnsafeRow}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.hudi.command.payload.ExpressionCodeGen.RECORD_NAME
import org.apache.spark.sql.types.{DataType, Decimal}
@@ -53,7 +53,7 @@ object ExpressionCodeGen extends Logging {
* @return An IExpressionEvaluator generate by CodeGen which take a IndexedRecord as input
* param and return a Array of results for each expression.
*/
def doCodeGen(exprs: Seq[Expression]): IExpressionEvaluator = {
def doCodeGen(exprs: Seq[Expression], serializer: AvroSerializer): IExpressionEvaluator = {
val ctx = new CodegenContext()
// Set the input_row to null as we do not use row as the input object but Record.
ctx.INPUT_ROW = null
@@ -65,13 +65,15 @@ object ExpressionCodeGen extends Logging {
s"""
|private Object[] references;
|private String code;
|private AvroSerializer serializer;
|
|public $className(Object references, String code) {
|public $className(Object references, String code, AvroSerializer serializer) {
| this.references = (Object[])references;
| this.code = code;
| this.serializer = serializer;
|}
|
|public Object[] eval(IndexedRecord $RECORD_NAME) {
|public GenericRecord eval(IndexedRecord $RECORD_NAME) {
| ${resultVars.map(_.code).mkString("\n")}
| Object[] results = new Object[${resultVars.length}];
| ${
@@ -85,7 +87,8 @@ object ExpressionCodeGen extends Logging {
""".stripMargin
}).mkString("\n")
}
return results;
InternalRow row = new GenericInternalRow(results);
return (GenericRecord) serializer.serialize(row);
| }
|
|public String getCode() {
@@ -115,7 +118,10 @@ object ExpressionCodeGen extends Logging {
classOf[TaskContext].getName,
classOf[TaskKilledException].getName,
classOf[InputMetrics].getName,
classOf[IndexedRecord].getName
classOf[IndexedRecord].getName,
classOf[AvroSerializer].getName,
classOf[GenericRecord].getName,
classOf[GenericInternalRow].getName
)
evaluator.setImplementedInterfaces(Array(classOf[IExpressionEvaluator]))
try {
@@ -133,8 +139,8 @@ object ExpressionCodeGen extends Logging {
val referenceArray = ctx.references.toArray.map(_.asInstanceOf[Object])
val expressionSql = exprs.map(_.sql).mkString(" ")
evaluator.getClazz.getConstructor(classOf[Object], classOf[String])
.newInstance(referenceArray, s"Expressions is: [$expressionSql]\nCodeBody is: {\n$codeBody\n}")
evaluator.getClazz.getConstructor(classOf[Object], classOf[String], classOf[AvroSerializer])
.newInstance(referenceArray, s"Expressions is: [$expressionSql]\nCodeBody is: {\n$codeBody\n}", serializer)
.asInstanceOf[IExpressionEvaluator]
}

View File

@@ -21,11 +21,8 @@ import java.util.{Base64, Properties}
import java.util.concurrent.Callable
import scala.collection.JavaConverters._
import com.google.common.cache.CacheBuilder
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.Schema.Type
import org.apache.avro.{LogicalTypes, Schema}
import org.apache.avro.Schema
import org.apache.avro.generic.{GenericData, GenericRecord, IndexedRecord}
import org.apache.avro.util.Utf8
import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.avro.HoodieAvroUtils
import org.apache.hudi.avro.HoodieAvroUtils.bytesToAvro
@@ -34,11 +31,11 @@ import org.apache.hudi.common.util.{ValidationUtils, Option => HOption}
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.io.HoodieWriteHandle
import org.apache.hudi.sql.IExpressionEvaluator
import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.hudi.SerDeUtils
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.getEvaluator
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.types.{StructField, StructType}
import scala.collection.mutable.ArrayBuffer
@@ -104,12 +101,11 @@ class ExpressionPayload(record: GenericRecord,
val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString, writeSchema)
for ((conditionEvaluator, assignmentEvaluator) <- updateConditionAndAssignments
if resultRecordOpt == null) {
val conditionVal = evaluate(conditionEvaluator, inputRecord).head.asInstanceOf[Boolean]
val conditionVal = evaluate(conditionEvaluator, inputRecord).get(0).asInstanceOf[Boolean]
// If the update condition matched then execute assignment expression
// to compute final record to update. We will return the first matched record.
if (conditionVal) {
val results = evaluate(assignmentEvaluator, inputRecord)
val resultRecord = convertToRecord(results, writeSchema)
val resultRecord = evaluate(assignmentEvaluator, inputRecord)
if (targetRecord.isEmpty || needUpdatingPersistedRecord(targetRecord.get, resultRecord, properties)) {
resultRecordOpt = HOption.of(resultRecord)
@@ -125,7 +121,7 @@ class ExpressionPayload(record: GenericRecord,
val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION)
if (deleteConditionText != null) {
val deleteCondition = getEvaluator(deleteConditionText.toString, writeSchema).head._1
val deleteConditionVal = evaluate(deleteCondition, inputRecord).head.asInstanceOf[Boolean]
val deleteConditionVal = evaluate(deleteCondition, inputRecord).get(0).asInstanceOf[Boolean]
if (deleteConditionVal) {
resultRecordOpt = HOption.empty()
}
@@ -159,12 +155,12 @@ class ExpressionPayload(record: GenericRecord,
var resultRecordOpt: HOption[IndexedRecord] = null
for ((conditionEvaluator, assignmentEvaluator) <- insertConditionAndAssignments
if resultRecordOpt == null) {
val conditionVal = evaluate(conditionEvaluator, inputRecord).head.asInstanceOf[Boolean]
val conditionVal = evaluate(conditionEvaluator, inputRecord).get(0).asInstanceOf[Boolean]
// If matched the insert condition then execute the assignment expressions to compute the
// result record. We will return the first matched record.
if (conditionVal) {
val results = evaluate(assignmentEvaluator, inputRecord)
resultRecordOpt = HOption.of(convertToRecord(results, writeSchema))
val resultRecord = evaluate(assignmentEvaluator, inputRecord)
resultRecordOpt = HOption.of(resultRecord)
}
}
if (resultRecordOpt != null) {
@@ -258,7 +254,7 @@ class ExpressionPayload(record: GenericRecord,
Schema.createRecord(a.getName, a.getDoc, a.getNamespace, a.isError, mergedFields.asJava)
}
private def evaluate(evaluator: IExpressionEvaluator, sqlTypedRecord: SqlTypedRecord): Array[Object] = {
private def evaluate(evaluator: IExpressionEvaluator, sqlTypedRecord: SqlTypedRecord): GenericRecord = {
try evaluator.eval(sqlTypedRecord) catch {
case e: Throwable =>
throw new RuntimeException(s"Error in execute expression: ${e.getMessage}.\n${evaluator.getCode}", e)
@@ -295,8 +291,6 @@ object ExpressionPayload {
/**
* Do the CodeGen for each condition and assignment expressions.We will cache it to reduce
* the compile time for each method call.
* @param serializedConditionAssignments
* @return
*/
def getEvaluator(
serializedConditionAssignments: String, writeSchema: Schema): Map[IExpressionEvaluator, IExpressionEvaluator] = {
@@ -310,42 +304,18 @@ object ExpressionPayload {
// Do the CodeGen for condition expression and assignment expression
conditionAssignments.map {
case (condition, assignments) =>
val conditionEvaluator = ExpressionCodeGen.doCodeGen(Seq(condition))
val assignmentEvaluator = AvroTypeConvertEvaluator(ExpressionCodeGen.doCodeGen(assignments), writeSchema)
val conditionType = StructType(Seq(StructField("_col0", condition.dataType, nullable = true)))
val conditionSerializer = new AvroSerializer(conditionType,
SchemaConverters.toAvroType(conditionType), false)
val conditionEvaluator = ExpressionCodeGen.doCodeGen(Seq(condition), conditionSerializer)
val assignSqlType = SchemaConverters.toSqlType(writeSchema).dataType.asInstanceOf[StructType]
val assignSerializer = new AvroSerializer(assignSqlType, writeSchema, false)
val assignmentEvaluator = ExpressionCodeGen.doCodeGen(assignments, assignSerializer)
conditionEvaluator -> assignmentEvaluator
}
}
})
}
/**
* A IExpressionEvaluator wrapped the base evaluator which convert the result of the base evaluator
* to the avro typed-value.
*/
case class AvroTypeConvertEvaluator(baseEvaluator: IExpressionEvaluator, writeSchema: Schema) extends IExpressionEvaluator {
private lazy val decimalConversions = new DecimalConversion()
/**
* Convert to the avro typed-value.
* e.g. convert UTF8String -> Utf8, Dicimal -> GenericFixed.
*/
override def eval(record: IndexedRecord): Array[AnyRef] = {
baseEvaluator.eval(record).zipWithIndex.map {
case (s: UTF8String, _) => new Utf8(s.toString)
case (d: Decimal, i) =>
val schema = writeSchema.getFields.get(i).schema()
val fixedSchema = if (schema.getType == Type.UNION) {
schema.getTypes.asScala.filter(s => s.getType != Type.NULL).head
} else {
schema
}
decimalConversions.toFixed(d.toJavaBigDecimal, fixedSchema
, LogicalTypes.decimal(d.precision, d.scale))
case (o, _) => o
}
}
override def getCode: String = baseEvaluator.getCode
}
}

View File

@@ -17,165 +17,28 @@
package org.apache.spark.sql.hudi.command.payload
import java.math.BigDecimal
import java.nio.ByteBuffer
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._
import org.apache.avro.generic.{GenericFixed, IndexedRecord}
import org.apache.avro.util.Utf8
import org.apache.avro.{LogicalTypes, Schema}
import org.apache.spark.sql.avro.{IncompatibleSchemaException, SchemaConverters}
import org.apache.avro.generic.IndexedRecord
import org.apache.avro.Schema
import org.apache.spark.sql.avro.{AvroDeserializer, SchemaConverters}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import scala.collection.JavaConverters._
/**
* A sql typed record which will convert the avro field to sql typed value.
* This is referred to the org.apache.spark.sql.avro.AvroDeserializer#newWriter in spark project.
* @param record
*/
* A sql typed record which will convert the avro field to sql typed value.
*/
class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord {
private lazy val decimalConversions = new DecimalConversion()
private lazy val sqlType = SchemaConverters.toSqlType(getSchema).dataType.asInstanceOf[StructType]
private lazy val avroDeserializer = new AvroDeserializer(record.getSchema, sqlType)
private lazy val sqlRow = avroDeserializer.deserialize(record).asInstanceOf[InternalRow]
override def put(i: Int, v: Any): Unit = {
record.put(i, v)
}
override def get(i: Int): AnyRef = {
val value = record.get(i)
val avroFieldType = getSchema.getFields.get(i).schema()
val sqlFieldType = sqlType.fields(i).dataType
if (value == null) {
null
} else {
convert(avroFieldType, sqlFieldType, value)
}
}
private def convert(avroFieldType: Schema, sqlFieldType: DataType, value: AnyRef): AnyRef = {
(avroFieldType.getType, sqlFieldType) match {
case (NULL, NullType) => null
case (BOOLEAN, BooleanType) => value.asInstanceOf[Boolean].asInstanceOf[java.lang.Boolean]
case (INT, IntegerType) => value.asInstanceOf[Int].asInstanceOf[java.lang.Integer]
case (INT, DateType) => value.asInstanceOf[Int].asInstanceOf[java.lang.Integer]
case (LONG, LongType) => value.asInstanceOf[Long].asInstanceOf[java.lang.Long]
case (LONG, TimestampType) => avroFieldType.getLogicalType match {
case _: TimestampMillis => (value.asInstanceOf[Long] * 1000).asInstanceOf[java.lang.Long]
case _: TimestampMicros => value.asInstanceOf[Long].asInstanceOf[java.lang.Long]
case null =>
// For backward compatibility, if the Avro type is Long and it is not logical type,
// the value is processed as timestamp type with millisecond precision.
java.lang.Long.valueOf(value.asInstanceOf[Long] * 1000)
case other => throw new IncompatibleSchemaException(
s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.")
}
// Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date.
// For backward compatibility, we still keep this conversion.
case (LONG, DateType) =>
java.lang.Integer.valueOf((value.asInstanceOf[Long] / SqlTypedRecord.MILLIS_PER_DAY).toInt)
case (FLOAT, FloatType) => value.asInstanceOf[Float].asInstanceOf[java.lang.Float]
case (DOUBLE, DoubleType) => value.asInstanceOf[Double].asInstanceOf[java.lang.Double]
case (STRING, StringType) => value match {
case s: String => UTF8String.fromString(s)
case s: Utf8 => UTF8String.fromString(s.toString)
case o => throw new IllegalArgumentException(s"Cannot convert $o to StringType")
}
case (ENUM, StringType) => value.toString
case (FIXED, BinaryType) => value.asInstanceOf[GenericFixed].bytes().clone()
case (BYTES, BinaryType) => value match {
case b: ByteBuffer =>
val bytes = new Array[Byte](b.remaining)
b.get(bytes)
bytes
case b: Array[Byte] => b
case other => throw new RuntimeException(s"$other is not a valid avro binary.")
}
case (FIXED, d: DecimalType) =>
val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroFieldType,
LogicalTypes.decimal(d.precision, d.scale))
createDecimal(bigDecimal, d.precision, d.scale)
case (BYTES, d: DecimalType) =>
val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroFieldType,
LogicalTypes.decimal(d.precision, d.scale))
createDecimal(bigDecimal, d.precision, d.scale)
case (RECORD, _: StructType) =>
throw new IllegalArgumentException(s"UnSupport StructType yet")
case (ARRAY, ArrayType(_, _)) =>
throw new IllegalArgumentException(s"UnSupport ARRAY type yet")
case (MAP, MapType(keyType, _, _)) if keyType == StringType =>
throw new IllegalArgumentException(s"UnSupport MAP type yet")
case (UNION, _) =>
val allTypes = avroFieldType.getTypes.asScala
val nonNullTypes = allTypes.filter(_.getType != NULL)
if (nonNullTypes.nonEmpty) {
if (nonNullTypes.length == 1) {
convert(nonNullTypes.head, sqlFieldType, value)
} else {
nonNullTypes.map(_.getType) match {
case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlFieldType == LongType =>
value match {
case null => null
case l: java.lang.Long => l
case i: java.lang.Integer => i.longValue().asInstanceOf[java.lang.Long]
}
case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlFieldType == DoubleType =>
value match {
case null => null
case d: java.lang.Double => d
case f: java.lang.Float => f.doubleValue().asInstanceOf[java.lang.Double]
}
case _ =>
throw new IllegalArgumentException(s"UnSupport UNION type: ${sqlFieldType}")
}
}
} else {
null
}
case _ =>
throw new IncompatibleSchemaException(
s"Cannot convert Avro to catalyst because schema " +
s"is not compatible (avroType = $avroFieldType, sqlType = $sqlFieldType).\n")
}
}
private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
if (precision <= Decimal.MAX_LONG_DIGITS) {
// Constructs a `Decimal` with an unscaled `Long` value if possible.
Decimal(decimal.unscaledValue().longValue(), precision, scale)
} else {
// Otherwise, resorts to an unscaled `BigInteger` instead.
Decimal(decimal, precision, scale)
}
sqlRow.get(i, sqlType(i).dataType)
}
override def getSchema: Schema = record.getSchema
}
object SqlTypedRecord {
val MILLIS_PER_DAY = 24 * 60 * 60 * 1000L
}

View File

@@ -18,6 +18,7 @@
package org.apache.spark.sql.hudi
import org.apache.hudi.common.table.HoodieTableMetaClient
import org.apache.spark.sql.Row
class TestMergeIntoTable2 extends TestHoodieSqlBase {
@@ -172,4 +173,68 @@ class TestMergeIntoTable2 extends TestHoodieSqlBase {
)
}
}
test("Test Merge With Complex Data Type") {
withTempDir{tmp =>
val tableName = generateTableName
spark.sql(
s"""
| create table $tableName (
| id int,
| name string,
| s_value struct<f0: int, f1: string>,
| a_value array<string>,
| m_value map<string, string>,
| ts long
| ) using hudi
| options (
| type = 'mor',
| primaryKey = 'id',
| preCombineField = 'ts'
| )
| location '${tmp.getCanonicalPath}'
""".stripMargin)
spark.sql(
s"""
|merge into $tableName h0
|using (
|select
| 1 as id,
| 'a1' as name,
| struct(1, '10') as s_value,
| split('a0,a1', ',') as a_value,
| map('k0', 'v0') as m_value,
| 1000 as ts
|) s0
|on h0.id = s0.id
|when not matched then insert *
|""".stripMargin)
checkAnswer(s"select id, name, s_value, a_value, m_value, ts from $tableName")(
Seq(1, "a1", Row(1, "10"), Seq("a0", "a1"), Map("k0" -> "v0"), 1000)
)
// update value
spark.sql(
s"""
|merge into $tableName h0
|using (
|select
| 1 as id,
| 'a1' as name,
| struct(1, '12') as s_value,
| split('a0,a1,a2', ',') as a_value,
| map('k1', 'v1') as m_value,
| 1000 as ts
|) s0
|on h0.id = s0.id
|when matched then update set *
|when not matched then insert *
|""".stripMargin)
checkAnswer(s"select id, name, s_value, a_value, m_value, ts from $tableName")(
Seq(1, "a1", Row(1, "12"), Seq("a0", "a1", "a2"), Map("k1" -> "v1"), 1000)
)
}
}
}