1
0

[HUDI-4285] add ByteBuffer#rewind after ByteBuffer#get in AvroDeseria… (#5907)

* [HUDI-4285] add ByteBuffer#rewind after ByteBuffer#get in AvroDeserializer

* add ut

Co-authored-by: wangzixuan.wzxuan <wangzixuan.wzxuan@bytedance.com>
This commit is contained in:
komao
2022-06-30 20:48:50 +08:00
committed by GitHub
parent cdaaa3c4c7
commit 8547899a39
4 changed files with 62 additions and 1 deletions

View File

@@ -18,8 +18,13 @@
package org.apache.hudi
import java.nio.ByteBuffer
import java.util.Objects
import org.apache.avro.Schema
import org.apache.spark.sql.types.{DataTypes, StructType, StringType, ArrayType}
import org.apache.avro.generic.GenericData
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, DataTypes, MapType, StringType, StructField, StructType}
import org.scalatest.{FunSuite, Matchers}
class TestAvroConversionUtils extends FunSuite with Matchers {
@@ -377,4 +382,54 @@ class TestAvroConversionUtils extends FunSuite with Matchers {
assert(avroSchema.equals(expectedAvroSchema))
}
test("test converter with binary") {
val avroSchema = new Schema.Parser().parse("{\"type\":\"record\",\"name\":\"h0_record\",\"namespace\":\"hoodie.h0\",\"fields\""
+ ":[{\"name\":\"col9\",\"type\":[\"null\",\"bytes\"],\"default\":null}]}")
val sparkSchema = StructType(List(StructField("col9", BinaryType, nullable = true)))
// create a test record with avroSchema
val avroRecord = new GenericData.Record(avroSchema)
val bb = ByteBuffer.wrap(Array[Byte](97, 48, 53))
avroRecord.put("col9", bb)
val row1 = AvroConversionUtils.createAvroToInternalRowConverter(avroSchema, sparkSchema).apply(avroRecord).get
val row2 = AvroConversionUtils.createAvroToInternalRowConverter(avroSchema, sparkSchema).apply(avroRecord).get
internalRowCompare(row1, row2, sparkSchema)
}
private def internalRowCompare(expected: Any, actual: Any, schema: DataType): Unit = {
schema match {
case StructType(fields) =>
val expectedRow = expected.asInstanceOf[InternalRow]
val actualRow = actual.asInstanceOf[InternalRow]
fields.zipWithIndex.foreach { case (field, i) => internalRowCompare(expectedRow.get(i, field.dataType), actualRow.get(i, field.dataType), field.dataType) }
case ArrayType(elementType, _) =>
val expectedArray = expected.asInstanceOf[ArrayData].toSeq[Any](elementType)
val actualArray = actual.asInstanceOf[ArrayData].toSeq[Any](elementType)
if (expectedArray.size != actualArray.size) {
throw new AssertionError()
} else {
expectedArray.zip(actualArray).foreach { case (e1, e2) => internalRowCompare(e1, e2, elementType) }
}
case MapType(keyType, valueType, _) =>
val expectedKeyArray = expected.asInstanceOf[MapData].keyArray()
val expectedValueArray = expected.asInstanceOf[MapData].valueArray()
val actualKeyArray = actual.asInstanceOf[MapData].keyArray()
val actualValueArray = actual.asInstanceOf[MapData].valueArray()
internalRowCompare(expectedKeyArray, actualKeyArray, ArrayType(keyType))
internalRowCompare(expectedValueArray, actualValueArray, ArrayType(valueType))
case StringType => if (checkNull(expected, actual) || !expected.toString.equals(actual.toString)) {
throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString))
}
case BinaryType => if (checkNull(expected, actual) || !expected.asInstanceOf[Array[Byte]].sameElements(actual.asInstanceOf[Array[Byte]])) {
throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString))
}
case _ => if (!Objects.equals(expected, actual)) {
throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString))
}
}
}
private def checkNull(left: Any, right: Any): Boolean = {
(left == null && right != null) || (left == null && right != null)
}
}

View File

@@ -146,6 +146,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
case b: ByteBuffer =>
val bytes = new Array[Byte](b.remaining)
b.get(bytes)
// Do not forget to reset the position
b.rewind()
bytes
case b: Array[Byte] => b
case other => throw new RuntimeException(s"$other is not a valid avro binary.")

View File

@@ -167,6 +167,8 @@ private[sql] class AvroDeserializer(rootAvroType: Schema,
case b: ByteBuffer =>
val bytes = new Array[Byte](b.remaining)
b.get(bytes)
// Do not forget to reset the position
b.rewind()
bytes
case b: Array[Byte] => b
case other => throw new RuntimeException(s"$other is not a valid avro binary.")

View File

@@ -181,6 +181,8 @@ private[sql] class AvroDeserializer(rootAvroType: Schema,
case b: ByteBuffer =>
val bytes = new Array[Byte](b.remaining)
b.get(bytes)
// Do not forget to reset the position
b.rewind()
bytes
case b: Array[Byte] => b
case other =>