1
0

[HUDI-1526] Translate the api partitionBy in spark datasource to hoodie.datasource.write.partitionpath.field (#2431)

This commit is contained in:
teeyog
2021-02-11 01:07:54 +08:00
committed by GitHub
parent a2f85d90de
commit 26da4f5462
3 changed files with 208 additions and 9 deletions

View File

@@ -25,12 +25,16 @@ import org.apache.hudi.common.table.timeline.HoodieInstant
import org.apache.hudi.common.testutils.HoodieTestDataGenerator
import org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.keygen._
import org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator.Config
import org.apache.hudi.testutils.HoodieClientTestBase
import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions, HoodieDataSourceHelpers}
import org.apache.spark.sql._
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.{DataTypes, DateType, IntegerType, StringType, StructField, StructType, TimestampType}
import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
import org.apache.spark.sql.functions.{col, concat, lit, udf}
import org.apache.spark.sql.types._
import org.joda.time.DateTime
import org.joda.time.format.DateTimeFormat
import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue, fail}
import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
@@ -428,4 +432,151 @@ class TestCOWDataSource extends HoodieClientTestBase {
assertTrue(HoodieDataSourceHelpers.hasNewCommits(fs, basePath, "000"))
}
private def getDataFrameWriter(keyGenerator: String): DataFrameWriter[Row] = {
val records = recordsToStrings(dataGen.generateInserts("000", 100)).toList
val inputDF = spark.read.json(spark.sparkContext.parallelize(records, 2))
inputDF.write.format("hudi")
.options(commonOpts)
.option(DataSourceWriteOptions.KEYGENERATOR_CLASS_OPT_KEY, keyGenerator)
.mode(SaveMode.Overwrite)
}
@Test def testSparkPartitonByWithCustomKeyGenerator(): Unit = {
// Without fieldType, the default is SIMPLE
var writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName)
writer.partitionBy("current_ts")
.save(basePath)
var recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*/*")
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= col("current_ts").cast("string")).count() == 0)
// Specify fieldType as TIMESTAMP
writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName)
writer.partitionBy("current_ts:TIMESTAMP")
.option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS")
.option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd")
.save(basePath)
recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*/*")
val udf_date_format = udf((data: Long) => new DateTime(data).toString(DateTimeFormat.forPattern("yyyyMMdd")))
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= udf_date_format(col("current_ts"))).count() == 0)
// Mixed fieldType
writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName)
writer.partitionBy("driver", "rider:SIMPLE", "current_ts:TIMESTAMP")
.option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS")
.option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd")
.save(basePath)
recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*/*/*")
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!=
concat(col("driver"), lit("/"), col("rider"), lit("/"), udf_date_format(col("current_ts")))).count() == 0)
// Test invalid partitionKeyType
writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName)
writer = writer.partitionBy("current_ts:DUMMY")
.option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS")
.option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd")
try {
writer.save(basePath)
fail("should fail when invalid PartitionKeyType is provided!")
} catch {
case e: Exception =>
assertTrue(e.getMessage.contains("No enum constant org.apache.hudi.keygen.CustomAvroKeyGenerator.PartitionKeyType.DUMMY"))
}
}
@Test def testSparkPartitonByWithSimpleKeyGenerator() {
// Use the `driver` field as the partition key
var writer = getDataFrameWriter(classOf[SimpleKeyGenerator].getName)
writer.partitionBy("driver")
.save(basePath)
var recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*/*")
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= col("driver")).count() == 0)
// Use the `driver,rider` field as the partition key, If no such field exists, the default value `default` is used
writer = getDataFrameWriter(classOf[SimpleKeyGenerator].getName)
writer.partitionBy("driver", "rider")
.save(basePath)
recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*/*")
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= lit("default")).count() == 0)
}
@Test def testSparkPartitonByWithComplexKeyGenerator() {
// Use the `driver` field as the partition key
var writer = getDataFrameWriter(classOf[ComplexKeyGenerator].getName)
writer.partitionBy("driver")
.save(basePath)
var recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*/*")
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= col("driver")).count() == 0)
// Use the `driver`,`rider` field as the partition key
writer = getDataFrameWriter(classOf[ComplexKeyGenerator].getName)
writer.partitionBy("driver", "rider")
.save(basePath)
recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*/*")
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= concat(col("driver"), lit("/"), col("rider"))).count() == 0)
}
@Test def testSparkPartitonByWithTimestampBasedKeyGenerator() {
val writer = getDataFrameWriter(classOf[TimestampBasedKeyGenerator].getName)
writer.partitionBy("current_ts")
.option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS")
.option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd")
.save(basePath)
val recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*/*")
val udf_date_format = udf((data: Long) => new DateTime(data).toString(DateTimeFormat.forPattern("yyyyMMdd")))
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= udf_date_format(col("current_ts"))).count() == 0)
}
@Test def testSparkPartitonByWithGlobalDeleteKeyGenerator() {
val writer = getDataFrameWriter(classOf[GlobalDeleteKeyGenerator].getName)
writer.partitionBy("driver")
.save(basePath)
val recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*")
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= lit("")).count() == 0)
}
@Test def testSparkPartitonByWithNonpartitionedKeyGenerator() {
// Empty string column
var writer = getDataFrameWriter(classOf[NonpartitionedKeyGenerator].getName)
writer.partitionBy("")
.save(basePath)
var recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*")
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= lit("")).count() == 0)
// Non-existent column
writer = getDataFrameWriter(classOf[NonpartitionedKeyGenerator].getName)
writer.partitionBy("abc")
.save(basePath)
recordsReadDF = spark.read.format("org.apache.hudi")
.load(basePath + "/*")
assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= lit("")).count() == 0)
}
}