[HUDI-2131] Exception Throw Out When MergeInto With Decimal Type Field (#3224)
This commit is contained in:
@@ -21,7 +21,9 @@ import java.util.{Base64, Properties}
|
||||
import java.util.concurrent.Callable
|
||||
import scala.collection.JavaConverters._
|
||||
import com.google.common.cache.CacheBuilder
|
||||
import org.apache.avro.Schema
|
||||
import org.apache.avro.Conversions.DecimalConversion
|
||||
import org.apache.avro.Schema.Type
|
||||
import org.apache.avro.{LogicalTypes, Schema}
|
||||
import org.apache.avro.generic.{GenericData, GenericRecord, IndexedRecord}
|
||||
import org.apache.avro.util.Utf8
|
||||
import org.apache.hudi.DataSourceWriteOptions._
|
||||
@@ -33,9 +35,9 @@ import org.apache.hudi.config.HoodieWriteConfig
|
||||
import org.apache.hudi.io.HoodieWriteHandle
|
||||
import org.apache.hudi.sql.IExpressionEvaluator
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Assignment
|
||||
import org.apache.spark.sql.hudi.SerDeUtils
|
||||
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.getEvaluator
|
||||
import org.apache.spark.sql.types.Decimal
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
@@ -84,7 +86,8 @@ class ExpressionPayload(record: GenericRecord,
|
||||
var resultRecordOpt: HOption[IndexedRecord] = null
|
||||
|
||||
// Get the Evaluator for each condition and update assignments.
|
||||
val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString)
|
||||
initWriteSchemaIfNeed(properties)
|
||||
val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString, writeSchema)
|
||||
for ((conditionEvaluator, assignmentEvaluator) <- updateConditionAndAssignments
|
||||
if resultRecordOpt == null) {
|
||||
val conditionVal = evaluate(conditionEvaluator, joinSqlRecord).head.asInstanceOf[Boolean]
|
||||
@@ -92,7 +95,6 @@ class ExpressionPayload(record: GenericRecord,
|
||||
// to compute final record to update. We will return the first matched record.
|
||||
if (conditionVal) {
|
||||
val results = evaluate(assignmentEvaluator, joinSqlRecord)
|
||||
initWriteSchemaIfNeed(properties)
|
||||
val resultRecord = convertToRecord(results, writeSchema)
|
||||
|
||||
if (needUpdatingPersistedRecord(targetRecord, resultRecord, properties)) {
|
||||
@@ -108,7 +110,7 @@ class ExpressionPayload(record: GenericRecord,
|
||||
// Process delete
|
||||
val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION)
|
||||
if (deleteConditionText != null) {
|
||||
val deleteCondition = getEvaluator(deleteConditionText.toString).head._1
|
||||
val deleteCondition = getEvaluator(deleteConditionText.toString, writeSchema).head._1
|
||||
val deleteConditionVal = evaluate(deleteCondition, joinSqlRecord).head.asInstanceOf[Boolean]
|
||||
if (deleteConditionVal) {
|
||||
resultRecordOpt = HOption.empty()
|
||||
@@ -134,8 +136,9 @@ class ExpressionPayload(record: GenericRecord,
|
||||
// Process insert
|
||||
val sqlTypedRecord = new SqlTypedRecord(incomingRecord)
|
||||
// Get the evaluator for each condition and insert assignment.
|
||||
initWriteSchemaIfNeed(properties)
|
||||
val insertConditionAndAssignments =
|
||||
ExpressionPayload.getEvaluator(insertConditionAndAssignmentsText.toString)
|
||||
ExpressionPayload.getEvaluator(insertConditionAndAssignmentsText.toString, writeSchema)
|
||||
var resultRecordOpt: HOption[IndexedRecord] = null
|
||||
for ((conditionEvaluator, assignmentEvaluator) <- insertConditionAndAssignments
|
||||
if resultRecordOpt == null) {
|
||||
@@ -144,7 +147,6 @@ class ExpressionPayload(record: GenericRecord,
|
||||
// result record. We will return the first matched record.
|
||||
if (conditionVal) {
|
||||
val results = evaluate(assignmentEvaluator, sqlTypedRecord)
|
||||
initWriteSchemaIfNeed(properties)
|
||||
resultRecordOpt = HOption.of(convertToRecord(results, writeSchema))
|
||||
}
|
||||
}
|
||||
@@ -153,7 +155,7 @@ class ExpressionPayload(record: GenericRecord,
|
||||
if (resultRecordOpt == null && isMORTable(properties)) {
|
||||
val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION)
|
||||
if (deleteConditionText != null) {
|
||||
val deleteCondition = getEvaluator(deleteConditionText.toString).head._1
|
||||
val deleteCondition = getEvaluator(deleteConditionText.toString, writeSchema).head._1
|
||||
val deleteConditionVal = evaluate(deleteCondition, sqlTypedRecord).head.asInstanceOf[Boolean]
|
||||
if (deleteConditionVal) {
|
||||
resultRecordOpt = HOption.empty()
|
||||
@@ -269,19 +271,19 @@ object ExpressionPayload {
|
||||
* @return
|
||||
*/
|
||||
def getEvaluator(
|
||||
serializedConditionAssignments: String): Map[IExpressionEvaluator, IExpressionEvaluator] = {
|
||||
serializedConditionAssignments: String, writeSchema: Schema): Map[IExpressionEvaluator, IExpressionEvaluator] = {
|
||||
cache.get(serializedConditionAssignments,
|
||||
new Callable[Map[IExpressionEvaluator, IExpressionEvaluator]] {
|
||||
|
||||
override def call(): Map[IExpressionEvaluator, IExpressionEvaluator] = {
|
||||
val serializedBytes = Base64.getDecoder.decode(serializedConditionAssignments)
|
||||
val conditionAssignments = SerDeUtils.toObject(serializedBytes)
|
||||
.asInstanceOf[Map[Expression, Seq[Assignment]]]
|
||||
.asInstanceOf[Map[Expression, Seq[Expression]]]
|
||||
// Do the CodeGen for condition expression and assignment expression
|
||||
conditionAssignments.map {
|
||||
case (condition, assignments) =>
|
||||
val conditionEvaluator = ExpressionCodeGen.doCodeGen(Seq(condition))
|
||||
val assignmentEvaluator = StringConvertEvaluator(ExpressionCodeGen.doCodeGen(assignments))
|
||||
val assignmentEvaluator = AvroTypeConvertEvaluator(ExpressionCodeGen.doCodeGen(assignments), writeSchema)
|
||||
conditionEvaluator -> assignmentEvaluator
|
||||
}
|
||||
}
|
||||
@@ -289,17 +291,29 @@ object ExpressionPayload {
|
||||
}
|
||||
|
||||
/**
|
||||
* As the "baseEvaluator" return "UTF8String" for the string type which cannot be process by
|
||||
* the Avro, The StringConvertEvaluator will convert the "UTF8String" to "Utf8".
|
||||
* A IExpressionEvaluator wrapped the base evaluator which convert the result of the base evaluator
|
||||
* to the avro typed-value.
|
||||
*/
|
||||
case class StringConvertEvaluator(baseEvaluator: IExpressionEvaluator) extends IExpressionEvaluator {
|
||||
case class AvroTypeConvertEvaluator(baseEvaluator: IExpressionEvaluator, writeSchema: Schema) extends IExpressionEvaluator {
|
||||
private lazy val decimalConversions = new DecimalConversion()
|
||||
|
||||
/**
|
||||
* Convert the UTF8String to Utf8
|
||||
* Convert to the avro typed-value.
|
||||
* e.g. convert UTF8String -> Utf8, Dicimal -> GenericFixed.
|
||||
*/
|
||||
override def eval(record: IndexedRecord): Array[AnyRef] = {
|
||||
baseEvaluator.eval(record).map {
|
||||
case s: UTF8String => new Utf8(s.toString)
|
||||
case o => o
|
||||
baseEvaluator.eval(record).zipWithIndex.map {
|
||||
case (s: UTF8String, _) => new Utf8(s.toString)
|
||||
case (d: Decimal, i) =>
|
||||
val schema = writeSchema.getFields.get(i).schema()
|
||||
val fixedSchema = if (schema.getType == Type.UNION) {
|
||||
schema.getTypes.asScala.filter(s => s.getType != Type.NULL).head
|
||||
} else {
|
||||
schema
|
||||
}
|
||||
decimalConversions.toFixed(d.toJavaBigDecimal, fixedSchema
|
||||
, LogicalTypes.decimal(d.precision, d.scale))
|
||||
case (o, _) => o
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ import org.apache.avro.generic.{GenericFixed, IndexedRecord}
|
||||
import org.apache.avro.util.Utf8
|
||||
import org.apache.avro.{LogicalTypes, Schema}
|
||||
import org.apache.spark.sql.avro.{IncompatibleSchemaException, SchemaConverters}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
||||
Reference in New Issue
Block a user