diff --git a/hudi-client/src/main/java/org/apache/hudi/config/HoodieWriteConfig.java b/hudi-client/src/main/java/org/apache/hudi/config/HoodieWriteConfig.java index 5ac87da2e..50af72569 100644 --- a/hudi-client/src/main/java/org/apache/hudi/config/HoodieWriteConfig.java +++ b/hudi-client/src/main/java/org/apache/hudi/config/HoodieWriteConfig.java @@ -57,6 +57,7 @@ public class HoodieWriteConfig extends DefaultHoodieConfig { private static final String DEFAULT_PARALLELISM = "1500"; private static final String INSERT_PARALLELISM = "hoodie.insert.shuffle.parallelism"; private static final String BULKINSERT_PARALLELISM = "hoodie.bulkinsert.shuffle.parallelism"; + private static final String BULKINSERT_USER_DEFINED_PARTITIONER_CLASS = "hoodie.bulkinsert.user.defined.partitioner.class"; private static final String UPSERT_PARALLELISM = "hoodie.upsert.shuffle.parallelism"; private static final String DELETE_PARALLELISM = "hoodie.delete.shuffle.parallelism"; private static final String DEFAULT_ROLLBACK_PARALLELISM = "100"; @@ -157,6 +158,10 @@ public class HoodieWriteConfig extends DefaultHoodieConfig { return Integer.parseInt(props.getProperty(BULKINSERT_PARALLELISM)); } + public String getUserDefinedBulkInsertPartitionerClass() { + return props.getProperty(BULKINSERT_USER_DEFINED_PARTITIONER_CLASS); + } + public int getInsertShuffleParallelism() { return Integer.parseInt(props.getProperty(INSERT_PARALLELISM)); } @@ -603,6 +608,11 @@ public class HoodieWriteConfig extends DefaultHoodieConfig { return this; } + public Builder withUserDefinedBulkInsertPartitionerClass(String className) { + props.setProperty(BULKINSERT_USER_DEFINED_PARTITIONER_CLASS, className); + return this; + } + public Builder withParallelism(int insertShuffleParallelism, int upsertShuffleParallelism) { props.setProperty(INSERT_PARALLELISM, String.valueOf(insertShuffleParallelism)); props.setProperty(UPSERT_PARALLELISM, String.valueOf(upsertShuffleParallelism)); diff --git a/hudi-spark/src/main/java/org/apache/hudi/DataSourceUtils.java b/hudi-spark/src/main/java/org/apache/hudi/DataSourceUtils.java index 7a4caac05..34f2ef297 100644 --- a/hudi-spark/src/main/java/org/apache/hudi/DataSourceUtils.java +++ b/hudi-spark/src/main/java/org/apache/hudi/DataSourceUtils.java @@ -25,7 +25,9 @@ import org.apache.hudi.common.config.TypedProperties; import org.apache.hudi.common.model.HoodieKey; import org.apache.hudi.common.model.HoodieRecord; import org.apache.hudi.common.model.HoodieRecordPayload; +import org.apache.hudi.common.util.Option; import org.apache.hudi.common.util.ReflectionUtils; +import org.apache.hudi.common.util.StringUtils; import org.apache.hudi.config.HoodieCompactionConfig; import org.apache.hudi.config.HoodieIndexConfig; import org.apache.hudi.config.HoodieWriteConfig; @@ -36,6 +38,7 @@ import org.apache.hudi.hive.HiveSyncConfig; import org.apache.hudi.hive.SlashEncodedDayPartitionValueExtractor; import org.apache.hudi.index.HoodieIndex; import org.apache.hudi.keygen.KeyGenerator; +import org.apache.hudi.table.UserDefinedBulkInsertPartitioner; import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; @@ -152,6 +155,24 @@ public class DataSourceUtils { } } + /** + * Create a UserDefinedBulkInsertPartitioner class via reflection, + *
+ * if the class name of UserDefinedBulkInsertPartitioner is configured through the HoodieWriteConfig. + * @see HoodieWriteConfig#getUserDefinedBulkInsertPartitionerClass() + */ + private static Option createUserDefinedBulkInsertPartitioner(HoodieWriteConfig config) + throws HoodieException { + String bulkInsertPartitionerClass = config.getUserDefinedBulkInsertPartitionerClass(); + try { + return StringUtils.isNullOrEmpty(bulkInsertPartitionerClass) + ? Option.empty() : + Option.of((UserDefinedBulkInsertPartitioner) ReflectionUtils.loadClass(bulkInsertPartitionerClass)); + } catch (Throwable e) { + throw new HoodieException("Could not create UserDefinedBulkInsertPartitioner class " + bulkInsertPartitionerClass, e); + } + } + /** * Create a payload class via reflection, passing in an ordering/precombine value. */ @@ -196,9 +217,11 @@ public class DataSourceUtils { } public static JavaRDD doWriteOperation(HoodieWriteClient client, JavaRDD hoodieRecords, - String instantTime, String operation) { + String instantTime, String operation) throws HoodieException { if (operation.equals(DataSourceWriteOptions.BULK_INSERT_OPERATION_OPT_VAL())) { - return client.bulkInsert(hoodieRecords, instantTime); + Option userDefinedBulkInsertPartitioner = + createUserDefinedBulkInsertPartitioner(client.getConfig()); + return client.bulkInsert(hoodieRecords, instantTime, userDefinedBulkInsertPartitioner); } else if (operation.equals(DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL())) { return client.insert(hoodieRecords, instantTime); } else { diff --git a/hudi-spark/src/test/java/DataSourceTestUtils.java b/hudi-spark/src/test/java/DataSourceTestUtils.java index 036e6c221..0d801bb96 100644 --- a/hudi-spark/src/test/java/DataSourceTestUtils.java +++ b/hudi-spark/src/test/java/DataSourceTestUtils.java @@ -19,7 +19,10 @@ import org.apache.hudi.common.TestRawTripPayload; import org.apache.hudi.common.model.HoodieKey; import org.apache.hudi.common.model.HoodieRecord; +import org.apache.hudi.common.model.HoodieRecordPayload; import org.apache.hudi.common.util.Option; +import org.apache.hudi.table.UserDefinedBulkInsertPartitioner; +import org.apache.spark.api.java.JavaRDD; import java.io.IOException; import java.util.List; @@ -52,4 +55,14 @@ public class DataSourceTestUtils { .map(hr -> "{\"_row_key\":\"" + hr.getRecordKey() + "\",\"partition\":\"" + hr.getPartitionPath() + "\"}") .collect(Collectors.toList()); } + + public static class NoOpBulkInsertPartitioner + implements UserDefinedBulkInsertPartitioner { + + @Override + public JavaRDD> repartitionRecords(JavaRDD> records, int outputSparkPartitions) { + return records; + } + } + } diff --git a/hudi-spark/src/test/java/DataSourceUtilsTest.java b/hudi-spark/src/test/java/DataSourceUtilsTest.java index 4bacb7c8d..c14b852f6 100644 --- a/hudi-spark/src/test/java/DataSourceUtilsTest.java +++ b/hudi-spark/src/test/java/DataSourceUtilsTest.java @@ -17,18 +17,58 @@ */ import org.apache.hudi.DataSourceUtils; +import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.client.HoodieWriteClient; +import org.apache.hudi.common.model.HoodieRecord; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.exception.HoodieException; import org.apache.avro.Schema; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; +import org.apache.spark.api.java.JavaRDD; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; import java.time.LocalDate; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +@ExtendWith(MockitoExtension.class) public class DataSourceUtilsTest { + @Mock + private HoodieWriteClient hoodieWriteClient; + + @Mock + private JavaRDD hoodieRecords; + + @Captor + private ArgumentCaptor