[HUDI-2232] [SQL] MERGE INTO fails with table having nested struct (#3379)
This commit is contained in:
@@ -17,18 +17,18 @@
|
|||||||
|
|
||||||
package org.apache.hudi.sql;
|
package org.apache.hudi.sql;
|
||||||
|
|
||||||
|
import org.apache.avro.generic.GenericRecord;
|
||||||
import org.apache.avro.generic.IndexedRecord;
|
import org.apache.avro.generic.IndexedRecord;
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* A interface for CodeGen to execute expressions on the record
|
* A interface for CodeGen to execute expressions on the record.
|
||||||
* and return the results with a array for each expression.
|
|
||||||
*/
|
*/
|
||||||
public interface IExpressionEvaluator {
|
public interface IExpressionEvaluator {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Evaluate the result of the expressions based on the record.
|
* Evaluate the result of the expressions based on the record.
|
||||||
*/
|
*/
|
||||||
Object[] eval(IndexedRecord record);
|
GenericRecord eval(IndexedRecord record);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the code of the expressions. This is used for debug.
|
* Get the code of the expressions. This is used for debug.
|
||||||
|
|||||||
@@ -18,15 +18,15 @@
|
|||||||
package org.apache.spark.sql.hudi.command.payload
|
package org.apache.spark.sql.hudi.command.payload
|
||||||
|
|
||||||
import java.util.UUID
|
import java.util.UUID
|
||||||
|
import org.apache.avro.generic.{GenericRecord, IndexedRecord}
|
||||||
import org.apache.avro.generic.IndexedRecord
|
|
||||||
import org.apache.hudi.sql.IExpressionEvaluator
|
import org.apache.hudi.sql.IExpressionEvaluator
|
||||||
import org.apache.spark.executor.InputMetrics
|
import org.apache.spark.executor.InputMetrics
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
|
import org.apache.spark.sql.avro.AvroSerializer
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
|
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, LeafExpression, UnsafeArrayData, UnsafeMapData, UnsafeRow}
|
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, LeafExpression, UnsafeArrayData, UnsafeMapData, UnsafeRow}
|
||||||
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
||||||
import org.apache.spark.sql.hudi.command.payload.ExpressionCodeGen.RECORD_NAME
|
import org.apache.spark.sql.hudi.command.payload.ExpressionCodeGen.RECORD_NAME
|
||||||
import org.apache.spark.sql.types.{DataType, Decimal}
|
import org.apache.spark.sql.types.{DataType, Decimal}
|
||||||
@@ -53,7 +53,7 @@ object ExpressionCodeGen extends Logging {
|
|||||||
* @return An IExpressionEvaluator generate by CodeGen which take a IndexedRecord as input
|
* @return An IExpressionEvaluator generate by CodeGen which take a IndexedRecord as input
|
||||||
* param and return a Array of results for each expression.
|
* param and return a Array of results for each expression.
|
||||||
*/
|
*/
|
||||||
def doCodeGen(exprs: Seq[Expression]): IExpressionEvaluator = {
|
def doCodeGen(exprs: Seq[Expression], serializer: AvroSerializer): IExpressionEvaluator = {
|
||||||
val ctx = new CodegenContext()
|
val ctx = new CodegenContext()
|
||||||
// Set the input_row to null as we do not use row as the input object but Record.
|
// Set the input_row to null as we do not use row as the input object but Record.
|
||||||
ctx.INPUT_ROW = null
|
ctx.INPUT_ROW = null
|
||||||
@@ -65,13 +65,15 @@ object ExpressionCodeGen extends Logging {
|
|||||||
s"""
|
s"""
|
||||||
|private Object[] references;
|
|private Object[] references;
|
||||||
|private String code;
|
|private String code;
|
||||||
|
|private AvroSerializer serializer;
|
||||||
|
|
|
|
||||||
|public $className(Object references, String code) {
|
|public $className(Object references, String code, AvroSerializer serializer) {
|
||||||
| this.references = (Object[])references;
|
| this.references = (Object[])references;
|
||||||
| this.code = code;
|
| this.code = code;
|
||||||
|
| this.serializer = serializer;
|
||||||
|}
|
|}
|
||||||
|
|
|
|
||||||
|public Object[] eval(IndexedRecord $RECORD_NAME) {
|
|public GenericRecord eval(IndexedRecord $RECORD_NAME) {
|
||||||
| ${resultVars.map(_.code).mkString("\n")}
|
| ${resultVars.map(_.code).mkString("\n")}
|
||||||
| Object[] results = new Object[${resultVars.length}];
|
| Object[] results = new Object[${resultVars.length}];
|
||||||
| ${
|
| ${
|
||||||
@@ -85,7 +87,8 @@ object ExpressionCodeGen extends Logging {
|
|||||||
""".stripMargin
|
""".stripMargin
|
||||||
}).mkString("\n")
|
}).mkString("\n")
|
||||||
}
|
}
|
||||||
return results;
|
InternalRow row = new GenericInternalRow(results);
|
||||||
|
return (GenericRecord) serializer.serialize(row);
|
||||||
| }
|
| }
|
||||||
|
|
|
|
||||||
|public String getCode() {
|
|public String getCode() {
|
||||||
@@ -115,7 +118,10 @@ object ExpressionCodeGen extends Logging {
|
|||||||
classOf[TaskContext].getName,
|
classOf[TaskContext].getName,
|
||||||
classOf[TaskKilledException].getName,
|
classOf[TaskKilledException].getName,
|
||||||
classOf[InputMetrics].getName,
|
classOf[InputMetrics].getName,
|
||||||
classOf[IndexedRecord].getName
|
classOf[IndexedRecord].getName,
|
||||||
|
classOf[AvroSerializer].getName,
|
||||||
|
classOf[GenericRecord].getName,
|
||||||
|
classOf[GenericInternalRow].getName
|
||||||
)
|
)
|
||||||
evaluator.setImplementedInterfaces(Array(classOf[IExpressionEvaluator]))
|
evaluator.setImplementedInterfaces(Array(classOf[IExpressionEvaluator]))
|
||||||
try {
|
try {
|
||||||
@@ -133,8 +139,8 @@ object ExpressionCodeGen extends Logging {
|
|||||||
val referenceArray = ctx.references.toArray.map(_.asInstanceOf[Object])
|
val referenceArray = ctx.references.toArray.map(_.asInstanceOf[Object])
|
||||||
val expressionSql = exprs.map(_.sql).mkString(" ")
|
val expressionSql = exprs.map(_.sql).mkString(" ")
|
||||||
|
|
||||||
evaluator.getClazz.getConstructor(classOf[Object], classOf[String])
|
evaluator.getClazz.getConstructor(classOf[Object], classOf[String], classOf[AvroSerializer])
|
||||||
.newInstance(referenceArray, s"Expressions is: [$expressionSql]\nCodeBody is: {\n$codeBody\n}")
|
.newInstance(referenceArray, s"Expressions is: [$expressionSql]\nCodeBody is: {\n$codeBody\n}", serializer)
|
||||||
.asInstanceOf[IExpressionEvaluator]
|
.asInstanceOf[IExpressionEvaluator]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,11 +21,8 @@ import java.util.{Base64, Properties}
|
|||||||
import java.util.concurrent.Callable
|
import java.util.concurrent.Callable
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import com.google.common.cache.CacheBuilder
|
import com.google.common.cache.CacheBuilder
|
||||||
import org.apache.avro.Conversions.DecimalConversion
|
import org.apache.avro.Schema
|
||||||
import org.apache.avro.Schema.Type
|
|
||||||
import org.apache.avro.{LogicalTypes, Schema}
|
|
||||||
import org.apache.avro.generic.{GenericData, GenericRecord, IndexedRecord}
|
import org.apache.avro.generic.{GenericData, GenericRecord, IndexedRecord}
|
||||||
import org.apache.avro.util.Utf8
|
|
||||||
import org.apache.hudi.DataSourceWriteOptions._
|
import org.apache.hudi.DataSourceWriteOptions._
|
||||||
import org.apache.hudi.avro.HoodieAvroUtils
|
import org.apache.hudi.avro.HoodieAvroUtils
|
||||||
import org.apache.hudi.avro.HoodieAvroUtils.bytesToAvro
|
import org.apache.hudi.avro.HoodieAvroUtils.bytesToAvro
|
||||||
@@ -34,11 +31,11 @@ import org.apache.hudi.common.util.{ValidationUtils, Option => HOption}
|
|||||||
import org.apache.hudi.config.HoodieWriteConfig
|
import org.apache.hudi.config.HoodieWriteConfig
|
||||||
import org.apache.hudi.io.HoodieWriteHandle
|
import org.apache.hudi.io.HoodieWriteHandle
|
||||||
import org.apache.hudi.sql.IExpressionEvaluator
|
import org.apache.hudi.sql.IExpressionEvaluator
|
||||||
|
import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters}
|
||||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||||
import org.apache.spark.sql.hudi.SerDeUtils
|
import org.apache.spark.sql.hudi.SerDeUtils
|
||||||
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.getEvaluator
|
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.getEvaluator
|
||||||
import org.apache.spark.sql.types.Decimal
|
import org.apache.spark.sql.types.{StructField, StructType}
|
||||||
import org.apache.spark.unsafe.types.UTF8String
|
|
||||||
|
|
||||||
import scala.collection.mutable.ArrayBuffer
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
@@ -104,12 +101,11 @@ class ExpressionPayload(record: GenericRecord,
|
|||||||
val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString, writeSchema)
|
val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString, writeSchema)
|
||||||
for ((conditionEvaluator, assignmentEvaluator) <- updateConditionAndAssignments
|
for ((conditionEvaluator, assignmentEvaluator) <- updateConditionAndAssignments
|
||||||
if resultRecordOpt == null) {
|
if resultRecordOpt == null) {
|
||||||
val conditionVal = evaluate(conditionEvaluator, inputRecord).head.asInstanceOf[Boolean]
|
val conditionVal = evaluate(conditionEvaluator, inputRecord).get(0).asInstanceOf[Boolean]
|
||||||
// If the update condition matched then execute assignment expression
|
// If the update condition matched then execute assignment expression
|
||||||
// to compute final record to update. We will return the first matched record.
|
// to compute final record to update. We will return the first matched record.
|
||||||
if (conditionVal) {
|
if (conditionVal) {
|
||||||
val results = evaluate(assignmentEvaluator, inputRecord)
|
val resultRecord = evaluate(assignmentEvaluator, inputRecord)
|
||||||
val resultRecord = convertToRecord(results, writeSchema)
|
|
||||||
|
|
||||||
if (targetRecord.isEmpty || needUpdatingPersistedRecord(targetRecord.get, resultRecord, properties)) {
|
if (targetRecord.isEmpty || needUpdatingPersistedRecord(targetRecord.get, resultRecord, properties)) {
|
||||||
resultRecordOpt = HOption.of(resultRecord)
|
resultRecordOpt = HOption.of(resultRecord)
|
||||||
@@ -125,7 +121,7 @@ class ExpressionPayload(record: GenericRecord,
|
|||||||
val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION)
|
val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION)
|
||||||
if (deleteConditionText != null) {
|
if (deleteConditionText != null) {
|
||||||
val deleteCondition = getEvaluator(deleteConditionText.toString, writeSchema).head._1
|
val deleteCondition = getEvaluator(deleteConditionText.toString, writeSchema).head._1
|
||||||
val deleteConditionVal = evaluate(deleteCondition, inputRecord).head.asInstanceOf[Boolean]
|
val deleteConditionVal = evaluate(deleteCondition, inputRecord).get(0).asInstanceOf[Boolean]
|
||||||
if (deleteConditionVal) {
|
if (deleteConditionVal) {
|
||||||
resultRecordOpt = HOption.empty()
|
resultRecordOpt = HOption.empty()
|
||||||
}
|
}
|
||||||
@@ -159,12 +155,12 @@ class ExpressionPayload(record: GenericRecord,
|
|||||||
var resultRecordOpt: HOption[IndexedRecord] = null
|
var resultRecordOpt: HOption[IndexedRecord] = null
|
||||||
for ((conditionEvaluator, assignmentEvaluator) <- insertConditionAndAssignments
|
for ((conditionEvaluator, assignmentEvaluator) <- insertConditionAndAssignments
|
||||||
if resultRecordOpt == null) {
|
if resultRecordOpt == null) {
|
||||||
val conditionVal = evaluate(conditionEvaluator, inputRecord).head.asInstanceOf[Boolean]
|
val conditionVal = evaluate(conditionEvaluator, inputRecord).get(0).asInstanceOf[Boolean]
|
||||||
// If matched the insert condition then execute the assignment expressions to compute the
|
// If matched the insert condition then execute the assignment expressions to compute the
|
||||||
// result record. We will return the first matched record.
|
// result record. We will return the first matched record.
|
||||||
if (conditionVal) {
|
if (conditionVal) {
|
||||||
val results = evaluate(assignmentEvaluator, inputRecord)
|
val resultRecord = evaluate(assignmentEvaluator, inputRecord)
|
||||||
resultRecordOpt = HOption.of(convertToRecord(results, writeSchema))
|
resultRecordOpt = HOption.of(resultRecord)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (resultRecordOpt != null) {
|
if (resultRecordOpt != null) {
|
||||||
@@ -258,7 +254,7 @@ class ExpressionPayload(record: GenericRecord,
|
|||||||
Schema.createRecord(a.getName, a.getDoc, a.getNamespace, a.isError, mergedFields.asJava)
|
Schema.createRecord(a.getName, a.getDoc, a.getNamespace, a.isError, mergedFields.asJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def evaluate(evaluator: IExpressionEvaluator, sqlTypedRecord: SqlTypedRecord): Array[Object] = {
|
private def evaluate(evaluator: IExpressionEvaluator, sqlTypedRecord: SqlTypedRecord): GenericRecord = {
|
||||||
try evaluator.eval(sqlTypedRecord) catch {
|
try evaluator.eval(sqlTypedRecord) catch {
|
||||||
case e: Throwable =>
|
case e: Throwable =>
|
||||||
throw new RuntimeException(s"Error in execute expression: ${e.getMessage}.\n${evaluator.getCode}", e)
|
throw new RuntimeException(s"Error in execute expression: ${e.getMessage}.\n${evaluator.getCode}", e)
|
||||||
@@ -295,8 +291,6 @@ object ExpressionPayload {
|
|||||||
/**
|
/**
|
||||||
* Do the CodeGen for each condition and assignment expressions.We will cache it to reduce
|
* Do the CodeGen for each condition and assignment expressions.We will cache it to reduce
|
||||||
* the compile time for each method call.
|
* the compile time for each method call.
|
||||||
* @param serializedConditionAssignments
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
def getEvaluator(
|
def getEvaluator(
|
||||||
serializedConditionAssignments: String, writeSchema: Schema): Map[IExpressionEvaluator, IExpressionEvaluator] = {
|
serializedConditionAssignments: String, writeSchema: Schema): Map[IExpressionEvaluator, IExpressionEvaluator] = {
|
||||||
@@ -310,42 +304,18 @@ object ExpressionPayload {
|
|||||||
// Do the CodeGen for condition expression and assignment expression
|
// Do the CodeGen for condition expression and assignment expression
|
||||||
conditionAssignments.map {
|
conditionAssignments.map {
|
||||||
case (condition, assignments) =>
|
case (condition, assignments) =>
|
||||||
val conditionEvaluator = ExpressionCodeGen.doCodeGen(Seq(condition))
|
val conditionType = StructType(Seq(StructField("_col0", condition.dataType, nullable = true)))
|
||||||
val assignmentEvaluator = AvroTypeConvertEvaluator(ExpressionCodeGen.doCodeGen(assignments), writeSchema)
|
val conditionSerializer = new AvroSerializer(conditionType,
|
||||||
|
SchemaConverters.toAvroType(conditionType), false)
|
||||||
|
val conditionEvaluator = ExpressionCodeGen.doCodeGen(Seq(condition), conditionSerializer)
|
||||||
|
|
||||||
|
val assignSqlType = SchemaConverters.toSqlType(writeSchema).dataType.asInstanceOf[StructType]
|
||||||
|
val assignSerializer = new AvroSerializer(assignSqlType, writeSchema, false)
|
||||||
|
val assignmentEvaluator = ExpressionCodeGen.doCodeGen(assignments, assignSerializer)
|
||||||
conditionEvaluator -> assignmentEvaluator
|
conditionEvaluator -> assignmentEvaluator
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* A IExpressionEvaluator wrapped the base evaluator which convert the result of the base evaluator
|
|
||||||
* to the avro typed-value.
|
|
||||||
*/
|
|
||||||
case class AvroTypeConvertEvaluator(baseEvaluator: IExpressionEvaluator, writeSchema: Schema) extends IExpressionEvaluator {
|
|
||||||
private lazy val decimalConversions = new DecimalConversion()
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert to the avro typed-value.
|
|
||||||
* e.g. convert UTF8String -> Utf8, Dicimal -> GenericFixed.
|
|
||||||
*/
|
|
||||||
override def eval(record: IndexedRecord): Array[AnyRef] = {
|
|
||||||
baseEvaluator.eval(record).zipWithIndex.map {
|
|
||||||
case (s: UTF8String, _) => new Utf8(s.toString)
|
|
||||||
case (d: Decimal, i) =>
|
|
||||||
val schema = writeSchema.getFields.get(i).schema()
|
|
||||||
val fixedSchema = if (schema.getType == Type.UNION) {
|
|
||||||
schema.getTypes.asScala.filter(s => s.getType != Type.NULL).head
|
|
||||||
} else {
|
|
||||||
schema
|
|
||||||
}
|
|
||||||
decimalConversions.toFixed(d.toJavaBigDecimal, fixedSchema
|
|
||||||
, LogicalTypes.decimal(d.precision, d.scale))
|
|
||||||
case (o, _) => o
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def getCode: String = baseEvaluator.getCode
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,165 +17,28 @@
|
|||||||
|
|
||||||
package org.apache.spark.sql.hudi.command.payload
|
package org.apache.spark.sql.hudi.command.payload
|
||||||
|
|
||||||
import java.math.BigDecimal
|
import org.apache.avro.generic.IndexedRecord
|
||||||
import java.nio.ByteBuffer
|
import org.apache.avro.Schema
|
||||||
|
import org.apache.spark.sql.avro.{AvroDeserializer, SchemaConverters}
|
||||||
import org.apache.avro.Conversions.DecimalConversion
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
|
|
||||||
import org.apache.avro.Schema.Type._
|
|
||||||
import org.apache.avro.generic.{GenericFixed, IndexedRecord}
|
|
||||||
import org.apache.avro.util.Utf8
|
|
||||||
import org.apache.avro.{LogicalTypes, Schema}
|
|
||||||
import org.apache.spark.sql.avro.{IncompatibleSchemaException, SchemaConverters}
|
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.unsafe.types.UTF8String
|
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A sql typed record which will convert the avro field to sql typed value.
|
* A sql typed record which will convert the avro field to sql typed value.
|
||||||
* This is referred to the org.apache.spark.sql.avro.AvroDeserializer#newWriter in spark project.
|
*/
|
||||||
* @param record
|
|
||||||
*/
|
|
||||||
class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord {
|
class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord {
|
||||||
|
|
||||||
private lazy val decimalConversions = new DecimalConversion()
|
|
||||||
private lazy val sqlType = SchemaConverters.toSqlType(getSchema).dataType.asInstanceOf[StructType]
|
private lazy val sqlType = SchemaConverters.toSqlType(getSchema).dataType.asInstanceOf[StructType]
|
||||||
|
private lazy val avroDeserializer = new AvroDeserializer(record.getSchema, sqlType)
|
||||||
|
private lazy val sqlRow = avroDeserializer.deserialize(record).asInstanceOf[InternalRow]
|
||||||
|
|
||||||
override def put(i: Int, v: Any): Unit = {
|
override def put(i: Int, v: Any): Unit = {
|
||||||
record.put(i, v)
|
record.put(i, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def get(i: Int): AnyRef = {
|
override def get(i: Int): AnyRef = {
|
||||||
val value = record.get(i)
|
sqlRow.get(i, sqlType(i).dataType)
|
||||||
val avroFieldType = getSchema.getFields.get(i).schema()
|
|
||||||
val sqlFieldType = sqlType.fields(i).dataType
|
|
||||||
if (value == null) {
|
|
||||||
null
|
|
||||||
} else {
|
|
||||||
convert(avroFieldType, sqlFieldType, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def convert(avroFieldType: Schema, sqlFieldType: DataType, value: AnyRef): AnyRef = {
|
|
||||||
(avroFieldType.getType, sqlFieldType) match {
|
|
||||||
case (NULL, NullType) => null
|
|
||||||
|
|
||||||
case (BOOLEAN, BooleanType) => value.asInstanceOf[Boolean].asInstanceOf[java.lang.Boolean]
|
|
||||||
|
|
||||||
case (INT, IntegerType) => value.asInstanceOf[Int].asInstanceOf[java.lang.Integer]
|
|
||||||
|
|
||||||
case (INT, DateType) => value.asInstanceOf[Int].asInstanceOf[java.lang.Integer]
|
|
||||||
|
|
||||||
case (LONG, LongType) => value.asInstanceOf[Long].asInstanceOf[java.lang.Long]
|
|
||||||
|
|
||||||
case (LONG, TimestampType) => avroFieldType.getLogicalType match {
|
|
||||||
case _: TimestampMillis => (value.asInstanceOf[Long] * 1000).asInstanceOf[java.lang.Long]
|
|
||||||
case _: TimestampMicros => value.asInstanceOf[Long].asInstanceOf[java.lang.Long]
|
|
||||||
case null =>
|
|
||||||
// For backward compatibility, if the Avro type is Long and it is not logical type,
|
|
||||||
// the value is processed as timestamp type with millisecond precision.
|
|
||||||
java.lang.Long.valueOf(value.asInstanceOf[Long] * 1000)
|
|
||||||
case other => throw new IncompatibleSchemaException(
|
|
||||||
s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date.
|
|
||||||
// For backward compatibility, we still keep this conversion.
|
|
||||||
case (LONG, DateType) =>
|
|
||||||
java.lang.Integer.valueOf((value.asInstanceOf[Long] / SqlTypedRecord.MILLIS_PER_DAY).toInt)
|
|
||||||
|
|
||||||
case (FLOAT, FloatType) => value.asInstanceOf[Float].asInstanceOf[java.lang.Float]
|
|
||||||
|
|
||||||
case (DOUBLE, DoubleType) => value.asInstanceOf[Double].asInstanceOf[java.lang.Double]
|
|
||||||
|
|
||||||
case (STRING, StringType) => value match {
|
|
||||||
case s: String => UTF8String.fromString(s)
|
|
||||||
case s: Utf8 => UTF8String.fromString(s.toString)
|
|
||||||
case o => throw new IllegalArgumentException(s"Cannot convert $o to StringType")
|
|
||||||
}
|
|
||||||
|
|
||||||
case (ENUM, StringType) => value.toString
|
|
||||||
|
|
||||||
case (FIXED, BinaryType) => value.asInstanceOf[GenericFixed].bytes().clone()
|
|
||||||
|
|
||||||
case (BYTES, BinaryType) => value match {
|
|
||||||
case b: ByteBuffer =>
|
|
||||||
val bytes = new Array[Byte](b.remaining)
|
|
||||||
b.get(bytes)
|
|
||||||
bytes
|
|
||||||
case b: Array[Byte] => b
|
|
||||||
case other => throw new RuntimeException(s"$other is not a valid avro binary.")
|
|
||||||
}
|
|
||||||
|
|
||||||
case (FIXED, d: DecimalType) =>
|
|
||||||
val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroFieldType,
|
|
||||||
LogicalTypes.decimal(d.precision, d.scale))
|
|
||||||
createDecimal(bigDecimal, d.precision, d.scale)
|
|
||||||
|
|
||||||
case (BYTES, d: DecimalType) =>
|
|
||||||
val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroFieldType,
|
|
||||||
LogicalTypes.decimal(d.precision, d.scale))
|
|
||||||
createDecimal(bigDecimal, d.precision, d.scale)
|
|
||||||
|
|
||||||
case (RECORD, _: StructType) =>
|
|
||||||
throw new IllegalArgumentException(s"UnSupport StructType yet")
|
|
||||||
|
|
||||||
case (ARRAY, ArrayType(_, _)) =>
|
|
||||||
throw new IllegalArgumentException(s"UnSupport ARRAY type yet")
|
|
||||||
|
|
||||||
case (MAP, MapType(keyType, _, _)) if keyType == StringType =>
|
|
||||||
throw new IllegalArgumentException(s"UnSupport MAP type yet")
|
|
||||||
|
|
||||||
case (UNION, _) =>
|
|
||||||
val allTypes = avroFieldType.getTypes.asScala
|
|
||||||
val nonNullTypes = allTypes.filter(_.getType != NULL)
|
|
||||||
if (nonNullTypes.nonEmpty) {
|
|
||||||
if (nonNullTypes.length == 1) {
|
|
||||||
convert(nonNullTypes.head, sqlFieldType, value)
|
|
||||||
} else {
|
|
||||||
nonNullTypes.map(_.getType) match {
|
|
||||||
case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlFieldType == LongType =>
|
|
||||||
value match {
|
|
||||||
case null => null
|
|
||||||
case l: java.lang.Long => l
|
|
||||||
case i: java.lang.Integer => i.longValue().asInstanceOf[java.lang.Long]
|
|
||||||
}
|
|
||||||
|
|
||||||
case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlFieldType == DoubleType =>
|
|
||||||
value match {
|
|
||||||
case null => null
|
|
||||||
case d: java.lang.Double => d
|
|
||||||
case f: java.lang.Float => f.doubleValue().asInstanceOf[java.lang.Double]
|
|
||||||
}
|
|
||||||
|
|
||||||
case _ =>
|
|
||||||
throw new IllegalArgumentException(s"UnSupport UNION type: ${sqlFieldType}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
case _ =>
|
|
||||||
throw new IncompatibleSchemaException(
|
|
||||||
s"Cannot convert Avro to catalyst because schema " +
|
|
||||||
s"is not compatible (avroType = $avroFieldType, sqlType = $sqlFieldType).\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
|
|
||||||
if (precision <= Decimal.MAX_LONG_DIGITS) {
|
|
||||||
// Constructs a `Decimal` with an unscaled `Long` value if possible.
|
|
||||||
Decimal(decimal.unscaledValue().longValue(), precision, scale)
|
|
||||||
} else {
|
|
||||||
// Otherwise, resorts to an unscaled `BigInteger` instead.
|
|
||||||
Decimal(decimal, precision, scale)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override def getSchema: Schema = record.getSchema
|
override def getSchema: Schema = record.getSchema
|
||||||
}
|
}
|
||||||
|
|
||||||
object SqlTypedRecord {
|
|
||||||
val MILLIS_PER_DAY = 24 * 60 * 60 * 1000L
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
package org.apache.spark.sql.hudi
|
package org.apache.spark.sql.hudi
|
||||||
|
|
||||||
import org.apache.hudi.common.table.HoodieTableMetaClient
|
import org.apache.hudi.common.table.HoodieTableMetaClient
|
||||||
|
import org.apache.spark.sql.Row
|
||||||
|
|
||||||
class TestMergeIntoTable2 extends TestHoodieSqlBase {
|
class TestMergeIntoTable2 extends TestHoodieSqlBase {
|
||||||
|
|
||||||
@@ -172,4 +173,68 @@ class TestMergeIntoTable2 extends TestHoodieSqlBase {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Test Merge With Complex Data Type") {
|
||||||
|
withTempDir{tmp =>
|
||||||
|
val tableName = generateTableName
|
||||||
|
spark.sql(
|
||||||
|
s"""
|
||||||
|
| create table $tableName (
|
||||||
|
| id int,
|
||||||
|
| name string,
|
||||||
|
| s_value struct<f0: int, f1: string>,
|
||||||
|
| a_value array<string>,
|
||||||
|
| m_value map<string, string>,
|
||||||
|
| ts long
|
||||||
|
| ) using hudi
|
||||||
|
| options (
|
||||||
|
| type = 'mor',
|
||||||
|
| primaryKey = 'id',
|
||||||
|
| preCombineField = 'ts'
|
||||||
|
| )
|
||||||
|
| location '${tmp.getCanonicalPath}'
|
||||||
|
""".stripMargin)
|
||||||
|
|
||||||
|
spark.sql(
|
||||||
|
s"""
|
||||||
|
|merge into $tableName h0
|
||||||
|
|using (
|
||||||
|
|select
|
||||||
|
| 1 as id,
|
||||||
|
| 'a1' as name,
|
||||||
|
| struct(1, '10') as s_value,
|
||||||
|
| split('a0,a1', ',') as a_value,
|
||||||
|
| map('k0', 'v0') as m_value,
|
||||||
|
| 1000 as ts
|
||||||
|
|) s0
|
||||||
|
|on h0.id = s0.id
|
||||||
|
|when not matched then insert *
|
||||||
|
|""".stripMargin)
|
||||||
|
|
||||||
|
checkAnswer(s"select id, name, s_value, a_value, m_value, ts from $tableName")(
|
||||||
|
Seq(1, "a1", Row(1, "10"), Seq("a0", "a1"), Map("k0" -> "v0"), 1000)
|
||||||
|
)
|
||||||
|
// update value
|
||||||
|
spark.sql(
|
||||||
|
s"""
|
||||||
|
|merge into $tableName h0
|
||||||
|
|using (
|
||||||
|
|select
|
||||||
|
| 1 as id,
|
||||||
|
| 'a1' as name,
|
||||||
|
| struct(1, '12') as s_value,
|
||||||
|
| split('a0,a1,a2', ',') as a_value,
|
||||||
|
| map('k1', 'v1') as m_value,
|
||||||
|
| 1000 as ts
|
||||||
|
|) s0
|
||||||
|
|on h0.id = s0.id
|
||||||
|
|when matched then update set *
|
||||||
|
|when not matched then insert *
|
||||||
|
|""".stripMargin)
|
||||||
|
checkAnswer(s"select id, name, s_value, a_value, m_value, ts from $tableName")(
|
||||||
|
Seq(1, "a1", Row(1, "12"), Seq("a0", "a1", "a2"), Map("k1" -> "v1"), 1000)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user