diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala
index 8586300e6..a79ac6f1d 100644
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala
@@ -86,8 +86,10 @@ class SparkHoodieTableFileIndex(spark: SparkSession,
val nameFieldMap = generateFieldMap(schema)
if (partitionColumns.isPresent) {
- if (tableConfig.getKeyGeneratorClassName.equalsIgnoreCase(classOf[TimestampBasedKeyGenerator].getName)
- || tableConfig.getKeyGeneratorClassName.equalsIgnoreCase(classOf[TimestampBasedAvroKeyGenerator].getName)) {
+ // Note that key generator class name could be null
+ val keyGeneratorClassName = tableConfig.getKeyGeneratorClassName
+ if (classOf[TimestampBasedKeyGenerator].getName.equalsIgnoreCase(keyGeneratorClassName)
+ || classOf[TimestampBasedAvroKeyGenerator].getName.equalsIgnoreCase(keyGeneratorClassName)) {
val partitionFields = partitionColumns.get().map(column => StructField(column, StringType))
StructType(partitionFields)
} else {
diff --git a/hudi-spark-datasource/hudi-spark/pom.xml b/hudi-spark-datasource/hudi-spark/pom.xml
index 534691cf0..606f6fa89 100644
--- a/hudi-spark-datasource/hudi-spark/pom.xml
+++ b/hudi-spark-datasource/hudi-spark/pom.xml
@@ -466,6 +466,12 @@
test-jar
test
+
+ org.apache.hudi
+ hudi-java-client
+ ${project.version}
+ test
+
org.scalatest
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
index d4f7cede7..4896ddf07 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
@@ -17,29 +17,36 @@
package org.apache.hudi
-import java.util.Properties
+import org.apache.hadoop.conf.Configuration
import org.apache.hudi.DataSourceWriteOptions._
+import org.apache.hudi.client.HoodieJavaWriteClient
+import org.apache.hudi.client.common.HoodieJavaEngineContext
import org.apache.hudi.common.config.HoodieMetadataConfig
+import org.apache.hudi.common.engine.EngineType
+import org.apache.hudi.common.model.{HoodieRecord, HoodieTableType}
import org.apache.hudi.common.table.HoodieTableMetaClient
import org.apache.hudi.common.table.view.HoodieTableFileSystemView
-import org.apache.hudi.common.testutils.HoodieTestDataGenerator
+import org.apache.hudi.common.testutils.{HoodieTestDataGenerator, HoodieTestUtils}
+import org.apache.hudi.common.testutils.HoodieTestTable.makeNewCommitTime
import org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings
import org.apache.hudi.common.util.PartitionPathEncodeUtils
+import org.apache.hudi.common.util.StringUtils.isNullOrEmpty
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.keygen.ComplexKeyGenerator
import org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator.{Config, TimestampType}
import org.apache.hudi.testutils.HoodieClientTestBase
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.functions.{lit, struct}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, GreaterThanOrEqual, LessThan, Literal}
import org.apache.spark.sql.execution.datasources.PartitionDirectory
+import org.apache.spark.sql.functions.{lit, struct}
import org.apache.spark.sql.types.StringType
-import org.apache.spark.sql.{SaveMode, SparkSession}
+import org.apache.spark.sql.{DataFrameWriter, Row, SaveMode, SparkSession}
import org.junit.jupiter.api.Assertions.assertEquals
-import org.junit.jupiter.api.BeforeEach
+import org.junit.jupiter.api.{BeforeEach, Test}
import org.junit.jupiter.params.ParameterizedTest
-import org.junit.jupiter.params.provider.{CsvSource, ValueSource}
+import org.junit.jupiter.params.provider.{Arguments, CsvSource, MethodSource, ValueSource}
+import java.util.Properties
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
@@ -92,22 +99,25 @@ class TestHoodieFileIndex extends HoodieClientTestBase {
}
@ParameterizedTest
- @ValueSource(strings = Array(
- "org.apache.hudi.keygen.ComplexKeyGenerator",
- "org.apache.hudi.keygen.SimpleKeyGenerator",
- "org.apache.hudi.keygen.TimestampBasedKeyGenerator"))
+ @MethodSource(Array("keyGeneratorParameters"))
def testPartitionSchemaForBuildInKeyGenerator(keyGenerator: String): Unit = {
val records1 = dataGen.generateInsertsContainsAllPartitions("000", 100)
val inputDF1 = spark.read.json(spark.sparkContext.parallelize(recordsToStrings(records1), 2))
- inputDF1.write.format("hudi")
+ val writer: DataFrameWriter[Row] = inputDF1.write.format("hudi")
.options(commonOpts)
.option(DataSourceWriteOptions.OPERATION.key, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL)
- .option(DataSourceWriteOptions.KEYGENERATOR_CLASS_NAME.key, keyGenerator)
.option(Config.TIMESTAMP_TYPE_FIELD_PROP, TimestampType.DATE_STRING.name())
.option(Config.TIMESTAMP_INPUT_DATE_FORMAT_PROP, "yyyy/MM/dd")
.option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyy-MM-dd")
.mode(SaveMode.Overwrite)
- .save(basePath)
+
+ if (isNullOrEmpty(keyGenerator)) {
+ writer.save(basePath)
+ } else {
+ writer.option(DataSourceWriteOptions.KEYGENERATOR_CLASS_NAME.key, keyGenerator)
+ .save(basePath)
+ }
+
metaClient = HoodieTableMetaClient.reload(metaClient)
val fileIndex = HoodieFileIndex(spark, metaClient, None, queryOpts)
assertEquals("partition", fileIndex.partitionSchema.fields.map(_.name).mkString(","))
@@ -132,6 +142,44 @@ class TestHoodieFileIndex extends HoodieClientTestBase {
assertEquals("partition", fileIndex.partitionSchema.fields.map(_.name).mkString(","))
}
+ @Test
+ def testPartitionSchemaWithoutKeyGenerator(): Unit = {
+ val metaClient = HoodieTestUtils.init(
+ hadoopConf, basePath, HoodieTableType.COPY_ON_WRITE, HoodieTableMetaClient.withPropertyBuilder()
+ .fromMetaClient(this.metaClient)
+ .setRecordKeyFields("_row_key")
+ .setPartitionFields("partition_path")
+ .setTableName("hoodie_test").build())
+ val props = Map(
+ "hoodie.insert.shuffle.parallelism" -> "4",
+ "hoodie.upsert.shuffle.parallelism" -> "4",
+ DataSourceWriteOptions.RECORDKEY_FIELD.key -> "_row_key",
+ DataSourceWriteOptions.PARTITIONPATH_FIELD.key -> "partition_path",
+ DataSourceWriteOptions.PRECOMBINE_FIELD.key -> "timestamp",
+ HoodieWriteConfig.TBL_NAME.key -> "hoodie_test",
+ DataSourceWriteOptions.OPERATION.key -> DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL
+ )
+ val writeConfig = HoodieWriteConfig.newBuilder()
+ .withEngineType(EngineType.JAVA)
+ .withPath(basePath)
+ .withSchema(HoodieTestDataGenerator.TRIP_EXAMPLE_SCHEMA)
+ .withProps(props)
+ .build()
+ val context = new HoodieJavaEngineContext(new Configuration())
+ val writeClient = new HoodieJavaWriteClient(context, writeConfig)
+ val instantTime = makeNewCommitTime()
+
+ val records: java.util.List[HoodieRecord[Nothing]] =
+ dataGen.generateInsertsContainsAllPartitions(instantTime, 100)
+ .asInstanceOf[java.util.List[HoodieRecord[Nothing]]]
+ writeClient.startCommitWithTime(instantTime)
+ writeClient.insert(records, instantTime)
+ metaClient.reloadActiveTimeline()
+
+ val fileIndex = HoodieFileIndex(spark, metaClient, None, queryOpts)
+ assertEquals("partition_path", fileIndex.partitionSchema.fields.map(_.name).mkString(","))
+ }
+
@ParameterizedTest
@ValueSource(booleans = Array(true, false))
def testPartitionPruneWithPartitionEncode(partitionEncode: Boolean): Unit = {
@@ -139,7 +187,7 @@ class TestHoodieFileIndex extends HoodieClientTestBase {
props.setProperty(DataSourceWriteOptions.URL_ENCODE_PARTITIONING.key, String.valueOf(partitionEncode))
initMetaClient(props)
val partitions = Array("2021/03/08", "2021/03/09", "2021/03/10", "2021/03/11", "2021/03/12")
- val newDataGen = new HoodieTestDataGenerator(partitions)
+ val newDataGen = new HoodieTestDataGenerator(partitions)
val records1 = newDataGen.generateInsertsContainsAllPartitions("000", 100)
val inputDF1 = spark.read.json(spark.sparkContext.parallelize(recordsToStrings(records1), 2))
inputDF1.write.format("hudi")
@@ -295,3 +343,14 @@ class TestHoodieFileIndex extends HoodieClientTestBase {
partitionPaths.map(getFileCountInPartitionPath).sum
}
}
+
+object TestHoodieFileIndex {
+ def keyGeneratorParameters(): java.util.stream.Stream[Arguments] = {
+ java.util.stream.Stream.of(
+ Arguments.arguments(null.asInstanceOf[String]),
+ Arguments.arguments("org.apache.hudi.keygen.ComplexKeyGenerator"),
+ Arguments.arguments("org.apache.hudi.keygen.SimpleKeyGenerator"),
+ Arguments.arguments("org.apache.hudi.keygen.TimestampBasedKeyGenerator")
+ )
+ }
+}