[HUDI-863] get decimal properties from derived spark DataType (#1596)
This commit is contained in:
@@ -268,8 +268,7 @@ object AvroConversionHelper {
|
|||||||
createConverter(sourceAvroSchema, targetSqlType, List.empty[String])
|
createConverter(sourceAvroSchema, targetSqlType, List.empty[String])
|
||||||
}
|
}
|
||||||
|
|
||||||
def createConverterToAvro(avroSchema: Schema,
|
def createConverterToAvro(dataType: DataType,
|
||||||
dataType: DataType,
|
|
||||||
structName: String,
|
structName: String,
|
||||||
recordNamespace: String): Any => Any = {
|
recordNamespace: String): Any => Any = {
|
||||||
dataType match {
|
dataType match {
|
||||||
@@ -284,13 +283,15 @@ object AvroConversionHelper {
|
|||||||
if (item == null) null else item.asInstanceOf[Byte].intValue
|
if (item == null) null else item.asInstanceOf[Byte].intValue
|
||||||
case ShortType => (item: Any) =>
|
case ShortType => (item: Any) =>
|
||||||
if (item == null) null else item.asInstanceOf[Short].intValue
|
if (item == null) null else item.asInstanceOf[Short].intValue
|
||||||
case dec: DecimalType => (item: Any) =>
|
case dec: DecimalType =>
|
||||||
|
val schema = SchemaConverters.toAvroType(dec, nullable = false, structName, recordNamespace)
|
||||||
|
(item: Any) => {
|
||||||
Option(item).map { _ =>
|
Option(item).map { _ =>
|
||||||
val bigDecimalValue = item.asInstanceOf[java.math.BigDecimal]
|
val bigDecimalValue = item.asInstanceOf[java.math.BigDecimal]
|
||||||
val decimalConversions = new DecimalConversion()
|
val decimalConversions = new DecimalConversion()
|
||||||
decimalConversions.toFixed(bigDecimalValue, avroSchema.getField(structName).schema().getTypes.get(0),
|
decimalConversions.toFixed(bigDecimalValue, schema, LogicalTypes.decimal(dec.precision, dec.scale))
|
||||||
LogicalTypes.decimal(dec.precision, dec.scale))
|
|
||||||
}.orNull
|
}.orNull
|
||||||
|
}
|
||||||
case TimestampType => (item: Any) =>
|
case TimestampType => (item: Any) =>
|
||||||
// Convert time to microseconds since spark-avro by default converts TimestampType to
|
// Convert time to microseconds since spark-avro by default converts TimestampType to
|
||||||
// Avro Logical TimestampMicros
|
// Avro Logical TimestampMicros
|
||||||
@@ -299,7 +300,6 @@ object AvroConversionHelper {
|
|||||||
Option(item).map(_.asInstanceOf[Date].toLocalDate.toEpochDay.toInt).orNull
|
Option(item).map(_.asInstanceOf[Date].toLocalDate.toEpochDay.toInt).orNull
|
||||||
case ArrayType(elementType, _) =>
|
case ArrayType(elementType, _) =>
|
||||||
val elementConverter = createConverterToAvro(
|
val elementConverter = createConverterToAvro(
|
||||||
avroSchema,
|
|
||||||
elementType,
|
elementType,
|
||||||
structName,
|
structName,
|
||||||
recordNamespace)
|
recordNamespace)
|
||||||
@@ -320,7 +320,6 @@ object AvroConversionHelper {
|
|||||||
}
|
}
|
||||||
case MapType(StringType, valueType, _) =>
|
case MapType(StringType, valueType, _) =>
|
||||||
val valueConverter = createConverterToAvro(
|
val valueConverter = createConverterToAvro(
|
||||||
avroSchema,
|
|
||||||
valueType,
|
valueType,
|
||||||
structName,
|
structName,
|
||||||
recordNamespace)
|
recordNamespace)
|
||||||
@@ -340,7 +339,6 @@ object AvroConversionHelper {
|
|||||||
val childNameSpace = if (recordNamespace != "") s"$recordNamespace.$structName" else structName
|
val childNameSpace = if (recordNamespace != "") s"$recordNamespace.$structName" else structName
|
||||||
val fieldConverters = structType.fields.map(field =>
|
val fieldConverters = structType.fields.map(field =>
|
||||||
createConverterToAvro(
|
createConverterToAvro(
|
||||||
avroSchema,
|
|
||||||
field.dataType,
|
field.dataType,
|
||||||
field.name,
|
field.name,
|
||||||
childNameSpace))
|
childNameSpace))
|
||||||
|
|||||||
@@ -38,14 +38,12 @@ object AvroConversionUtils {
|
|||||||
: RDD[GenericRecord] = {
|
: RDD[GenericRecord] = {
|
||||||
// Use the Avro schema to derive the StructType which has the correct nullability information
|
// Use the Avro schema to derive the StructType which has the correct nullability information
|
||||||
val dataType = SchemaConverters.toSqlType(avroSchema).dataType.asInstanceOf[StructType]
|
val dataType = SchemaConverters.toSqlType(avroSchema).dataType.asInstanceOf[StructType]
|
||||||
val avroSchemaAsJsonString = avroSchema.toString
|
|
||||||
val encoder = RowEncoder.apply(dataType).resolveAndBind()
|
val encoder = RowEncoder.apply(dataType).resolveAndBind()
|
||||||
df.queryExecution.toRdd.map(encoder.fromRow)
|
df.queryExecution.toRdd.map(encoder.fromRow)
|
||||||
.mapPartitions { records =>
|
.mapPartitions { records =>
|
||||||
if (records.isEmpty) Iterator.empty
|
if (records.isEmpty) Iterator.empty
|
||||||
else {
|
else {
|
||||||
val avroSchema = new Schema.Parser().parse(avroSchemaAsJsonString)
|
val convertor = AvroConversionHelper.createConverterToAvro(dataType, structName, recordNamespace)
|
||||||
val convertor = AvroConversionHelper.createConverterToAvro(avroSchema, dataType, structName, recordNamespace)
|
|
||||||
records.map { x => convertor(x).asInstanceOf[GenericRecord] }
|
records.map { x => convertor(x).asInstanceOf[GenericRecord] }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user