diff --git a/hudi-flink/src/main/java/org/apache/hudi/operator/partitioner/BucketAssignFunction.java b/hudi-flink/src/main/java/org/apache/hudi/operator/partitioner/BucketAssignFunction.java index 70289b7b5..8ce7f4043 100644 --- a/hudi-flink/src/main/java/org/apache/hudi/operator/partitioner/BucketAssignFunction.java +++ b/hudi-flink/src/main/java/org/apache/hudi/operator/partitioner/BucketAssignFunction.java @@ -55,6 +55,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.List; +import java.util.Set; import java.util.stream.Collectors; /** @@ -108,7 +109,7 @@ public class BucketAssignFunction> * All the partition paths when the task starts. It is used to help checking whether all the partitions * are loaded into the state. */ - private transient List initialPartitionsToLoad; + private transient Set initialPartitionsToLoad; /** * State to book-keep which partition is loaded into the index state {@code indexState}. @@ -136,15 +137,10 @@ public class BucketAssignFunction> new SerializableConfiguration(this.hadoopConf), new FlinkTaskContextSupplier(getRuntimeContext())); this.bucketAssigner = new BucketAssigner(context, writeConfig); - List allPartitionPaths = FSUtils.getAllPartitionPaths(this.context, - this.conf.getString(FlinkOptions.PATH), false, false, false); - final int parallelism = getRuntimeContext().getNumberOfParallelSubtasks(); - final int maxParallelism = getRuntimeContext().getMaxNumberOfParallelSubtasks(); - final int taskID = getRuntimeContext().getIndexOfThisSubtask(); - // reference: org.apache.flink.streaming.api.datastream.KeyedStream - this.initialPartitionsToLoad = allPartitionPaths.stream() - .filter(partition -> KeyGroupRangeAssignment.assignKeyToParallelOperator(partition, maxParallelism, parallelism) == taskID) - .collect(Collectors.toList()); + + // initialize and check the partitions load state + loadInitialPartitions(); + checkPartitionsLoaded(); } @Override @@ -163,9 +159,6 @@ public class BucketAssignFunction> MapStateDescriptor partitionLoadStateDesc = new MapStateDescriptor<>("partitionLoadState", Types.STRING, Types.INT); partitionLoadState = context.getKeyedStateStore().getMapState(partitionLoadStateDesc); - if (context.isRestored()) { - checkPartitionsLoaded(); - } } @SuppressWarnings("unchecked") @@ -178,7 +171,9 @@ public class BucketAssignFunction> final HoodieKey hoodieKey = record.getKey(); final BucketInfo bucketInfo; final HoodieRecordLocation location; - if (!allPartitionsLoaded && !partitionLoadState.contains(hoodieKey.getPartitionPath())) { + if (!allPartitionsLoaded + && initialPartitionsToLoad.contains(hoodieKey.getPartitionPath()) // this is an existing partition + && !partitionLoadState.contains(hoodieKey.getPartitionPath())) { // If the partition records are never loaded, load the records first. loadRecords(hoodieKey.getPartitionPath()); } @@ -244,6 +239,21 @@ public class BucketAssignFunction> partitionLoadState.put(partitionPath, 0); } + /** + * Loads the existing partitions for this task. + */ + private void loadInitialPartitions() { + List allPartitionPaths = FSUtils.getAllPartitionPaths(this.context, + this.conf.getString(FlinkOptions.PATH), false, false, false); + final int parallelism = getRuntimeContext().getNumberOfParallelSubtasks(); + final int maxParallelism = getRuntimeContext().getMaxNumberOfParallelSubtasks(); + final int taskID = getRuntimeContext().getIndexOfThisSubtask(); + // reference: org.apache.flink.streaming.api.datastream.KeyedStream + this.initialPartitionsToLoad = allPartitionPaths.stream() + .filter(partition -> KeyGroupRangeAssignment.assignKeyToParallelOperator(partition, maxParallelism, parallelism) == taskID) + .collect(Collectors.toSet()); + } + /** * Checks whether all the partitions of the table are loaded into the state, * set the flag {@code allPartitionsLoaded} to true if it is. @@ -271,6 +281,7 @@ public class BucketAssignFunction> public void clearIndexState() { this.allPartitionsLoaded = false; this.indexState.clear(); + loadInitialPartitions(); } @VisibleForTesting