diff --git a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/util/DataTypeUtils.java b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/util/DataTypeUtils.java index b934f5f6e..bf800536e 100644 --- a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/util/DataTypeUtils.java +++ b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/util/DataTypeUtils.java @@ -18,12 +18,16 @@ package org.apache.hudi.util; +import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.ByteType$; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.DoubleType$; import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.MapType; import org.apache.spark.sql.types.ShortType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.sql.types.StructField; @@ -119,4 +123,26 @@ public class DataTypeUtils { private static HashSet newHashSet(T... ts) { return new HashSet<>(Arrays.asList(ts)); } + + /** + * Try to find current sparktype whether contains that DecimalType which's scale < Decimal.MAX_LONG_DIGITS(). + * + * @param sparkType spark schema. + * @return found result. + */ + public static boolean foundSmallPrecisionDecimalType(DataType sparkType) { + if (sparkType instanceof StructType) { + StructField[] fields = ((StructType) sparkType).fields(); + return Arrays.stream(fields).anyMatch(f -> foundSmallPrecisionDecimalType(f.dataType())); + } else if (sparkType instanceof MapType) { + MapType map = (MapType) sparkType; + return foundSmallPrecisionDecimalType(map.keyType()) || foundSmallPrecisionDecimalType(map.valueType()); + } else if (sparkType instanceof ArrayType) { + return foundSmallPrecisionDecimalType(((ArrayType) sparkType).elementType()); + } else if (sparkType instanceof DecimalType) { + DecimalType decimalType = (DecimalType) sparkType; + return decimalType.precision() < Decimal.MAX_LONG_DIGITS(); + } + return false; + } } diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/java/org/apache/hudi/DataSourceUtils.java b/hudi-spark-datasource/hudi-spark-common/src/main/java/org/apache/hudi/DataSourceUtils.java index b98417ef2..a2d210e2c 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/java/org/apache/hudi/DataSourceUtils.java +++ b/hudi-spark-datasource/hudi-spark-common/src/main/java/org/apache/hudi/DataSourceUtils.java @@ -46,12 +46,14 @@ import org.apache.hudi.hive.HiveSyncConfig; import org.apache.hudi.hive.SlashEncodedDayPartitionValueExtractor; import org.apache.hudi.index.HoodieIndex.IndexType; import org.apache.hudi.table.BulkInsertPartitioner; +import org.apache.hudi.util.DataTypeUtils; import org.apache.log4j.LogManager; import org.apache.log4j.Logger; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructType; import java.io.IOException; import java.util.ArrayList; @@ -309,4 +311,15 @@ public class DataSourceUtils { DataSourceWriteOptions.HIVE_SUPPORT_TIMESTAMP_TYPE().defaultValue())); return hiveSyncConfig; } + + // Now by default ParquetWriteSupport will write DecimalType to parquet as int32/int64 when the scale of decimalType < Decimal.MAX_LONG_DIGITS(), + // but AvroParquetReader which used by HoodieParquetReader cannot support read int32/int64 as DecimalType. + // try to find current schema whether contains that DecimalType, and auto set the value of "hoodie.parquet.writeLegacyFormat.enabled" + public static void mayBeOverwriteParquetWriteLegacyFormatProp(Map properties, StructType schema) { + if (DataTypeUtils.foundSmallPrecisionDecimalType(schema) + && !Boolean.parseBoolean(properties.getOrDefault("hoodie.parquet.writeLegacyFormat.enabled", "false"))) { + properties.put("hoodie.parquet.writeLegacyFormat.enabled", "true"); + LOG.warn("Small Decimal Type found in current schema, auto set the value of hoodie.parquet.writeLegacyFormat.enabled to true"); + } + } } diff --git a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/TestDataSourceUtils.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/TestDataSourceUtils.java index 6353aa216..ae89e8af7 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/TestDataSourceUtils.java +++ b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/TestDataSourceUtils.java @@ -40,10 +40,16 @@ import org.apache.avro.generic.GenericRecord; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.DecimalType$; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.StructType$; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.mockito.Captor; @@ -52,7 +58,12 @@ import org.mockito.junit.jupiter.MockitoExtension; import java.math.BigDecimal; import java.time.LocalDate; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp; import static org.apache.hudi.common.model.HoodieFileFormat.PARQUET; import static org.apache.hudi.hive.ddl.HiveSyncMode.HMS; import static org.hamcrest.CoreMatchers.containsString; @@ -274,4 +285,33 @@ public class TestDataSourceUtils { return false; } } + + @ParameterizedTest + @CsvSource({"true, false", "true, true", "false, true", "false, false"}) + public void testAutoModifyParquetWriteLegacyFormatParameter(boolean smallDecimal, boolean defaultWriteValue) { + // create test StructType + List structFields = new ArrayList<>(); + if (smallDecimal) { + structFields.add(StructField.apply("d1", DecimalType$.MODULE$.apply(10, 2), false, Metadata.empty())); + } else { + structFields.add(StructField.apply("d1", DecimalType$.MODULE$.apply(38, 10), false, Metadata.empty())); + } + StructType structType = StructType$.MODULE$.apply(structFields); + // create write options + Map options = new HashMap<>(); + options.put("hoodie.parquet.writeLegacyFormat.enabled", String.valueOf(defaultWriteValue)); + + // start test + mayBeOverwriteParquetWriteLegacyFormatProp(options, structType); + + // check result + boolean res = Boolean.parseBoolean(options.get("hoodie.parquet.writeLegacyFormat.enabled")); + if (smallDecimal) { + // should auto modify "hoodie.parquet.writeLegacyFormat.enabled" = "true". + assertEquals(true, res); + } else { + // should not modify the value of "hoodie.parquet.writeLegacyFormat.enabled". + assertEquals(defaultWriteValue, res); + } + } } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala index 663493438..a42a1ced1 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala @@ -723,4 +723,29 @@ class TestCOWDataSource extends HoodieClientTestBase { val result = spark.sql("select * from tmptable limit 1").collect()(0) result.schema.contains(new StructField("partition", StringType, true)) } + + @Test + def testWriteSmallPrecisionDecimalTable(): Unit = { + val records1 = recordsToStrings(dataGen.generateInserts("001", 5)).toList + val inputDF1 = spark.read.json(spark.sparkContext.parallelize(records1, 2)) + .withColumn("shortDecimal", lit(new java.math.BigDecimal(s"2090.0000"))) // create decimalType(8, 4) + inputDF1.write.format("org.apache.hudi") + .options(commonOpts) + .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.BULK_INSERT_OPERATION_OPT_VAL) + .mode(SaveMode.Overwrite) + .save(basePath) + + // update the value of shortDecimal + val inputDF2 = inputDF1.withColumn("shortDecimal", lit(new java.math.BigDecimal(s"3090.0000"))) + inputDF2.write.format("org.apache.hudi") + .options(commonOpts) + .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL) + .mode(SaveMode.Append) + .save(basePath) + val readResult = spark.read.format("hudi").load(basePath) + assert(readResult.count() == 5) + // compare the test result + assertEquals(inputDF2.sort("_row_key").select("shortDecimal").collect().map(_.getDecimal(0).toPlainString).mkString(","), + readResult.sort("_row_key").select("shortDecimal").collect().map(_.getDecimal(0).toPlainString).mkString(",")) + } } diff --git a/hudi-spark-datasource/hudi-spark2/src/main/java/org/apache/hudi/internal/DefaultSource.java b/hudi-spark-datasource/hudi-spark2/src/main/java/org/apache/hudi/internal/DefaultSource.java index addbc899d..e607b2fdc 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/java/org/apache/hudi/internal/DefaultSource.java +++ b/hudi-spark-datasource/hudi-spark2/src/main/java/org/apache/hudi/internal/DefaultSource.java @@ -33,8 +33,11 @@ import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.types.StructType; +import java.util.Map; import java.util.Optional; +import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp; + /** * DataSource V2 implementation for managing internal write logic. Only called internally. */ @@ -64,8 +67,11 @@ public class DefaultSource extends BaseDefaultSource implements DataSourceV2, String tblName = options.get(HoodieWriteConfig.TBL_NAME.key()).get(); boolean populateMetaFields = options.getBoolean(HoodieTableConfig.POPULATE_META_FIELDS.key(), Boolean.parseBoolean(HoodieTableConfig.POPULATE_META_FIELDS.defaultValue())); + Map properties = options.asMap(); + // Auto set the value of "hoodie.parquet.writeLegacyFormat.enabled" + mayBeOverwriteParquetWriteLegacyFormatProp(properties, schema); // 1st arg to createHoodieConfig is not really required to be set. but passing it anyways. - HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(options.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()).get(), path, tblName, options.asMap()); + HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(options.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()).get(), path, tblName, properties); boolean arePartitionRecordsSorted = HoodieInternalConfig.getBulkInsertIsPartitionRecordsSorted( options.get(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED).isPresent() ? options.get(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED).get() : null); diff --git a/hudi-spark-datasource/hudi-spark3/src/main/java/org/apache/hudi/spark3/internal/DefaultSource.java b/hudi-spark-datasource/hudi-spark3/src/main/java/org/apache/hudi/spark3/internal/DefaultSource.java index eda8faead..63c09e05f 100644 --- a/hudi-spark-datasource/hudi-spark3/src/main/java/org/apache/hudi/spark3/internal/DefaultSource.java +++ b/hudi-spark-datasource/hudi-spark3/src/main/java/org/apache/hudi/spark3/internal/DefaultSource.java @@ -33,6 +33,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; import java.util.Map; +import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp; + /** * DataSource V2 implementation for managing internal write logic. Only called internally. * This class is only compatible with datasource V2 API in Spark 3. @@ -53,6 +55,8 @@ public class DefaultSource extends BaseDefaultSource implements TableProvider { HoodieTableConfig.POPULATE_META_FIELDS.defaultValue())); boolean arePartitionRecordsSorted = Boolean.parseBoolean(properties.getOrDefault(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED, Boolean.toString(HoodieInternalConfig.DEFAULT_BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED))); + // Auto set the value of "hoodie.parquet.writeLegacyFormat.enabled" + mayBeOverwriteParquetWriteLegacyFormatProp(properties, schema); // 1st arg to createHoodieConfig is not really required to be set. but passing it anyways. HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(properties.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()), path, tblName, properties); return new HoodieDataSourceInternalTable(instantTime, config, schema, getSparkSession(),