diff --git a/hudi-client/src/main/java/org/apache/hudi/config/HoodieWriteConfig.java b/hudi-client/src/main/java/org/apache/hudi/config/HoodieWriteConfig.java
index 5ac87da2e..50af72569 100644
--- a/hudi-client/src/main/java/org/apache/hudi/config/HoodieWriteConfig.java
+++ b/hudi-client/src/main/java/org/apache/hudi/config/HoodieWriteConfig.java
@@ -57,6 +57,7 @@ public class HoodieWriteConfig extends DefaultHoodieConfig {
private static final String DEFAULT_PARALLELISM = "1500";
private static final String INSERT_PARALLELISM = "hoodie.insert.shuffle.parallelism";
private static final String BULKINSERT_PARALLELISM = "hoodie.bulkinsert.shuffle.parallelism";
+ private static final String BULKINSERT_USER_DEFINED_PARTITIONER_CLASS = "hoodie.bulkinsert.user.defined.partitioner.class";
private static final String UPSERT_PARALLELISM = "hoodie.upsert.shuffle.parallelism";
private static final String DELETE_PARALLELISM = "hoodie.delete.shuffle.parallelism";
private static final String DEFAULT_ROLLBACK_PARALLELISM = "100";
@@ -157,6 +158,10 @@ public class HoodieWriteConfig extends DefaultHoodieConfig {
return Integer.parseInt(props.getProperty(BULKINSERT_PARALLELISM));
}
+ public String getUserDefinedBulkInsertPartitionerClass() {
+ return props.getProperty(BULKINSERT_USER_DEFINED_PARTITIONER_CLASS);
+ }
+
public int getInsertShuffleParallelism() {
return Integer.parseInt(props.getProperty(INSERT_PARALLELISM));
}
@@ -603,6 +608,11 @@ public class HoodieWriteConfig extends DefaultHoodieConfig {
return this;
}
+ public Builder withUserDefinedBulkInsertPartitionerClass(String className) {
+ props.setProperty(BULKINSERT_USER_DEFINED_PARTITIONER_CLASS, className);
+ return this;
+ }
+
public Builder withParallelism(int insertShuffleParallelism, int upsertShuffleParallelism) {
props.setProperty(INSERT_PARALLELISM, String.valueOf(insertShuffleParallelism));
props.setProperty(UPSERT_PARALLELISM, String.valueOf(upsertShuffleParallelism));
diff --git a/hudi-spark/src/main/java/org/apache/hudi/DataSourceUtils.java b/hudi-spark/src/main/java/org/apache/hudi/DataSourceUtils.java
index 7a4caac05..34f2ef297 100644
--- a/hudi-spark/src/main/java/org/apache/hudi/DataSourceUtils.java
+++ b/hudi-spark/src/main/java/org/apache/hudi/DataSourceUtils.java
@@ -25,7 +25,9 @@ import org.apache.hudi.common.config.TypedProperties;
import org.apache.hudi.common.model.HoodieKey;
import org.apache.hudi.common.model.HoodieRecord;
import org.apache.hudi.common.model.HoodieRecordPayload;
+import org.apache.hudi.common.util.Option;
import org.apache.hudi.common.util.ReflectionUtils;
+import org.apache.hudi.common.util.StringUtils;
import org.apache.hudi.config.HoodieCompactionConfig;
import org.apache.hudi.config.HoodieIndexConfig;
import org.apache.hudi.config.HoodieWriteConfig;
@@ -36,6 +38,7 @@ import org.apache.hudi.hive.HiveSyncConfig;
import org.apache.hudi.hive.SlashEncodedDayPartitionValueExtractor;
import org.apache.hudi.index.HoodieIndex;
import org.apache.hudi.keygen.KeyGenerator;
+import org.apache.hudi.table.UserDefinedBulkInsertPartitioner;
import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;
@@ -152,6 +155,24 @@ public class DataSourceUtils {
}
}
+ /**
+ * Create a UserDefinedBulkInsertPartitioner class via reflection,
+ *
+ * if the class name of UserDefinedBulkInsertPartitioner is configured through the HoodieWriteConfig.
+ * @see HoodieWriteConfig#getUserDefinedBulkInsertPartitionerClass()
+ */
+ private static Option createUserDefinedBulkInsertPartitioner(HoodieWriteConfig config)
+ throws HoodieException {
+ String bulkInsertPartitionerClass = config.getUserDefinedBulkInsertPartitionerClass();
+ try {
+ return StringUtils.isNullOrEmpty(bulkInsertPartitionerClass)
+ ? Option.empty() :
+ Option.of((UserDefinedBulkInsertPartitioner) ReflectionUtils.loadClass(bulkInsertPartitionerClass));
+ } catch (Throwable e) {
+ throw new HoodieException("Could not create UserDefinedBulkInsertPartitioner class " + bulkInsertPartitionerClass, e);
+ }
+ }
+
/**
* Create a payload class via reflection, passing in an ordering/precombine value.
*/
@@ -196,9 +217,11 @@ public class DataSourceUtils {
}
public static JavaRDD doWriteOperation(HoodieWriteClient client, JavaRDD hoodieRecords,
- String instantTime, String operation) {
+ String instantTime, String operation) throws HoodieException {
if (operation.equals(DataSourceWriteOptions.BULK_INSERT_OPERATION_OPT_VAL())) {
- return client.bulkInsert(hoodieRecords, instantTime);
+ Option userDefinedBulkInsertPartitioner =
+ createUserDefinedBulkInsertPartitioner(client.getConfig());
+ return client.bulkInsert(hoodieRecords, instantTime, userDefinedBulkInsertPartitioner);
} else if (operation.equals(DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL())) {
return client.insert(hoodieRecords, instantTime);
} else {
diff --git a/hudi-spark/src/test/java/DataSourceTestUtils.java b/hudi-spark/src/test/java/DataSourceTestUtils.java
index 036e6c221..0d801bb96 100644
--- a/hudi-spark/src/test/java/DataSourceTestUtils.java
+++ b/hudi-spark/src/test/java/DataSourceTestUtils.java
@@ -19,7 +19,10 @@
import org.apache.hudi.common.TestRawTripPayload;
import org.apache.hudi.common.model.HoodieKey;
import org.apache.hudi.common.model.HoodieRecord;
+import org.apache.hudi.common.model.HoodieRecordPayload;
import org.apache.hudi.common.util.Option;
+import org.apache.hudi.table.UserDefinedBulkInsertPartitioner;
+import org.apache.spark.api.java.JavaRDD;
import java.io.IOException;
import java.util.List;
@@ -52,4 +55,14 @@ public class DataSourceTestUtils {
.map(hr -> "{\"_row_key\":\"" + hr.getRecordKey() + "\",\"partition\":\"" + hr.getPartitionPath() + "\"}")
.collect(Collectors.toList());
}
+
+ public static class NoOpBulkInsertPartitioner
+ implements UserDefinedBulkInsertPartitioner {
+
+ @Override
+ public JavaRDD> repartitionRecords(JavaRDD> records, int outputSparkPartitions) {
+ return records;
+ }
+ }
+
}
diff --git a/hudi-spark/src/test/java/DataSourceUtilsTest.java b/hudi-spark/src/test/java/DataSourceUtilsTest.java
index 4bacb7c8d..c14b852f6 100644
--- a/hudi-spark/src/test/java/DataSourceUtilsTest.java
+++ b/hudi-spark/src/test/java/DataSourceUtilsTest.java
@@ -17,18 +17,58 @@
*/
import org.apache.hudi.DataSourceUtils;
+import org.apache.hudi.DataSourceWriteOptions;
+import org.apache.hudi.client.HoodieWriteClient;
+import org.apache.hudi.common.model.HoodieRecord;
+import org.apache.hudi.common.util.Option;
+import org.apache.hudi.config.HoodieWriteConfig;
+import org.apache.hudi.exception.HoodieException;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
+import org.apache.spark.api.java.JavaRDD;
+import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
import java.time.LocalDate;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.instanceOf;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+@ExtendWith(MockitoExtension.class)
public class DataSourceUtilsTest {
+ @Mock
+ private HoodieWriteClient hoodieWriteClient;
+
+ @Mock
+ private JavaRDD hoodieRecords;
+
+ @Captor
+ private ArgumentCaptor