1
0

[HUDI-2131] Exception Throw Out When MergeInto With Decimal Type Field (#3224)

This commit is contained in:
pengzhiwei
2021-07-05 22:28:57 +08:00
committed by GitHub
parent e6ee7bdb51
commit 287d2dd79c
5 changed files with 85 additions and 27 deletions

View File

@@ -21,7 +21,9 @@ import java.util.{Base64, Properties}
import java.util.concurrent.Callable
import scala.collection.JavaConverters._
import com.google.common.cache.CacheBuilder
import org.apache.avro.Schema
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.Schema.Type
import org.apache.avro.{LogicalTypes, Schema}
import org.apache.avro.generic.{GenericData, GenericRecord, IndexedRecord}
import org.apache.avro.util.Utf8
import org.apache.hudi.DataSourceWriteOptions._
@@ -33,9 +35,9 @@ import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.io.HoodieWriteHandle
import org.apache.hudi.sql.IExpressionEvaluator
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.hudi.SerDeUtils
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.getEvaluator
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.UTF8String
import scala.collection.mutable.ArrayBuffer
@@ -84,7 +86,8 @@ class ExpressionPayload(record: GenericRecord,
var resultRecordOpt: HOption[IndexedRecord] = null
// Get the Evaluator for each condition and update assignments.
val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString)
initWriteSchemaIfNeed(properties)
val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString, writeSchema)
for ((conditionEvaluator, assignmentEvaluator) <- updateConditionAndAssignments
if resultRecordOpt == null) {
val conditionVal = evaluate(conditionEvaluator, joinSqlRecord).head.asInstanceOf[Boolean]
@@ -92,7 +95,6 @@ class ExpressionPayload(record: GenericRecord,
// to compute final record to update. We will return the first matched record.
if (conditionVal) {
val results = evaluate(assignmentEvaluator, joinSqlRecord)
initWriteSchemaIfNeed(properties)
val resultRecord = convertToRecord(results, writeSchema)
if (needUpdatingPersistedRecord(targetRecord, resultRecord, properties)) {
@@ -108,7 +110,7 @@ class ExpressionPayload(record: GenericRecord,
// Process delete
val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION)
if (deleteConditionText != null) {
val deleteCondition = getEvaluator(deleteConditionText.toString).head._1
val deleteCondition = getEvaluator(deleteConditionText.toString, writeSchema).head._1
val deleteConditionVal = evaluate(deleteCondition, joinSqlRecord).head.asInstanceOf[Boolean]
if (deleteConditionVal) {
resultRecordOpt = HOption.empty()
@@ -134,8 +136,9 @@ class ExpressionPayload(record: GenericRecord,
// Process insert
val sqlTypedRecord = new SqlTypedRecord(incomingRecord)
// Get the evaluator for each condition and insert assignment.
initWriteSchemaIfNeed(properties)
val insertConditionAndAssignments =
ExpressionPayload.getEvaluator(insertConditionAndAssignmentsText.toString)
ExpressionPayload.getEvaluator(insertConditionAndAssignmentsText.toString, writeSchema)
var resultRecordOpt: HOption[IndexedRecord] = null
for ((conditionEvaluator, assignmentEvaluator) <- insertConditionAndAssignments
if resultRecordOpt == null) {
@@ -144,7 +147,6 @@ class ExpressionPayload(record: GenericRecord,
// result record. We will return the first matched record.
if (conditionVal) {
val results = evaluate(assignmentEvaluator, sqlTypedRecord)
initWriteSchemaIfNeed(properties)
resultRecordOpt = HOption.of(convertToRecord(results, writeSchema))
}
}
@@ -153,7 +155,7 @@ class ExpressionPayload(record: GenericRecord,
if (resultRecordOpt == null && isMORTable(properties)) {
val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION)
if (deleteConditionText != null) {
val deleteCondition = getEvaluator(deleteConditionText.toString).head._1
val deleteCondition = getEvaluator(deleteConditionText.toString, writeSchema).head._1
val deleteConditionVal = evaluate(deleteCondition, sqlTypedRecord).head.asInstanceOf[Boolean]
if (deleteConditionVal) {
resultRecordOpt = HOption.empty()
@@ -269,19 +271,19 @@ object ExpressionPayload {
* @return
*/
def getEvaluator(
serializedConditionAssignments: String): Map[IExpressionEvaluator, IExpressionEvaluator] = {
serializedConditionAssignments: String, writeSchema: Schema): Map[IExpressionEvaluator, IExpressionEvaluator] = {
cache.get(serializedConditionAssignments,
new Callable[Map[IExpressionEvaluator, IExpressionEvaluator]] {
override def call(): Map[IExpressionEvaluator, IExpressionEvaluator] = {
val serializedBytes = Base64.getDecoder.decode(serializedConditionAssignments)
val conditionAssignments = SerDeUtils.toObject(serializedBytes)
.asInstanceOf[Map[Expression, Seq[Assignment]]]
.asInstanceOf[Map[Expression, Seq[Expression]]]
// Do the CodeGen for condition expression and assignment expression
conditionAssignments.map {
case (condition, assignments) =>
val conditionEvaluator = ExpressionCodeGen.doCodeGen(Seq(condition))
val assignmentEvaluator = StringConvertEvaluator(ExpressionCodeGen.doCodeGen(assignments))
val assignmentEvaluator = AvroTypeConvertEvaluator(ExpressionCodeGen.doCodeGen(assignments), writeSchema)
conditionEvaluator -> assignmentEvaluator
}
}
@@ -289,17 +291,29 @@ object ExpressionPayload {
}
/**
* As the "baseEvaluator" return "UTF8String" for the string type which cannot be process by
* the Avro, The StringConvertEvaluator will convert the "UTF8String" to "Utf8".
* A IExpressionEvaluator wrapped the base evaluator which convert the result of the base evaluator
* to the avro typed-value.
*/
case class StringConvertEvaluator(baseEvaluator: IExpressionEvaluator) extends IExpressionEvaluator {
case class AvroTypeConvertEvaluator(baseEvaluator: IExpressionEvaluator, writeSchema: Schema) extends IExpressionEvaluator {
private lazy val decimalConversions = new DecimalConversion()
/**
* Convert the UTF8String to Utf8
* Convert to the avro typed-value.
* e.g. convert UTF8String -> Utf8, Dicimal -> GenericFixed.
*/
override def eval(record: IndexedRecord): Array[AnyRef] = {
baseEvaluator.eval(record).map {
case s: UTF8String => new Utf8(s.toString)
case o => o
baseEvaluator.eval(record).zipWithIndex.map {
case (s: UTF8String, _) => new Utf8(s.toString)
case (d: Decimal, i) =>
val schema = writeSchema.getFields.get(i).schema()
val fixedSchema = if (schema.getType == Type.UNION) {
schema.getTypes.asScala.filter(s => s.getType != Type.NULL).head
} else {
schema
}
decimalConversions.toFixed(d.toJavaBigDecimal, fixedSchema
, LogicalTypes.decimal(d.precision, d.scale))
case (o, _) => o
}
}

View File

@@ -27,7 +27,6 @@ import org.apache.avro.generic.{GenericFixed, IndexedRecord}
import org.apache.avro.util.Utf8
import org.apache.avro.{LogicalTypes, Schema}
import org.apache.spark.sql.avro.{IncompatibleSchemaException, SchemaConverters}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

View File

@@ -86,4 +86,11 @@ class TestHoodieSqlBase extends FunSuite with BeforeAndAfterAll {
assertResult(errorMsg)(e.getMessage)
}
}
protected def removeQuotes(value: Any): Any = {
value match {
case s: String => s.stripPrefix("'").stripSuffix("'")
case _=> value
}
}
}

View File

@@ -277,11 +277,4 @@ class TestInsertTable extends TestHoodieSqlBase {
" count: 3columns: (1,a1,10)"
)
}
private def removeQuotes(value: Any): Any = {
value match {
case s: String => s.stripPrefix("'").stripSuffix("'")
case _=> value
}
}
}

View File

@@ -20,7 +20,6 @@ package org.apache.spark.sql.hudi
import org.apache.hudi.{DataSourceReadOptions, HoodieDataSourceHelpers}
import org.apache.hudi.common.fs.FSUtils
class TestMergeIntoTable extends TestHoodieSqlBase {
test("Test MergeInto Basic") {
@@ -687,4 +686,50 @@ class TestMergeIntoTable extends TestHoodieSqlBase {
}
}
}
test("Test MereInto With All Kinds Of DataType") {
withTempDir { tmp =>
val dataAndTypes = Seq(
("string", "'a1'"),
("int", "10"),
("bigint", "10"),
("double", "10.0"),
("float", "10.0"),
("decimal(5,2)", "10.11"),
("decimal(5,0)", "10"),
("timestamp", "'2021-05-20 00:00:00'"),
("date", "'2021-05-20'")
)
dataAndTypes.foreach { case (dataType, dataValue) =>
val tableName = generateTableName
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| value $dataType,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| options (
| primaryKey ='id',
| preCombineField = 'ts'
| )
""".stripMargin)
spark.sql(
s"""
|merge into $tableName h0
|using (
| select 1 as id, 'a1' as name, cast($dataValue as $dataType) as value, 1000 as ts
| ) s0
| on h0.id = s0.id
| when not matched then insert *
|""".stripMargin)
checkAnswer(s"select id, name, cast(value as string), ts from $tableName")(
Seq(1, "a1", removeQuotes(dataValue), 1000)
)
}
}
}
}