/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.uber.hoodie import java.nio.ByteBuffer import java.sql.{Date, Timestamp} import java.util import com.databricks.spark.avro.SchemaConverters import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException import org.apache.avro.Schema.Type._ import org.apache.avro.generic.GenericData.{Fixed, Record} import org.apache.avro.generic.{GenericData, GenericRecord} import org.apache.avro.{Schema, SchemaBuilder} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import scala.collection.JavaConverters._ object AvroConversionUtils { def createRdd(df: DataFrame, structName: String, recordNamespace: String): RDD[GenericRecord] = { val dataType = df.schema val encoder = RowEncoder.apply(dataType).resolveAndBind() df.queryExecution.toRdd.map(encoder.fromRow) .mapPartitions { records => if (records.isEmpty) Iterator.empty else { val convertor = createConverterToAvro(dataType, structName, recordNamespace) records.map { x => convertor(x).asInstanceOf[GenericRecord] } } } } def createDataFrame(rdd: RDD[GenericRecord], schemaStr: String, ss : SparkSession): Dataset[Row] = { if (rdd.isEmpty()) { ss.emptyDataFrame } else { ss.createDataFrame(rdd.mapPartitions { records => if (records.isEmpty) Iterator.empty else { val schema = Schema.parse(schemaStr) val dataType = convertAvroSchemaToStructType(schema) val convertor = createConverterToRow(schema, dataType) records.map { x => convertor(x).asInstanceOf[Row] } } }, convertAvroSchemaToStructType(Schema.parse(schemaStr))).asInstanceOf[Dataset[Row]] } } def getNewRecordNamespace(elementDataType: DataType, currentRecordNamespace: String, elementName: String): String = { elementDataType match { case StructType(_) => s"$currentRecordNamespace.$elementName" case _ => currentRecordNamespace } } /** * NOTE : This part of code is copied from com.databricks.spark.avro.SchemaConverters.scala (133:310) (spark-avro) * * Returns a converter function to convert row in avro format to GenericRow of catalyst. * * @param sourceAvroSchema Source schema before conversion inferred from avro file by passed in * by user. * @param targetSqlType Target catalyst sql type after the conversion. * @return returns a converter function to convert row in avro format to GenericRow of catalyst. */ def createConverterToRow(sourceAvroSchema: Schema, targetSqlType: DataType): AnyRef => AnyRef = { def createConverter(avroSchema: Schema, sqlType: DataType, path: List[String]): AnyRef => AnyRef = { val avroType = avroSchema.getType (sqlType, avroType) match { // Avro strings are in Utf8, so we have to call toString on them case (StringType, STRING) | (StringType, ENUM) => (item: AnyRef) => if (item == null) null else item.toString // Byte arrays are reused by avro, so we have to make a copy of them. case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) | (FloatType, FLOAT) | (LongType, LONG) => identity case (BinaryType, FIXED) => (item: AnyRef) => if (item == null) { null } else { item.asInstanceOf[Fixed].bytes().clone() } case (BinaryType, BYTES) => (item: AnyRef) => if (item == null) { null } else { val byteBuffer = item.asInstanceOf[ByteBuffer] val bytes = new Array[Byte](byteBuffer.remaining) byteBuffer.get(bytes) bytes } case (struct: StructType, RECORD) => val length = struct.fields.length val converters = new Array[AnyRef => AnyRef](length) val avroFieldIndexes = new Array[Int](length) var i = 0 while (i < length) { val sqlField = struct.fields(i) val avroField = avroSchema.getField(sqlField.name) if (avroField != null) { val converter = createConverter(avroField.schema(), sqlField.dataType, path :+ sqlField.name) converters(i) = converter avroFieldIndexes(i) = avroField.pos() } else if (!sqlField.nullable) { throw new IncompatibleSchemaException( s"Cannot find non-nullable field ${sqlField.name} at path ${path.mkString(".")} " + "in Avro schema\n" + s"Source Avro schema: $sourceAvroSchema.\n" + s"Target Catalyst type: $targetSqlType") } i += 1 } (item: AnyRef) => { if (item == null) { null } else { val record = item.asInstanceOf[GenericRecord] val result = new Array[Any](length) var i = 0 while (i < converters.length) { if (converters(i) != null) { val converter = converters(i) result(i) = converter(record.get(avroFieldIndexes(i))) } i += 1 } new GenericRow(result) } } case (arrayType: ArrayType, ARRAY) => val elementConverter = createConverter(avroSchema.getElementType, arrayType.elementType, path) val allowsNull = arrayType.containsNull (item: AnyRef) => { if (item == null) { null } else { item.asInstanceOf[java.lang.Iterable[AnyRef]].asScala.map { element => if (element == null && !allowsNull) { throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + "allowed to be null") } else { elementConverter(element) } } } } case (mapType: MapType, MAP) if mapType.keyType == StringType => val valueConverter = createConverter(avroSchema.getValueType, mapType.valueType, path) val allowsNull = mapType.valueContainsNull (item: AnyRef) => { if (item == null) { null } else { item.asInstanceOf[java.util.Map[AnyRef, AnyRef]].asScala.map { x => if (x._2 == null && !allowsNull) { throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + "allowed to be null") } else { (x._1.toString, valueConverter(x._2)) } }.toMap } } case (sqlType, UNION) => if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) if (remainingUnionTypes.size == 1) { createConverter(remainingUnionTypes.head, sqlType, path) } else { createConverter(Schema.createUnion(remainingUnionTypes.asJava), sqlType, path) } } else avroSchema.getTypes.asScala.map(_.getType) match { case Seq(t1) => createConverter(avroSchema.getTypes.get(0), sqlType, path) case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlType == LongType => (item: AnyRef) => { item match { case null => null case l: java.lang.Long => l case i: java.lang.Integer => new java.lang.Long(i.longValue()) } } case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlType == DoubleType => (item: AnyRef) => { item match { case null => null case d: java.lang.Double => d case f: java.lang.Float => new java.lang.Double(f.doubleValue()) } } case other => sqlType match { case t: StructType if t.fields.length == avroSchema.getTypes.size => val fieldConverters = t.fields.zip(avroSchema.getTypes.asScala).map { case (field, schema) => createConverter(schema, field.dataType, path :+ field.name) } (item: AnyRef) => if (item == null) { null } else { val i = GenericData.get().resolveUnion(avroSchema, item) val converted = new Array[Any](fieldConverters.length) converted(i) = fieldConverters(i)(item) new GenericRow(converted) } case _ => throw new IncompatibleSchemaException( s"Cannot convert Avro schema to catalyst type because schema at path " + s"${path.mkString(".")} is not compatible " + s"(avroType = $other, sqlType = $sqlType). \n" + s"Source Avro schema: $sourceAvroSchema.\n" + s"Target Catalyst type: $targetSqlType") } } case (left, right) => throw new IncompatibleSchemaException( s"Cannot convert Avro schema to catalyst type because schema at path " + s"${path.mkString(".")} is not compatible (avroType = $left, sqlType = $right). \n" + s"Source Avro schema: $sourceAvroSchema.\n" + s"Target Catalyst type: $targetSqlType") } } createConverter(sourceAvroSchema, targetSqlType, List.empty[String]) } def createConverterToAvro(dataType: DataType, structName: String, recordNamespace: String): Any => Any = { dataType match { case BinaryType => (item: Any) => item match { case null => null case bytes: Array[Byte] => ByteBuffer.wrap(bytes) } case IntegerType | LongType | FloatType | DoubleType | StringType | BooleanType => identity case ByteType => (item: Any) => if (item == null) null else item.asInstanceOf[Byte].intValue case ShortType => (item: Any) => if (item == null) null else item.asInstanceOf[Short].intValue case _: DecimalType => (item: Any) => if (item == null) null else item.toString case TimestampType => (item: Any) => if (item == null) null else item.asInstanceOf[Timestamp].getTime case DateType => (item: Any) => if (item == null) null else item.asInstanceOf[Date].getTime case ArrayType(elementType, _) => val elementConverter = createConverterToAvro( elementType, structName, getNewRecordNamespace(elementType, recordNamespace, structName)) (item: Any) => { if (item == null) { null } else { val sourceArray = item.asInstanceOf[Seq[Any]] val sourceArraySize = sourceArray.size val targetList = new util.ArrayList[Any](sourceArraySize) var idx = 0 while (idx < sourceArraySize) { targetList.add(elementConverter(sourceArray(idx))) idx += 1 } targetList } } case MapType(StringType, valueType, _) => val valueConverter = createConverterToAvro( valueType, structName, getNewRecordNamespace(valueType, recordNamespace, structName)) (item: Any) => { if (item == null) { null } else { val javaMap = new util.HashMap[String, Any]() item.asInstanceOf[Map[String, Any]].foreach { case (key, value) => javaMap.put(key, valueConverter(value)) } javaMap } } case structType: StructType => val builder = SchemaBuilder.record(structName).namespace(recordNamespace) val schema: Schema = SchemaConverters.convertStructToAvro( structType, builder, recordNamespace) val fieldConverters = structType.fields.map(field => createConverterToAvro( field.dataType, field.name, getNewRecordNamespace(field.dataType, recordNamespace, field.name))) (item: Any) => { if (item == null) { null } else { val record = new Record(schema) val convertersIterator = fieldConverters.iterator val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator val rowIterator = item.asInstanceOf[Row].toSeq.iterator while (convertersIterator.hasNext) { val converter = convertersIterator.next() record.put(fieldNamesIterator.next(), converter(rowIterator.next())) } record } } } } def convertStructTypeToAvroSchema(structType: StructType, structName: String, recordNamespace: String): Schema = { val builder = SchemaBuilder.record(structName).namespace(recordNamespace) SchemaConverters.convertStructToAvro(structType, builder, recordNamespace) } def convertAvroSchemaToStructType(avroSchema: Schema): StructType = { SchemaConverters.toSqlType(avroSchema).dataType.asInstanceOf[StructType]; } }