1
0

[HUDI-2958] Automatically set spark.sql.parquet.writelegacyformat, when using bulkinsert to insert data which contains decimalType (#4253)

This commit is contained in:
xiarixiaoyao
2021-12-17 21:58:02 +08:00
committed by GitHub
parent e4cfb421c0
commit 9246b16492
6 changed files with 115 additions and 1 deletions

View File

@@ -18,12 +18,16 @@
package org.apache.hudi.util;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.ShortType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
@@ -119,4 +123,26 @@ public class DataTypeUtils {
private static <T> HashSet<T> newHashSet(T... ts) {
return new HashSet<>(Arrays.asList(ts));
}
/**
* Try to find current sparktype whether contains that DecimalType which's scale < Decimal.MAX_LONG_DIGITS().
*
* @param sparkType spark schema.
* @return found result.
*/
public static boolean foundSmallPrecisionDecimalType(DataType sparkType) {
if (sparkType instanceof StructType) {
StructField[] fields = ((StructType) sparkType).fields();
return Arrays.stream(fields).anyMatch(f -> foundSmallPrecisionDecimalType(f.dataType()));
} else if (sparkType instanceof MapType) {
MapType map = (MapType) sparkType;
return foundSmallPrecisionDecimalType(map.keyType()) || foundSmallPrecisionDecimalType(map.valueType());
} else if (sparkType instanceof ArrayType) {
return foundSmallPrecisionDecimalType(((ArrayType) sparkType).elementType());
} else if (sparkType instanceof DecimalType) {
DecimalType decimalType = (DecimalType) sparkType;
return decimalType.precision() < Decimal.MAX_LONG_DIGITS();
}
return false;
}
}

View File

@@ -46,12 +46,14 @@ import org.apache.hudi.hive.HiveSyncConfig;
import org.apache.hudi.hive.SlashEncodedDayPartitionValueExtractor;
import org.apache.hudi.index.HoodieIndex.IndexType;
import org.apache.hudi.table.BulkInsertPartitioner;
import org.apache.hudi.util.DataTypeUtils;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import java.io.IOException;
import java.util.ArrayList;
@@ -309,4 +311,15 @@ public class DataSourceUtils {
DataSourceWriteOptions.HIVE_SUPPORT_TIMESTAMP_TYPE().defaultValue()));
return hiveSyncConfig;
}
// Now by default ParquetWriteSupport will write DecimalType to parquet as int32/int64 when the scale of decimalType < Decimal.MAX_LONG_DIGITS(),
// but AvroParquetReader which used by HoodieParquetReader cannot support read int32/int64 as DecimalType.
// try to find current schema whether contains that DecimalType, and auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
public static void mayBeOverwriteParquetWriteLegacyFormatProp(Map<String, String> properties, StructType schema) {
if (DataTypeUtils.foundSmallPrecisionDecimalType(schema)
&& !Boolean.parseBoolean(properties.getOrDefault("hoodie.parquet.writeLegacyFormat.enabled", "false"))) {
properties.put("hoodie.parquet.writeLegacyFormat.enabled", "true");
LOG.warn("Small Decimal Type found in current schema, auto set the value of hoodie.parquet.writeLegacyFormat.enabled to true");
}
}
}

View File

@@ -40,10 +40,16 @@ import org.apache.avro.generic.GenericRecord;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DecimalType$;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
@@ -52,7 +58,12 @@ import org.mockito.junit.jupiter.MockitoExtension;
import java.math.BigDecimal;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;
import static org.apache.hudi.common.model.HoodieFileFormat.PARQUET;
import static org.apache.hudi.hive.ddl.HiveSyncMode.HMS;
import static org.hamcrest.CoreMatchers.containsString;
@@ -274,4 +285,33 @@ public class TestDataSourceUtils {
return false;
}
}
@ParameterizedTest
@CsvSource({"true, false", "true, true", "false, true", "false, false"})
public void testAutoModifyParquetWriteLegacyFormatParameter(boolean smallDecimal, boolean defaultWriteValue) {
// create test StructType
List<StructField> structFields = new ArrayList<>();
if (smallDecimal) {
structFields.add(StructField.apply("d1", DecimalType$.MODULE$.apply(10, 2), false, Metadata.empty()));
} else {
structFields.add(StructField.apply("d1", DecimalType$.MODULE$.apply(38, 10), false, Metadata.empty()));
}
StructType structType = StructType$.MODULE$.apply(structFields);
// create write options
Map<String, String> options = new HashMap<>();
options.put("hoodie.parquet.writeLegacyFormat.enabled", String.valueOf(defaultWriteValue));
// start test
mayBeOverwriteParquetWriteLegacyFormatProp(options, structType);
// check result
boolean res = Boolean.parseBoolean(options.get("hoodie.parquet.writeLegacyFormat.enabled"));
if (smallDecimal) {
// should auto modify "hoodie.parquet.writeLegacyFormat.enabled" = "true".
assertEquals(true, res);
} else {
// should not modify the value of "hoodie.parquet.writeLegacyFormat.enabled".
assertEquals(defaultWriteValue, res);
}
}
}

View File

@@ -723,4 +723,29 @@ class TestCOWDataSource extends HoodieClientTestBase {
val result = spark.sql("select * from tmptable limit 1").collect()(0)
result.schema.contains(new StructField("partition", StringType, true))
}
@Test
def testWriteSmallPrecisionDecimalTable(): Unit = {
val records1 = recordsToStrings(dataGen.generateInserts("001", 5)).toList
val inputDF1 = spark.read.json(spark.sparkContext.parallelize(records1, 2))
.withColumn("shortDecimal", lit(new java.math.BigDecimal(s"2090.0000"))) // create decimalType(8, 4)
inputDF1.write.format("org.apache.hudi")
.options(commonOpts)
.option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.BULK_INSERT_OPERATION_OPT_VAL)
.mode(SaveMode.Overwrite)
.save(basePath)
// update the value of shortDecimal
val inputDF2 = inputDF1.withColumn("shortDecimal", lit(new java.math.BigDecimal(s"3090.0000")))
inputDF2.write.format("org.apache.hudi")
.options(commonOpts)
.option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
.mode(SaveMode.Append)
.save(basePath)
val readResult = spark.read.format("hudi").load(basePath)
assert(readResult.count() == 5)
// compare the test result
assertEquals(inputDF2.sort("_row_key").select("shortDecimal").collect().map(_.getDecimal(0).toPlainString).mkString(","),
readResult.sort("_row_key").select("shortDecimal").collect().map(_.getDecimal(0).toPlainString).mkString(","))
}
}

View File

@@ -33,8 +33,11 @@ import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter;
import org.apache.spark.sql.types.StructType;
import java.util.Map;
import java.util.Optional;
import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;
/**
* DataSource V2 implementation for managing internal write logic. Only called internally.
*/
@@ -64,8 +67,11 @@ public class DefaultSource extends BaseDefaultSource implements DataSourceV2,
String tblName = options.get(HoodieWriteConfig.TBL_NAME.key()).get();
boolean populateMetaFields = options.getBoolean(HoodieTableConfig.POPULATE_META_FIELDS.key(),
Boolean.parseBoolean(HoodieTableConfig.POPULATE_META_FIELDS.defaultValue()));
Map<String, String> properties = options.asMap();
// Auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
mayBeOverwriteParquetWriteLegacyFormatProp(properties, schema);
// 1st arg to createHoodieConfig is not really required to be set. but passing it anyways.
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(options.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()).get(), path, tblName, options.asMap());
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(options.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()).get(), path, tblName, properties);
boolean arePartitionRecordsSorted = HoodieInternalConfig.getBulkInsertIsPartitionRecordsSorted(
options.get(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED).isPresent()
? options.get(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED).get() : null);

View File

@@ -33,6 +33,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import java.util.Map;
import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;
/**
* DataSource V2 implementation for managing internal write logic. Only called internally.
* This class is only compatible with datasource V2 API in Spark 3.
@@ -53,6 +55,8 @@ public class DefaultSource extends BaseDefaultSource implements TableProvider {
HoodieTableConfig.POPULATE_META_FIELDS.defaultValue()));
boolean arePartitionRecordsSorted = Boolean.parseBoolean(properties.getOrDefault(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED,
Boolean.toString(HoodieInternalConfig.DEFAULT_BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED)));
// Auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
mayBeOverwriteParquetWriteLegacyFormatProp(properties, schema);
// 1st arg to createHoodieConfig is not really required to be set. but passing it anyways.
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(properties.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()), path, tblName, properties);
return new HoodieDataSourceInternalTable(instantTime, config, schema, getSparkSession(),