354 lines
14 KiB
Scala
354 lines
14 KiB
Scala
/*
|
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
* contributor license agreements. See the NOTICE file distributed with
|
|
* this work for additional information regarding copyright ownership.
|
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
* (the "License"); you may not use this file except in compliance with
|
|
* the License. You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package com.uber.hoodie
|
|
|
|
import java.nio.ByteBuffer
|
|
import java.sql.{Date, Timestamp}
|
|
import java.util
|
|
|
|
import com.databricks.spark.avro.SchemaConverters
|
|
import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException
|
|
import org.apache.avro.Schema.Type._
|
|
import org.apache.avro.generic.GenericData.{Fixed, Record}
|
|
import org.apache.avro.generic.{GenericData, GenericRecord}
|
|
import org.apache.avro.{Schema, SchemaBuilder}
|
|
import org.apache.spark.rdd.RDD
|
|
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
|
import org.apache.spark.sql.catalyst.expressions.GenericRow
|
|
import org.apache.spark.sql.types._
|
|
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
|
|
|
|
import scala.collection.JavaConverters._
|
|
|
|
|
|
object AvroConversionUtils {
|
|
|
|
def createRdd(df: DataFrame, structName: String, recordNamespace: String): RDD[GenericRecord] = {
|
|
val dataType = df.schema
|
|
val encoder = RowEncoder.apply(dataType).resolveAndBind()
|
|
df.queryExecution.toRdd.map(encoder.fromRow)
|
|
.mapPartitions { records =>
|
|
if (records.isEmpty) Iterator.empty
|
|
else {
|
|
val convertor = createConverterToAvro(dataType, structName, recordNamespace)
|
|
records.map { x => convertor(x).asInstanceOf[GenericRecord] }
|
|
}
|
|
}
|
|
}
|
|
|
|
def createDataFrame(rdd: RDD[GenericRecord], schemaStr: String, ss : SparkSession): Dataset[Row] = {
|
|
if (rdd.isEmpty()) {
|
|
ss.emptyDataFrame
|
|
} else {
|
|
ss.createDataFrame(rdd.mapPartitions { records =>
|
|
if (records.isEmpty) Iterator.empty
|
|
else {
|
|
val schema = Schema.parse(schemaStr)
|
|
val dataType = convertAvroSchemaToStructType(schema)
|
|
val convertor = createConverterToRow(schema, dataType)
|
|
records.map { x => convertor(x).asInstanceOf[Row] }
|
|
}
|
|
}, convertAvroSchemaToStructType(Schema.parse(schemaStr))).asInstanceOf[Dataset[Row]]
|
|
}
|
|
}
|
|
|
|
def getNewRecordNamespace(elementDataType: DataType,
|
|
currentRecordNamespace: String,
|
|
elementName: String): String = {
|
|
|
|
elementDataType match {
|
|
case StructType(_) => s"$currentRecordNamespace.$elementName"
|
|
case _ => currentRecordNamespace
|
|
}
|
|
}
|
|
|
|
/**
|
|
* NOTE : This part of code is copied from com.databricks.spark.avro.SchemaConverters.scala (133:310) (spark-avro)
|
|
*
|
|
* Returns a converter function to convert row in avro format to GenericRow of catalyst.
|
|
*
|
|
* @param sourceAvroSchema Source schema before conversion inferred from avro file by passed in
|
|
* by user.
|
|
* @param targetSqlType Target catalyst sql type after the conversion.
|
|
* @return returns a converter function to convert row in avro format to GenericRow of catalyst.
|
|
*/
|
|
def createConverterToRow(sourceAvroSchema: Schema,
|
|
targetSqlType: DataType): AnyRef => AnyRef = {
|
|
|
|
def createConverter(avroSchema: Schema,
|
|
sqlType: DataType, path: List[String]): AnyRef => AnyRef = {
|
|
val avroType = avroSchema.getType
|
|
(sqlType, avroType) match {
|
|
// Avro strings are in Utf8, so we have to call toString on them
|
|
case (StringType, STRING) | (StringType, ENUM) =>
|
|
(item: AnyRef) => if (item == null) null else item.toString
|
|
// Byte arrays are reused by avro, so we have to make a copy of them.
|
|
case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) |
|
|
(FloatType, FLOAT) | (LongType, LONG) =>
|
|
identity
|
|
case (BinaryType, FIXED) =>
|
|
(item: AnyRef) =>
|
|
if (item == null) {
|
|
null
|
|
} else {
|
|
item.asInstanceOf[Fixed].bytes().clone()
|
|
}
|
|
case (BinaryType, BYTES) =>
|
|
(item: AnyRef) =>
|
|
if (item == null) {
|
|
null
|
|
} else {
|
|
val byteBuffer = item.asInstanceOf[ByteBuffer]
|
|
val bytes = new Array[Byte](byteBuffer.remaining)
|
|
byteBuffer.get(bytes)
|
|
bytes
|
|
}
|
|
|
|
case (struct: StructType, RECORD) =>
|
|
val length = struct.fields.length
|
|
val converters = new Array[AnyRef => AnyRef](length)
|
|
val avroFieldIndexes = new Array[Int](length)
|
|
var i = 0
|
|
while (i < length) {
|
|
val sqlField = struct.fields(i)
|
|
val avroField = avroSchema.getField(sqlField.name)
|
|
if (avroField != null) {
|
|
val converter = createConverter(avroField.schema(), sqlField.dataType,
|
|
path :+ sqlField.name)
|
|
converters(i) = converter
|
|
avroFieldIndexes(i) = avroField.pos()
|
|
} else if (!sqlField.nullable) {
|
|
throw new IncompatibleSchemaException(
|
|
s"Cannot find non-nullable field ${sqlField.name} at path ${path.mkString(".")} " +
|
|
"in Avro schema\n" +
|
|
s"Source Avro schema: $sourceAvroSchema.\n" +
|
|
s"Target Catalyst type: $targetSqlType")
|
|
}
|
|
i += 1
|
|
}
|
|
|
|
(item: AnyRef) => {
|
|
if (item == null) {
|
|
null
|
|
} else {
|
|
val record = item.asInstanceOf[GenericRecord]
|
|
|
|
val result = new Array[Any](length)
|
|
var i = 0
|
|
while (i < converters.length) {
|
|
if (converters(i) != null) {
|
|
val converter = converters(i)
|
|
result(i) = converter(record.get(avroFieldIndexes(i)))
|
|
}
|
|
i += 1
|
|
}
|
|
new GenericRow(result)
|
|
}
|
|
}
|
|
case (arrayType: ArrayType, ARRAY) =>
|
|
val elementConverter = createConverter(avroSchema.getElementType, arrayType.elementType,
|
|
path)
|
|
val allowsNull = arrayType.containsNull
|
|
(item: AnyRef) => {
|
|
if (item == null) {
|
|
null
|
|
} else {
|
|
item.asInstanceOf[java.lang.Iterable[AnyRef]].asScala.map { element =>
|
|
if (element == null && !allowsNull) {
|
|
throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " +
|
|
"allowed to be null")
|
|
} else {
|
|
elementConverter(element)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case (mapType: MapType, MAP) if mapType.keyType == StringType =>
|
|
val valueConverter = createConverter(avroSchema.getValueType, mapType.valueType, path)
|
|
val allowsNull = mapType.valueContainsNull
|
|
(item: AnyRef) => {
|
|
if (item == null) {
|
|
null
|
|
} else {
|
|
item.asInstanceOf[java.util.Map[AnyRef, AnyRef]].asScala.map { x =>
|
|
if (x._2 == null && !allowsNull) {
|
|
throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " +
|
|
"allowed to be null")
|
|
} else {
|
|
(x._1.toString, valueConverter(x._2))
|
|
}
|
|
}.toMap
|
|
}
|
|
}
|
|
case (sqlType, UNION) =>
|
|
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
|
|
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
|
|
if (remainingUnionTypes.size == 1) {
|
|
createConverter(remainingUnionTypes.head, sqlType, path)
|
|
} else {
|
|
createConverter(Schema.createUnion(remainingUnionTypes.asJava), sqlType, path)
|
|
}
|
|
} else avroSchema.getTypes.asScala.map(_.getType) match {
|
|
case Seq(t1) => createConverter(avroSchema.getTypes.get(0), sqlType, path)
|
|
case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlType == LongType =>
|
|
(item: AnyRef) => {
|
|
item match {
|
|
case null => null
|
|
case l: java.lang.Long => l
|
|
case i: java.lang.Integer => new java.lang.Long(i.longValue())
|
|
}
|
|
}
|
|
case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlType == DoubleType =>
|
|
(item: AnyRef) => {
|
|
item match {
|
|
case null => null
|
|
case d: java.lang.Double => d
|
|
case f: java.lang.Float => new java.lang.Double(f.doubleValue())
|
|
}
|
|
}
|
|
case other =>
|
|
sqlType match {
|
|
case t: StructType if t.fields.length == avroSchema.getTypes.size =>
|
|
val fieldConverters = t.fields.zip(avroSchema.getTypes.asScala).map {
|
|
case (field, schema) =>
|
|
createConverter(schema, field.dataType, path :+ field.name)
|
|
}
|
|
|
|
(item: AnyRef) => if (item == null) {
|
|
null
|
|
} else {
|
|
val i = GenericData.get().resolveUnion(avroSchema, item)
|
|
val converted = new Array[Any](fieldConverters.length)
|
|
converted(i) = fieldConverters(i)(item)
|
|
new GenericRow(converted)
|
|
}
|
|
case _ => throw new IncompatibleSchemaException(
|
|
s"Cannot convert Avro schema to catalyst type because schema at path " +
|
|
s"${path.mkString(".")} is not compatible " +
|
|
s"(avroType = $other, sqlType = $sqlType). \n" +
|
|
s"Source Avro schema: $sourceAvroSchema.\n" +
|
|
s"Target Catalyst type: $targetSqlType")
|
|
}
|
|
}
|
|
case (left, right) =>
|
|
throw new IncompatibleSchemaException(
|
|
s"Cannot convert Avro schema to catalyst type because schema at path " +
|
|
s"${path.mkString(".")} is not compatible (avroType = $left, sqlType = $right). \n" +
|
|
s"Source Avro schema: $sourceAvroSchema.\n" +
|
|
s"Target Catalyst type: $targetSqlType")
|
|
}
|
|
}
|
|
createConverter(sourceAvroSchema, targetSqlType, List.empty[String])
|
|
}
|
|
|
|
def createConverterToAvro(dataType: DataType,
|
|
structName: String,
|
|
recordNamespace: String): Any => Any = {
|
|
dataType match {
|
|
case BinaryType => (item: Any) =>
|
|
item match {
|
|
case null => null
|
|
case bytes: Array[Byte] => ByteBuffer.wrap(bytes)
|
|
}
|
|
case IntegerType | LongType |
|
|
FloatType | DoubleType | StringType | BooleanType => identity
|
|
case ByteType => (item: Any) =>
|
|
if (item == null) null else item.asInstanceOf[Byte].intValue
|
|
case ShortType => (item: Any) =>
|
|
if (item == null) null else item.asInstanceOf[Short].intValue
|
|
case _: DecimalType => (item: Any) => if (item == null) null else item.toString
|
|
case TimestampType => (item: Any) =>
|
|
if (item == null) null else item.asInstanceOf[Timestamp].getTime
|
|
case DateType => (item: Any) =>
|
|
if (item == null) null else item.asInstanceOf[Date].getTime
|
|
case ArrayType(elementType, _) =>
|
|
val elementConverter = createConverterToAvro(
|
|
elementType,
|
|
structName,
|
|
getNewRecordNamespace(elementType, recordNamespace, structName))
|
|
(item: Any) => {
|
|
if (item == null) {
|
|
null
|
|
} else {
|
|
val sourceArray = item.asInstanceOf[Seq[Any]]
|
|
val sourceArraySize = sourceArray.size
|
|
val targetList = new util.ArrayList[Any](sourceArraySize)
|
|
var idx = 0
|
|
while (idx < sourceArraySize) {
|
|
targetList.add(elementConverter(sourceArray(idx)))
|
|
idx += 1
|
|
}
|
|
targetList
|
|
}
|
|
}
|
|
case MapType(StringType, valueType, _) =>
|
|
val valueConverter = createConverterToAvro(
|
|
valueType,
|
|
structName,
|
|
getNewRecordNamespace(valueType, recordNamespace, structName))
|
|
(item: Any) => {
|
|
if (item == null) {
|
|
null
|
|
} else {
|
|
val javaMap = new util.HashMap[String, Any]()
|
|
item.asInstanceOf[Map[String, Any]].foreach { case (key, value) =>
|
|
javaMap.put(key, valueConverter(value))
|
|
}
|
|
javaMap
|
|
}
|
|
}
|
|
case structType: StructType =>
|
|
val builder = SchemaBuilder.record(structName).namespace(recordNamespace)
|
|
val schema: Schema = SchemaConverters.convertStructToAvro(
|
|
structType, builder, recordNamespace)
|
|
val fieldConverters = structType.fields.map(field =>
|
|
createConverterToAvro(
|
|
field.dataType,
|
|
field.name,
|
|
getNewRecordNamespace(field.dataType, recordNamespace, field.name)))
|
|
(item: Any) => {
|
|
if (item == null) {
|
|
null
|
|
} else {
|
|
val record = new Record(schema)
|
|
val convertersIterator = fieldConverters.iterator
|
|
val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator
|
|
val rowIterator = item.asInstanceOf[Row].toSeq.iterator
|
|
|
|
while (convertersIterator.hasNext) {
|
|
val converter = convertersIterator.next()
|
|
record.put(fieldNamesIterator.next(), converter(rowIterator.next()))
|
|
}
|
|
record
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
def convertStructTypeToAvroSchema(structType: StructType,
|
|
structName: String,
|
|
recordNamespace: String): Schema = {
|
|
val builder = SchemaBuilder.record(structName).namespace(recordNamespace)
|
|
SchemaConverters.convertStructToAvro(structType, builder, recordNamespace)
|
|
}
|
|
|
|
def convertAvroSchemaToStructType(avroSchema: Schema): StructType = {
|
|
SchemaConverters.toSqlType(avroSchema).dataType.asInstanceOf[StructType];
|
|
}
|
|
}
|