From 8547899a39168c8399d32ee7ee22f35dfe3f7c84 Mon Sep 17 00:00:00 2001 From: komao Date: Thu, 30 Jun 2022 20:48:50 +0800 Subject: [PATCH] =?UTF-8?q?[HUDI-4285]=20add=20ByteBuffer#rewind=20after?= =?UTF-8?q?=20ByteBuffer#get=20in=20AvroDeseria=E2=80=A6=20(#5907)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [HUDI-4285] add ByteBuffer#rewind after ByteBuffer#get in AvroDeserializer * add ut Co-authored-by: wangzixuan.wzxuan --- .../apache/hudi/TestAvroConversionUtils.scala | 57 ++++++++++++++++++- .../spark/sql/avro/AvroDeserializer.scala | 2 + .../spark/sql/avro/AvroDeserializer.scala | 2 + .../spark/sql/avro/AvroDeserializer.scala | 2 + 4 files changed, 62 insertions(+), 1 deletion(-) diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala index bacd44753..16df1f869 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala @@ -18,8 +18,13 @@ package org.apache.hudi +import java.nio.ByteBuffer +import java.util.Objects import org.apache.avro.Schema -import org.apache.spark.sql.types.{DataTypes, StructType, StringType, ArrayType} +import org.apache.avro.generic.GenericData +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, DataTypes, MapType, StringType, StructField, StructType} import org.scalatest.{FunSuite, Matchers} class TestAvroConversionUtils extends FunSuite with Matchers { @@ -377,4 +382,54 @@ class TestAvroConversionUtils extends FunSuite with Matchers { assert(avroSchema.equals(expectedAvroSchema)) } + + test("test converter with binary") { + val avroSchema = new Schema.Parser().parse("{\"type\":\"record\",\"name\":\"h0_record\",\"namespace\":\"hoodie.h0\",\"fields\"" + + ":[{\"name\":\"col9\",\"type\":[\"null\",\"bytes\"],\"default\":null}]}") + val sparkSchema = StructType(List(StructField("col9", BinaryType, nullable = true))) + // create a test record with avroSchema + val avroRecord = new GenericData.Record(avroSchema) + val bb = ByteBuffer.wrap(Array[Byte](97, 48, 53)) + avroRecord.put("col9", bb) + val row1 = AvroConversionUtils.createAvroToInternalRowConverter(avroSchema, sparkSchema).apply(avroRecord).get + val row2 = AvroConversionUtils.createAvroToInternalRowConverter(avroSchema, sparkSchema).apply(avroRecord).get + internalRowCompare(row1, row2, sparkSchema) + } + + private def internalRowCompare(expected: Any, actual: Any, schema: DataType): Unit = { + schema match { + case StructType(fields) => + val expectedRow = expected.asInstanceOf[InternalRow] + val actualRow = actual.asInstanceOf[InternalRow] + fields.zipWithIndex.foreach { case (field, i) => internalRowCompare(expectedRow.get(i, field.dataType), actualRow.get(i, field.dataType), field.dataType) } + case ArrayType(elementType, _) => + val expectedArray = expected.asInstanceOf[ArrayData].toSeq[Any](elementType) + val actualArray = actual.asInstanceOf[ArrayData].toSeq[Any](elementType) + if (expectedArray.size != actualArray.size) { + throw new AssertionError() + } else { + expectedArray.zip(actualArray).foreach { case (e1, e2) => internalRowCompare(e1, e2, elementType) } + } + case MapType(keyType, valueType, _) => + val expectedKeyArray = expected.asInstanceOf[MapData].keyArray() + val expectedValueArray = expected.asInstanceOf[MapData].valueArray() + val actualKeyArray = actual.asInstanceOf[MapData].keyArray() + val actualValueArray = actual.asInstanceOf[MapData].valueArray() + internalRowCompare(expectedKeyArray, actualKeyArray, ArrayType(keyType)) + internalRowCompare(expectedValueArray, actualValueArray, ArrayType(valueType)) + case StringType => if (checkNull(expected, actual) || !expected.toString.equals(actual.toString)) { + throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString)) + } + case BinaryType => if (checkNull(expected, actual) || !expected.asInstanceOf[Array[Byte]].sameElements(actual.asInstanceOf[Array[Byte]])) { + throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString)) + } + case _ => if (!Objects.equals(expected, actual)) { + throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString)) + } + } + } + + private def checkNull(left: Any, right: Any): Boolean = { + (left == null && right != null) || (left == null && right != null) + } } diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 2e0946f1e..385577dd3 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -146,6 +146,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { case b: ByteBuffer => val bytes = new Array[Byte](b.remaining) b.get(bytes) + // Do not forget to reset the position + b.rewind() bytes case b: Array[Byte] => b case other => throw new RuntimeException(s"$other is not a valid avro binary.") diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 717df0f40..5fb6d907b 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -167,6 +167,8 @@ private[sql] class AvroDeserializer(rootAvroType: Schema, case b: ByteBuffer => val bytes = new Array[Byte](b.remaining) b.get(bytes) + // Do not forget to reset the position + b.rewind() bytes case b: Array[Byte] => b case other => throw new RuntimeException(s"$other is not a valid avro binary.") diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index ef9b59092..0b6093307 100644 --- a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -181,6 +181,8 @@ private[sql] class AvroDeserializer(rootAvroType: Schema, case b: ByteBuffer => val bytes = new Array[Byte](b.remaining) b.get(bytes) + // Do not forget to reset the position + b.rewind() bytes case b: Array[Byte] => b case other =>