diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala index 8586300e6..a79ac6f1d 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala @@ -86,8 +86,10 @@ class SparkHoodieTableFileIndex(spark: SparkSession, val nameFieldMap = generateFieldMap(schema) if (partitionColumns.isPresent) { - if (tableConfig.getKeyGeneratorClassName.equalsIgnoreCase(classOf[TimestampBasedKeyGenerator].getName) - || tableConfig.getKeyGeneratorClassName.equalsIgnoreCase(classOf[TimestampBasedAvroKeyGenerator].getName)) { + // Note that key generator class name could be null + val keyGeneratorClassName = tableConfig.getKeyGeneratorClassName + if (classOf[TimestampBasedKeyGenerator].getName.equalsIgnoreCase(keyGeneratorClassName) + || classOf[TimestampBasedAvroKeyGenerator].getName.equalsIgnoreCase(keyGeneratorClassName)) { val partitionFields = partitionColumns.get().map(column => StructField(column, StringType)) StructType(partitionFields) } else { diff --git a/hudi-spark-datasource/hudi-spark/pom.xml b/hudi-spark-datasource/hudi-spark/pom.xml index 534691cf0..606f6fa89 100644 --- a/hudi-spark-datasource/hudi-spark/pom.xml +++ b/hudi-spark-datasource/hudi-spark/pom.xml @@ -466,6 +466,12 @@ test-jar test + + org.apache.hudi + hudi-java-client + ${project.version} + test + org.scalatest diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala index d4f7cede7..4896ddf07 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala @@ -17,29 +17,36 @@ package org.apache.hudi -import java.util.Properties +import org.apache.hadoop.conf.Configuration import org.apache.hudi.DataSourceWriteOptions._ +import org.apache.hudi.client.HoodieJavaWriteClient +import org.apache.hudi.client.common.HoodieJavaEngineContext import org.apache.hudi.common.config.HoodieMetadataConfig +import org.apache.hudi.common.engine.EngineType +import org.apache.hudi.common.model.{HoodieRecord, HoodieTableType} import org.apache.hudi.common.table.HoodieTableMetaClient import org.apache.hudi.common.table.view.HoodieTableFileSystemView -import org.apache.hudi.common.testutils.HoodieTestDataGenerator +import org.apache.hudi.common.testutils.{HoodieTestDataGenerator, HoodieTestUtils} +import org.apache.hudi.common.testutils.HoodieTestTable.makeNewCommitTime import org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings import org.apache.hudi.common.util.PartitionPathEncodeUtils +import org.apache.hudi.common.util.StringUtils.isNullOrEmpty import org.apache.hudi.config.HoodieWriteConfig import org.apache.hudi.keygen.ComplexKeyGenerator import org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator.{Config, TimestampType} import org.apache.hudi.testutils.HoodieClientTestBase import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.functions.{lit, struct} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, GreaterThanOrEqual, LessThan, Literal} import org.apache.spark.sql.execution.datasources.PartitionDirectory +import org.apache.spark.sql.functions.{lit, struct} import org.apache.spark.sql.types.StringType -import org.apache.spark.sql.{SaveMode, SparkSession} +import org.apache.spark.sql.{DataFrameWriter, Row, SaveMode, SparkSession} import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.{BeforeEach, Test} import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.{CsvSource, ValueSource} +import org.junit.jupiter.params.provider.{Arguments, CsvSource, MethodSource, ValueSource} +import java.util.Properties import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -92,22 +99,25 @@ class TestHoodieFileIndex extends HoodieClientTestBase { } @ParameterizedTest - @ValueSource(strings = Array( - "org.apache.hudi.keygen.ComplexKeyGenerator", - "org.apache.hudi.keygen.SimpleKeyGenerator", - "org.apache.hudi.keygen.TimestampBasedKeyGenerator")) + @MethodSource(Array("keyGeneratorParameters")) def testPartitionSchemaForBuildInKeyGenerator(keyGenerator: String): Unit = { val records1 = dataGen.generateInsertsContainsAllPartitions("000", 100) val inputDF1 = spark.read.json(spark.sparkContext.parallelize(recordsToStrings(records1), 2)) - inputDF1.write.format("hudi") + val writer: DataFrameWriter[Row] = inputDF1.write.format("hudi") .options(commonOpts) .option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL) - .option(DataSourceWriteOptions.KEYGENERATOR_CLASS_NAME.key, keyGenerator) .option(Config.TIMESTAMP_TYPE_FIELD_PROP, TimestampType.DATE_STRING.name()) .option(Config.TIMESTAMP_INPUT_DATE_FORMAT_PROP, "yyyy/MM/dd") .option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyy-MM-dd") .mode(SaveMode.Overwrite) - .save(basePath) + + if (isNullOrEmpty(keyGenerator)) { + writer.save(basePath) + } else { + writer.option(DataSourceWriteOptions.KEYGENERATOR_CLASS_NAME.key, keyGenerator) + .save(basePath) + } + metaClient = HoodieTableMetaClient.reload(metaClient) val fileIndex = HoodieFileIndex(spark, metaClient, None, queryOpts) assertEquals("partition", fileIndex.partitionSchema.fields.map(_.name).mkString(",")) @@ -132,6 +142,44 @@ class TestHoodieFileIndex extends HoodieClientTestBase { assertEquals("partition", fileIndex.partitionSchema.fields.map(_.name).mkString(",")) } + @Test + def testPartitionSchemaWithoutKeyGenerator(): Unit = { + val metaClient = HoodieTestUtils.init( + hadoopConf, basePath, HoodieTableType.COPY_ON_WRITE, HoodieTableMetaClient.withPropertyBuilder() + .fromMetaClient(this.metaClient) + .setRecordKeyFields("_row_key") + .setPartitionFields("partition_path") + .setTableName("hoodie_test").build()) + val props = Map( + "hoodie.insert.shuffle.parallelism" -> "4", + "hoodie.upsert.shuffle.parallelism" -> "4", + DataSourceWriteOptions.RECORDKEY_FIELD.key -> "_row_key", + DataSourceWriteOptions.PARTITIONPATH_FIELD.key -> "partition_path", + DataSourceWriteOptions.PRECOMBINE_FIELD.key -> "timestamp", + HoodieWriteConfig.TBL_NAME.key -> "hoodie_test", + DataSourceWriteOptions.OPERATION.key -> DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL + ) + val writeConfig = HoodieWriteConfig.newBuilder() + .withEngineType(EngineType.JAVA) + .withPath(basePath) + .withSchema(HoodieTestDataGenerator.TRIP_EXAMPLE_SCHEMA) + .withProps(props) + .build() + val context = new HoodieJavaEngineContext(new Configuration()) + val writeClient = new HoodieJavaWriteClient(context, writeConfig) + val instantTime = makeNewCommitTime() + + val records: java.util.List[HoodieRecord[Nothing]] = + dataGen.generateInsertsContainsAllPartitions(instantTime, 100) + .asInstanceOf[java.util.List[HoodieRecord[Nothing]]] + writeClient.startCommitWithTime(instantTime) + writeClient.insert(records, instantTime) + metaClient.reloadActiveTimeline() + + val fileIndex = HoodieFileIndex(spark, metaClient, None, queryOpts) + assertEquals("partition_path", fileIndex.partitionSchema.fields.map(_.name).mkString(",")) + } + @ParameterizedTest @ValueSource(booleans = Array(true, false)) def testPartitionPruneWithPartitionEncode(partitionEncode: Boolean): Unit = { @@ -139,7 +187,7 @@ class TestHoodieFileIndex extends HoodieClientTestBase { props.setProperty(DataSourceWriteOptions.URL_ENCODE_PARTITIONING.key, String.valueOf(partitionEncode)) initMetaClient(props) val partitions = Array("2021/03/08", "2021/03/09", "2021/03/10", "2021/03/11", "2021/03/12") - val newDataGen = new HoodieTestDataGenerator(partitions) + val newDataGen = new HoodieTestDataGenerator(partitions) val records1 = newDataGen.generateInsertsContainsAllPartitions("000", 100) val inputDF1 = spark.read.json(spark.sparkContext.parallelize(recordsToStrings(records1), 2)) inputDF1.write.format("hudi") @@ -295,3 +343,14 @@ class TestHoodieFileIndex extends HoodieClientTestBase { partitionPaths.map(getFileCountInPartitionPath).sum } } + +object TestHoodieFileIndex { + def keyGeneratorParameters(): java.util.stream.Stream[Arguments] = { + java.util.stream.Stream.of( + Arguments.arguments(null.asInstanceOf[String]), + Arguments.arguments("org.apache.hudi.keygen.ComplexKeyGenerator"), + Arguments.arguments("org.apache.hudi.keygen.SimpleKeyGenerator"), + Arguments.arguments("org.apache.hudi.keygen.TimestampBasedKeyGenerator") + ) + } +}