[HUDI-2958] Automatically set spark.sql.parquet.writelegacyformat, when using bulkinsert to insert data which contains decimalType (#4253)
This commit is contained in:
@@ -18,12 +18,16 @@
|
||||
|
||||
package org.apache.hudi.util;
|
||||
|
||||
import org.apache.spark.sql.types.ArrayType;
|
||||
import org.apache.spark.sql.types.ByteType$;
|
||||
import org.apache.spark.sql.types.DataType;
|
||||
import org.apache.spark.sql.types.Decimal;
|
||||
import org.apache.spark.sql.types.DecimalType;
|
||||
import org.apache.spark.sql.types.DoubleType$;
|
||||
import org.apache.spark.sql.types.FloatType$;
|
||||
import org.apache.spark.sql.types.IntegerType$;
|
||||
import org.apache.spark.sql.types.LongType$;
|
||||
import org.apache.spark.sql.types.MapType;
|
||||
import org.apache.spark.sql.types.ShortType$;
|
||||
import org.apache.spark.sql.types.StringType$;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
@@ -119,4 +123,26 @@ public class DataTypeUtils {
|
||||
private static <T> HashSet<T> newHashSet(T... ts) {
|
||||
return new HashSet<>(Arrays.asList(ts));
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to find current sparktype whether contains that DecimalType which's scale < Decimal.MAX_LONG_DIGITS().
|
||||
*
|
||||
* @param sparkType spark schema.
|
||||
* @return found result.
|
||||
*/
|
||||
public static boolean foundSmallPrecisionDecimalType(DataType sparkType) {
|
||||
if (sparkType instanceof StructType) {
|
||||
StructField[] fields = ((StructType) sparkType).fields();
|
||||
return Arrays.stream(fields).anyMatch(f -> foundSmallPrecisionDecimalType(f.dataType()));
|
||||
} else if (sparkType instanceof MapType) {
|
||||
MapType map = (MapType) sparkType;
|
||||
return foundSmallPrecisionDecimalType(map.keyType()) || foundSmallPrecisionDecimalType(map.valueType());
|
||||
} else if (sparkType instanceof ArrayType) {
|
||||
return foundSmallPrecisionDecimalType(((ArrayType) sparkType).elementType());
|
||||
} else if (sparkType instanceof DecimalType) {
|
||||
DecimalType decimalType = (DecimalType) sparkType;
|
||||
return decimalType.precision() < Decimal.MAX_LONG_DIGITS();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,12 +46,14 @@ import org.apache.hudi.hive.HiveSyncConfig;
|
||||
import org.apache.hudi.hive.SlashEncodedDayPartitionValueExtractor;
|
||||
import org.apache.hudi.index.HoodieIndex.IndexType;
|
||||
import org.apache.hudi.table.BulkInsertPartitioner;
|
||||
import org.apache.hudi.util.DataTypeUtils;
|
||||
import org.apache.log4j.LogManager;
|
||||
import org.apache.log4j.Logger;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
@@ -309,4 +311,15 @@ public class DataSourceUtils {
|
||||
DataSourceWriteOptions.HIVE_SUPPORT_TIMESTAMP_TYPE().defaultValue()));
|
||||
return hiveSyncConfig;
|
||||
}
|
||||
|
||||
// Now by default ParquetWriteSupport will write DecimalType to parquet as int32/int64 when the scale of decimalType < Decimal.MAX_LONG_DIGITS(),
|
||||
// but AvroParquetReader which used by HoodieParquetReader cannot support read int32/int64 as DecimalType.
|
||||
// try to find current schema whether contains that DecimalType, and auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
|
||||
public static void mayBeOverwriteParquetWriteLegacyFormatProp(Map<String, String> properties, StructType schema) {
|
||||
if (DataTypeUtils.foundSmallPrecisionDecimalType(schema)
|
||||
&& !Boolean.parseBoolean(properties.getOrDefault("hoodie.parquet.writeLegacyFormat.enabled", "false"))) {
|
||||
properties.put("hoodie.parquet.writeLegacyFormat.enabled", "true");
|
||||
LOG.warn("Small Decimal Type found in current schema, auto set the value of hoodie.parquet.writeLegacyFormat.enabled to true");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,10 +40,16 @@ import org.apache.avro.generic.GenericRecord;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.types.DecimalType$;
|
||||
import org.apache.spark.sql.types.Metadata;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
import org.apache.spark.sql.types.StructType$;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
@@ -52,7 +58,12 @@ import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;
|
||||
import static org.apache.hudi.common.model.HoodieFileFormat.PARQUET;
|
||||
import static org.apache.hudi.hive.ddl.HiveSyncMode.HMS;
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
@@ -274,4 +285,33 @@ public class TestDataSourceUtils {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({"true, false", "true, true", "false, true", "false, false"})
|
||||
public void testAutoModifyParquetWriteLegacyFormatParameter(boolean smallDecimal, boolean defaultWriteValue) {
|
||||
// create test StructType
|
||||
List<StructField> structFields = new ArrayList<>();
|
||||
if (smallDecimal) {
|
||||
structFields.add(StructField.apply("d1", DecimalType$.MODULE$.apply(10, 2), false, Metadata.empty()));
|
||||
} else {
|
||||
structFields.add(StructField.apply("d1", DecimalType$.MODULE$.apply(38, 10), false, Metadata.empty()));
|
||||
}
|
||||
StructType structType = StructType$.MODULE$.apply(structFields);
|
||||
// create write options
|
||||
Map<String, String> options = new HashMap<>();
|
||||
options.put("hoodie.parquet.writeLegacyFormat.enabled", String.valueOf(defaultWriteValue));
|
||||
|
||||
// start test
|
||||
mayBeOverwriteParquetWriteLegacyFormatProp(options, structType);
|
||||
|
||||
// check result
|
||||
boolean res = Boolean.parseBoolean(options.get("hoodie.parquet.writeLegacyFormat.enabled"));
|
||||
if (smallDecimal) {
|
||||
// should auto modify "hoodie.parquet.writeLegacyFormat.enabled" = "true".
|
||||
assertEquals(true, res);
|
||||
} else {
|
||||
// should not modify the value of "hoodie.parquet.writeLegacyFormat.enabled".
|
||||
assertEquals(defaultWriteValue, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -723,4 +723,29 @@ class TestCOWDataSource extends HoodieClientTestBase {
|
||||
val result = spark.sql("select * from tmptable limit 1").collect()(0)
|
||||
result.schema.contains(new StructField("partition", StringType, true))
|
||||
}
|
||||
|
||||
@Test
|
||||
def testWriteSmallPrecisionDecimalTable(): Unit = {
|
||||
val records1 = recordsToStrings(dataGen.generateInserts("001", 5)).toList
|
||||
val inputDF1 = spark.read.json(spark.sparkContext.parallelize(records1, 2))
|
||||
.withColumn("shortDecimal", lit(new java.math.BigDecimal(s"2090.0000"))) // create decimalType(8, 4)
|
||||
inputDF1.write.format("org.apache.hudi")
|
||||
.options(commonOpts)
|
||||
.option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.BULK_INSERT_OPERATION_OPT_VAL)
|
||||
.mode(SaveMode.Overwrite)
|
||||
.save(basePath)
|
||||
|
||||
// update the value of shortDecimal
|
||||
val inputDF2 = inputDF1.withColumn("shortDecimal", lit(new java.math.BigDecimal(s"3090.0000")))
|
||||
inputDF2.write.format("org.apache.hudi")
|
||||
.options(commonOpts)
|
||||
.option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL)
|
||||
.mode(SaveMode.Append)
|
||||
.save(basePath)
|
||||
val readResult = spark.read.format("hudi").load(basePath)
|
||||
assert(readResult.count() == 5)
|
||||
// compare the test result
|
||||
assertEquals(inputDF2.sort("_row_key").select("shortDecimal").collect().map(_.getDecimal(0).toPlainString).mkString(","),
|
||||
readResult.sort("_row_key").select("shortDecimal").collect().map(_.getDecimal(0).toPlainString).mkString(","))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,11 @@ import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
|
||||
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;
|
||||
|
||||
/**
|
||||
* DataSource V2 implementation for managing internal write logic. Only called internally.
|
||||
*/
|
||||
@@ -64,8 +67,11 @@ public class DefaultSource extends BaseDefaultSource implements DataSourceV2,
|
||||
String tblName = options.get(HoodieWriteConfig.TBL_NAME.key()).get();
|
||||
boolean populateMetaFields = options.getBoolean(HoodieTableConfig.POPULATE_META_FIELDS.key(),
|
||||
Boolean.parseBoolean(HoodieTableConfig.POPULATE_META_FIELDS.defaultValue()));
|
||||
Map<String, String> properties = options.asMap();
|
||||
// Auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
|
||||
mayBeOverwriteParquetWriteLegacyFormatProp(properties, schema);
|
||||
// 1st arg to createHoodieConfig is not really required to be set. but passing it anyways.
|
||||
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(options.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()).get(), path, tblName, options.asMap());
|
||||
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(options.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()).get(), path, tblName, properties);
|
||||
boolean arePartitionRecordsSorted = HoodieInternalConfig.getBulkInsertIsPartitionRecordsSorted(
|
||||
options.get(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED).isPresent()
|
||||
? options.get(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED).get() : null);
|
||||
|
||||
@@ -33,6 +33,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.apache.hudi.DataSourceUtils.mayBeOverwriteParquetWriteLegacyFormatProp;
|
||||
|
||||
/**
|
||||
* DataSource V2 implementation for managing internal write logic. Only called internally.
|
||||
* This class is only compatible with datasource V2 API in Spark 3.
|
||||
@@ -53,6 +55,8 @@ public class DefaultSource extends BaseDefaultSource implements TableProvider {
|
||||
HoodieTableConfig.POPULATE_META_FIELDS.defaultValue()));
|
||||
boolean arePartitionRecordsSorted = Boolean.parseBoolean(properties.getOrDefault(HoodieInternalConfig.BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED,
|
||||
Boolean.toString(HoodieInternalConfig.DEFAULT_BULKINSERT_ARE_PARTITIONER_RECORDS_SORTED)));
|
||||
// Auto set the value of "hoodie.parquet.writeLegacyFormat.enabled"
|
||||
mayBeOverwriteParquetWriteLegacyFormatProp(properties, schema);
|
||||
// 1st arg to createHoodieConfig is not really required to be set. but passing it anyways.
|
||||
HoodieWriteConfig config = DataSourceUtils.createHoodieConfig(properties.get(HoodieWriteConfig.AVRO_SCHEMA_STRING.key()), path, tblName, properties);
|
||||
return new HoodieDataSourceInternalTable(instantTime, config, schema, getSparkSession(),
|
||||
|
||||
Reference in New Issue
Block a user