diff --git a/README.md b/README.md
index 6d3475755..be38999ae 100644
--- a/README.md
+++ b/README.md
@@ -90,21 +90,9 @@ mvn clean package -DskipTests -Dspark3
mvn clean package -DskipTests -Dspark3.1.x
```
-### Build without spark-avro module
+### What about "spark-avro" module?
-The default hudi-jar bundles spark-avro module. To build without spark-avro module, build using `spark-shade-unbundle-avro` profile
-
-```
-# Checkout code and build
-git clone https://github.com/apache/hudi.git && cd hudi
-mvn clean package -DskipTests -Pspark-shade-unbundle-avro
-
-# Start command
-spark-2.4.4-bin-hadoop2.7/bin/spark-shell \
- --packages org.apache.spark:spark-avro_2.11:2.4.4 \
- --jars `ls packaging/hudi-spark-bundle/target/hudi-spark-bundle_2.11-*.*.*-SNAPSHOT.jar` \
- --conf 'spark.serializer=org.apache.spark.serializer.KryoSerializer'
-```
+Starting from versions 0.11, Hudi no longer requires `spark-avro` to be specified using `--packages`
## Running Tests
diff --git a/docker/demo/config/test-suite/templates/spark_command.txt.template b/docker/demo/config/test-suite/templates/spark_command.txt.template
index 563d98b7f..bf19631b0 100644
--- a/docker/demo/config/test-suite/templates/spark_command.txt.template
+++ b/docker/demo/config/test-suite/templates/spark_command.txt.template
@@ -15,7 +15,6 @@
# limitations under the License.
spark-submit \
---packages org.apache.spark:spark-avro_2.11:2.4.0 \
--conf spark.task.cpus=1 \
--conf spark.executor.cores=1 \
--conf spark.task.maxFailures=100 \
diff --git a/hudi-cli/pom.xml b/hudi-cli/pom.xml
index 29bdf85ab..5f499d6e7 100644
--- a/hudi-cli/pom.xml
+++ b/hudi-cli/pom.xml
@@ -225,10 +225,6 @@
org.apache.spark
spark-sql_${scala.binary.version}
-
- org.apache.spark
- spark-avro_${scala.binary.version}
-
org.springframework.shell
diff --git a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/HoodieTable.java b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/HoodieTable.java
index 62a4f089a..032961aec 100644
--- a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/HoodieTable.java
+++ b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/table/HoodieTable.java
@@ -18,6 +18,11 @@
package org.apache.hudi.table;
+import org.apache.avro.Schema;
+import org.apache.avro.specific.SpecificRecordBase;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
import org.apache.hudi.avro.HoodieAvroUtils;
import org.apache.hudi.avro.model.HoodieCleanMetadata;
import org.apache.hudi.avro.model.HoodieCleanerPlan;
@@ -72,17 +77,9 @@ import org.apache.hudi.table.marker.WriteMarkers;
import org.apache.hudi.table.marker.WriteMarkersFactory;
import org.apache.hudi.table.storage.HoodieLayoutFactory;
import org.apache.hudi.table.storage.HoodieStorageLayout;
-
-import org.apache.avro.Schema;
-import org.apache.avro.specific.SpecificRecordBase;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
-import javax.annotation.Nonnull;
-
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
@@ -261,19 +258,6 @@ public abstract class HoodieTable implem
*/
public abstract HoodieWriteMetadata insertOverwriteTable(HoodieEngineContext context, String instantTime, I records);
- /**
- * Updates Metadata Indexes (like Column Stats index)
- * TODO rebase onto metadata table (post RFC-27)
- *
- * @param context instance of {@link HoodieEngineContext}
- * @param instantTime instant of the carried operation triggering the update
- */
- public abstract void updateMetadataIndexes(
- @Nonnull HoodieEngineContext context,
- @Nonnull List stats,
- @Nonnull String instantTime
- ) throws Exception;
-
public HoodieWriteConfig getConfig() {
return config;
}
diff --git a/hudi-client/hudi-flink-client/src/main/java/org/apache/hudi/table/HoodieFlinkCopyOnWriteTable.java b/hudi-client/hudi-flink-client/src/main/java/org/apache/hudi/table/HoodieFlinkCopyOnWriteTable.java
index 14937d6fe..f00dbfdf4 100644
--- a/hudi-client/hudi-flink-client/src/main/java/org/apache/hudi/table/HoodieFlinkCopyOnWriteTable.java
+++ b/hudi-client/hudi-flink-client/src/main/java/org/apache/hudi/table/HoodieFlinkCopyOnWriteTable.java
@@ -34,7 +34,6 @@ import org.apache.hudi.common.model.HoodieBaseFile;
import org.apache.hudi.common.model.HoodieKey;
import org.apache.hudi.common.model.HoodieRecord;
import org.apache.hudi.common.model.HoodieRecordPayload;
-import org.apache.hudi.common.model.HoodieWriteStat;
import org.apache.hudi.common.table.HoodieTableMetaClient;
import org.apache.hudi.common.table.timeline.HoodieInstant;
import org.apache.hudi.common.table.timeline.HoodieTimeline;
@@ -63,12 +62,9 @@ import org.apache.hudi.table.action.commit.FlinkUpsertCommitActionExecutor;
import org.apache.hudi.table.action.commit.FlinkUpsertPreppedCommitActionExecutor;
import org.apache.hudi.table.action.rollback.BaseRollbackPlanActionExecutor;
import org.apache.hudi.table.action.rollback.CopyOnWriteRollbackActionExecutor;
-
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import javax.annotation.Nonnull;
-
import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
@@ -246,11 +242,6 @@ public class HoodieFlinkCopyOnWriteTable
throw new HoodieNotSupportedException("DeletePartitions is not supported yet");
}
- @Override
- public void updateMetadataIndexes(@Nonnull HoodieEngineContext context, @Nonnull List stats, @Nonnull String instantTime) {
- throw new HoodieNotSupportedException("update statistics is not supported yet");
- }
-
@Override
public HoodieWriteMetadata> upsertPrepped(HoodieEngineContext context, String instantTime, List> preppedRecords) {
throw new HoodieNotSupportedException("This method should not be invoked");
diff --git a/hudi-client/hudi-java-client/src/main/java/org/apache/hudi/table/HoodieJavaCopyOnWriteTable.java b/hudi-client/hudi-java-client/src/main/java/org/apache/hudi/table/HoodieJavaCopyOnWriteTable.java
index 06c23049d..dd40a0afa 100644
--- a/hudi-client/hudi-java-client/src/main/java/org/apache/hudi/table/HoodieJavaCopyOnWriteTable.java
+++ b/hudi-client/hudi-java-client/src/main/java/org/apache/hudi/table/HoodieJavaCopyOnWriteTable.java
@@ -34,7 +34,6 @@ import org.apache.hudi.common.model.HoodieBaseFile;
import org.apache.hudi.common.model.HoodieKey;
import org.apache.hudi.common.model.HoodieRecord;
import org.apache.hudi.common.model.HoodieRecordPayload;
-import org.apache.hudi.common.model.HoodieWriteStat;
import org.apache.hudi.common.table.HoodieTableMetaClient;
import org.apache.hudi.common.table.timeline.HoodieInstant;
import org.apache.hudi.common.table.timeline.HoodieTimeline;
@@ -66,12 +65,9 @@ import org.apache.hudi.table.action.rollback.BaseRollbackPlanActionExecutor;
import org.apache.hudi.table.action.rollback.CopyOnWriteRollbackActionExecutor;
import org.apache.hudi.table.action.rollback.RestorePlanActionExecutor;
import org.apache.hudi.table.action.savepoint.SavepointActionExecutor;
-
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import javax.annotation.Nonnull;
-
import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
@@ -173,11 +169,6 @@ public class HoodieJavaCopyOnWriteTable
context, config, this, instantTime, records).execute();
}
- @Override
- public void updateMetadataIndexes(@Nonnull HoodieEngineContext context, @Nonnull List stats, @Nonnull String instantTime) {
- throw new HoodieNotSupportedException("update statistics is not supported yet");
- }
-
@Override
public Option scheduleCompaction(HoodieEngineContext context,
String instantTime,
diff --git a/hudi-client/hudi-spark-client/pom.xml b/hudi-client/hudi-spark-client/pom.xml
index 0688fedac..fa9390d56 100644
--- a/hudi-client/hudi-spark-client/pom.xml
+++ b/hudi-client/hudi-spark-client/pom.xml
@@ -53,11 +53,6 @@
org.apache.spark
spark-sql_${scala.binary.version}
-
- org.apache.spark
- spark-avro_${scala.binary.version}
- provided
-
diff --git a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/client/SparkRDDWriteClient.java b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/client/SparkRDDWriteClient.java
index ac9259c51..6627eeecb 100644
--- a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/client/SparkRDDWriteClient.java
+++ b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/client/SparkRDDWriteClient.java
@@ -389,9 +389,6 @@ public class SparkRDDWriteClient extends
finalizeWrite(table, clusteringCommitTime, writeStats);
// Update table's metadata (table)
updateTableMetadata(table, metadata, clusteringInstant);
- // Update tables' metadata indexes
- // NOTE: This overlaps w/ metadata table (above) and will be reconciled in the future
- table.updateMetadataIndexes(context, writeStats, clusteringCommitTime);
LOG.info("Committing Clustering " + clusteringCommitTime + ". Finished with result " + metadata);
diff --git a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/table/HoodieSparkCopyOnWriteTable.java b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/table/HoodieSparkCopyOnWriteTable.java
index 8f5211212..3a46e3531 100644
--- a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/table/HoodieSparkCopyOnWriteTable.java
+++ b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/table/HoodieSparkCopyOnWriteTable.java
@@ -18,8 +18,6 @@
package org.apache.hudi.table;
-import org.apache.hudi.AvroConversionUtils;
-import org.apache.hudi.avro.HoodieAvroUtils;
import org.apache.hudi.avro.model.HoodieCleanMetadata;
import org.apache.hudi.avro.model.HoodieCleanerPlan;
import org.apache.hudi.avro.model.HoodieClusteringPlan;
@@ -38,18 +36,14 @@ import org.apache.hudi.common.model.HoodieBaseFile;
import org.apache.hudi.common.model.HoodieKey;
import org.apache.hudi.common.model.HoodieRecord;
import org.apache.hudi.common.model.HoodieRecordPayload;
-import org.apache.hudi.common.model.HoodieWriteStat;
import org.apache.hudi.common.table.HoodieTableMetaClient;
-import org.apache.hudi.common.table.TableSchemaResolver;
import org.apache.hudi.common.table.timeline.HoodieInstant;
import org.apache.hudi.common.table.timeline.HoodieTimeline;
import org.apache.hudi.common.util.Option;
-import org.apache.hudi.common.util.StringUtils;
import org.apache.hudi.config.HoodieWriteConfig;
import org.apache.hudi.exception.HoodieIOException;
import org.apache.hudi.exception.HoodieNotSupportedException;
import org.apache.hudi.exception.HoodieUpsertException;
-import org.apache.hudi.index.columnstats.ColumnStatsIndexHelper;
import org.apache.hudi.io.HoodieCreateHandle;
import org.apache.hudi.io.HoodieMergeHandle;
import org.apache.hudi.io.HoodieSortedMergeHandle;
@@ -78,21 +72,14 @@ import org.apache.hudi.table.action.rollback.BaseRollbackPlanActionExecutor;
import org.apache.hudi.table.action.rollback.CopyOnWriteRollbackActionExecutor;
import org.apache.hudi.table.action.rollback.RestorePlanActionExecutor;
import org.apache.hudi.table.action.savepoint.SavepointActionExecutor;
-
-import org.apache.avro.Schema;
-import org.apache.hadoop.fs.Path;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
-import javax.annotation.Nonnull;
-
import java.io.IOException;
-import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
-import java.util.stream.Collectors;
/**
* Implementation of a very heavily read-optimized Hoodie Table where, all data is stored in base files, with
@@ -172,63 +159,6 @@ public class HoodieSparkCopyOnWriteTable
return new SparkInsertOverwriteTableCommitActionExecutor(context, config, this, instantTime, records).execute();
}
- @Override
- public void updateMetadataIndexes(@Nonnull HoodieEngineContext context, @Nonnull List stats, @Nonnull String instantTime) throws Exception {
- updateColumnsStatsIndex(context, stats, instantTime);
- }
-
- private void updateColumnsStatsIndex(
- @Nonnull HoodieEngineContext context,
- @Nonnull List updatedFilesStats,
- @Nonnull String instantTime
- ) throws Exception {
- String sortColsList = config.getClusteringSortColumns();
- String basePath = metaClient.getBasePath();
- String indexPath = metaClient.getColumnStatsIndexPath();
-
- List touchedFiles =
- updatedFilesStats.stream()
- .map(s -> new Path(basePath, s.getPath()).toString())
- .collect(Collectors.toList());
-
- if (touchedFiles.isEmpty() || StringUtils.isNullOrEmpty(sortColsList) || StringUtils.isNullOrEmpty(indexPath)) {
- return;
- }
-
- LOG.info(String.format("Updating column-statistics index table (%s)", indexPath));
-
- List sortCols = Arrays.stream(sortColsList.split(","))
- .map(String::trim)
- .collect(Collectors.toList());
-
- HoodieSparkEngineContext sparkEngineContext = (HoodieSparkEngineContext)context;
-
- // Fetch table schema to appropriately construct col-stats index schema
- Schema tableWriteSchema =
- HoodieAvroUtils.createHoodieWriteSchema(
- new TableSchemaResolver(metaClient).getTableAvroSchemaWithoutMetadataFields()
- );
-
- List completedCommits =
- metaClient.getCommitsTimeline()
- .filterCompletedInstants()
- .getInstants()
- .map(HoodieInstant::getTimestamp)
- .collect(Collectors.toList());
-
- ColumnStatsIndexHelper.updateColumnStatsIndexFor(
- sparkEngineContext.getSqlContext().sparkSession(),
- AvroConversionUtils.convertAvroSchemaToStructType(tableWriteSchema),
- touchedFiles,
- sortCols,
- indexPath,
- instantTime,
- completedCommits
- );
-
- LOG.info(String.format("Successfully updated column-statistics index at instant (%s)", instantTime));
- }
-
@Override
public Option scheduleCompaction(HoodieEngineContext context, String instantTime, Option
-
- org.apache.spark
- spark-avro_${scala.binary.version}
-
diff --git a/hudi-integ-test/README.md b/hudi-integ-test/README.md
index 7ee4598ba..6c1bad138 100644
--- a/hudi-integ-test/README.md
+++ b/hudi-integ-test/README.md
@@ -126,7 +126,7 @@ NOTE : The properties-file should have all the necessary information required to
information on what properties need to be set, take a look at the test suite section under demo steps.
```
shell$ ./prepare_integration_suite.sh --spark-command
-spark-submit --packages com.databricks:spark-avro_2.11:4.0.0 --master prepare_integration_suite.sh --deploy-mode
+spark-submit --master prepare_integration_suite.sh --deploy-mode
--properties-file --class org.apache.hudi.integ.testsuite.HoodieTestSuiteJob target/hudi-integ-test-0.6
.0-SNAPSHOT.jar --source-class --source-ordering-field --input-base-path --target-base-path --target-table --props --storage-type --payload-class --workload-yaml-path --input-file-size --
```
@@ -198,7 +198,6 @@ Launch a Copy-on-Write job:
=========================
## Run the following command to start the test suite
spark-submit \
---packages org.apache.spark:spark-avro_2.11:2.4.0 \
--conf spark.task.cpus=1 \
--conf spark.executor.cores=1 \
--conf spark.task.maxFailures=100 \
@@ -245,7 +244,6 @@ Or a Merge-on-Read job:
=========================
## Run the following command to start the test suite
spark-submit \
---packages org.apache.spark:spark-avro_2.11:2.4.0 \
--conf spark.task.cpus=1 \
--conf spark.executor.cores=1 \
--conf spark.task.maxFailures=100 \
@@ -438,7 +436,6 @@ docker exec -it adhoc-2 /bin/bash
Sample COW command
```
spark-submit \
---packages org.apache.spark:spark-avro_2.11:2.4.0 \
--conf spark.task.cpus=1 \
--conf spark.executor.cores=1 \
--conf spark.task.maxFailures=100 \
diff --git a/hudi-integ-test/src/test/java/org/apache/hudi/integ/ITTestBase.java b/hudi-integ-test/src/test/java/org/apache/hudi/integ/ITTestBase.java
index 7ec2ba506..db87f5dce 100644
--- a/hudi-integ-test/src/test/java/org/apache/hudi/integ/ITTestBase.java
+++ b/hudi-integ-test/src/test/java/org/apache/hudi/integ/ITTestBase.java
@@ -115,7 +115,7 @@ public abstract class ITTestBase {
.append(" --master local[2] --driver-class-path ").append(HADOOP_CONF_DIR)
.append(
" --conf spark.sql.hive.convertMetastoreParquet=false --deploy-mode client --driver-memory 1G --executor-memory 1G --num-executors 1 ")
- .append(" --packages org.apache.spark:spark-avro_2.11:2.4.4 ").append(" -i ").append(commandFile).toString();
+ .append(" -i ").append(commandFile).toString();
}
static String getPrestoConsoleCommand(String commandFile) {
diff --git a/hudi-integ-test/src/test/java/org/apache/hudi/integ/command/ITTestHoodieSyncCommand.java b/hudi-integ-test/src/test/java/org/apache/hudi/integ/command/ITTestHoodieSyncCommand.java
index a6a4c3ec4..e6a4b6146 100644
--- a/hudi-integ-test/src/test/java/org/apache/hudi/integ/command/ITTestHoodieSyncCommand.java
+++ b/hudi-integ-test/src/test/java/org/apache/hudi/integ/command/ITTestHoodieSyncCommand.java
@@ -60,7 +60,7 @@ public class ITTestHoodieSyncCommand extends HoodieTestHiveBase {
}
private void syncHoodieTable(String hiveTableName, String op) throws Exception {
- StringBuilder cmdBuilder = new StringBuilder("spark-submit --packages org.apache.spark:spark-avro_2.11:2.4.4 ")
+ StringBuilder cmdBuilder = new StringBuilder("spark-submit")
.append(" --class org.apache.hudi.utilities.deltastreamer.HoodieDeltaStreamer ").append(HUDI_UTILITIES_BUNDLE)
.append(" --table-type COPY_ON_WRITE ")
.append(" --base-file-format ").append(HoodieFileFormat.PARQUET.toString())
diff --git a/hudi-spark-datasource/hudi-spark-common/pom.xml b/hudi-spark-datasource/hudi-spark-common/pom.xml
index 31ac48025..0960e7a94 100644
--- a/hudi-spark-datasource/hudi-spark-common/pom.xml
+++ b/hudi-spark-datasource/hudi-spark-common/pom.xml
@@ -211,13 +211,6 @@
test
-
-
- org.apache.spark
- spark-avro_${scala.binary.version}
- provided
-
-
org.apache.hudi
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala
index 4fc86f729..1528e7f0b 100644
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala
@@ -70,7 +70,7 @@ abstract class HoodieBaseRelation(val sqlContext: SQLContext,
val metaClient: HoodieTableMetaClient,
val optParams: Map[String, String],
userSchema: Option[StructType])
- extends BaseRelation with PrunedFilteredScan with Logging {
+ extends BaseRelation with PrunedFilteredScan with Logging with SparkAdapterSupport {
type FileSplit <: HoodieFileSplit
@@ -120,7 +120,7 @@ abstract class HoodieBaseRelation(val sqlContext: SQLContext,
// If there is no commit in the table, we can't get the schema
// t/h [[TableSchemaResolver]], fallback to the provided [[userSchema]] instead.
userSchema match {
- case Some(s) => SchemaConverters.toAvroType(s)
+ case Some(s) => sparkAdapter.getAvroSchemaConverters.toAvroType(s, nullable = false, "record")
case _ => throw new IllegalArgumentException("User-provided schema is required in case the table is empty")
}
)
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/HoodieSparkAvroSchemaConverters.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/HoodieSparkAvroSchemaConverters.scala
new file mode 100644
index 000000000..65306ac44
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/HoodieSparkAvroSchemaConverters.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.Schema
+import org.apache.spark.sql.avro.SchemaConverters.SchemaType
+import org.apache.spark.sql.types.DataType
+
+/**
+ * This interface is simply a facade abstracting away Spark's [[SchemaConverters]] implementation, allowing
+ * the rest of the code-base to not depend on it directly
+ */
+object HoodieSparkAvroSchemaConverters extends HoodieAvroSchemaConverters {
+
+ override def toSqlType(avroSchema: Schema): (DataType, Boolean) =
+ SchemaConverters.toSqlType(avroSchema) match {
+ case SchemaType(dataType, nullable) => (dataType, nullable)
+ }
+
+ override def toAvroType(catalystType: DataType, nullable: Boolean, recordName: String, nameSpace: String): Schema =
+ SchemaConverters.toAvroType(catalystType, nullable, recordName, nameSpace)
+
+}
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
new file mode 100644
index 000000000..a5b519b0e
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.LogicalTypes.{Date, Decimal, TimestampMicros, TimestampMillis}
+import org.apache.avro.Schema.Type._
+import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.types.Decimal.minBytesForPrecision
+import org.apache.spark.sql.types._
+
+import scala.collection.JavaConverters._
+
+/**
+ * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice
+ * versa.
+ *
+ * NOTE: This code is borrowed from Spark 3.2.1
+ * This code is borrowed, so that we can better control compatibility w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
+ */
+@DeveloperApi
+private[sql] object SchemaConverters {
+ private lazy val nullSchema = Schema.create(Schema.Type.NULL)
+
+ /**
+ * Internal wrapper for SQL data type and nullability.
+ *
+ * @since 2.4.0
+ */
+ case class SchemaType(dataType: DataType, nullable: Boolean)
+
+ /**
+ * Converts an Avro schema to a corresponding Spark SQL schema.
+ *
+ * @since 2.4.0
+ */
+ def toSqlType(avroSchema: Schema): SchemaType = {
+ toSqlTypeHelper(avroSchema, Set.empty)
+ }
+
+ private def toSqlTypeHelper(avroSchema: Schema, existingRecordNames: Set[String]): SchemaType = {
+ avroSchema.getType match {
+ case INT => avroSchema.getLogicalType match {
+ case _: Date => SchemaType(DateType, nullable = false)
+ case _ => SchemaType(IntegerType, nullable = false)
+ }
+ case STRING => SchemaType(StringType, nullable = false)
+ case BOOLEAN => SchemaType(BooleanType, nullable = false)
+ case BYTES | FIXED => avroSchema.getLogicalType match {
+ // For FIXED type, if the precision requires more bytes than fixed size, the logical
+ // type will be null, which is handled by Avro library.
+ case d: Decimal => SchemaType(DecimalType(d.getPrecision, d.getScale), nullable = false)
+ case _ => SchemaType(BinaryType, nullable = false)
+ }
+
+ case DOUBLE => SchemaType(DoubleType, nullable = false)
+ case FLOAT => SchemaType(FloatType, nullable = false)
+ case LONG => avroSchema.getLogicalType match {
+ case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false)
+ case _ => SchemaType(LongType, nullable = false)
+ }
+
+ case ENUM => SchemaType(StringType, nullable = false)
+
+ case NULL => SchemaType(NullType, nullable = true)
+
+ case RECORD =>
+ if (existingRecordNames.contains(avroSchema.getFullName)) {
+ throw new IncompatibleSchemaException(
+ s"""
+ |Found recursive reference in Avro schema, which can not be processed by Spark:
+ |${avroSchema.toString(true)}
+ """.stripMargin)
+ }
+ val newRecordNames = existingRecordNames + avroSchema.getFullName
+ val fields = avroSchema.getFields.asScala.map { f =>
+ val schemaType = toSqlTypeHelper(f.schema(), newRecordNames)
+ StructField(f.name, schemaType.dataType, schemaType.nullable)
+ }
+
+ SchemaType(StructType(fields.toSeq), nullable = false)
+
+ case ARRAY =>
+ val schemaType = toSqlTypeHelper(avroSchema.getElementType, existingRecordNames)
+ SchemaType(
+ ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
+ nullable = false)
+
+ case MAP =>
+ val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames)
+ SchemaType(
+ MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
+ nullable = false)
+
+ case UNION =>
+ if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
+ // In case of a union with null, eliminate it and make a recursive call
+ val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
+ if (remainingUnionTypes.size == 1) {
+ toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames).copy(nullable = true)
+ } else {
+ toSqlTypeHelper(Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames)
+ .copy(nullable = true)
+ }
+ } else avroSchema.getTypes.asScala.map(_.getType).toSeq match {
+ case Seq(t1) =>
+ toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames)
+ case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
+ SchemaType(LongType, nullable = false)
+ case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
+ SchemaType(DoubleType, nullable = false)
+ case _ =>
+ // Convert complex unions to struct types where field names are member0, member1, etc.
+ // This is consistent with the behavior when converting between Avro and Parquet.
+ val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
+ case (s, i) =>
+ val schemaType = toSqlTypeHelper(s, existingRecordNames)
+ // All fields are nullable because only one of them is set at a time
+ StructField(s"member$i", schemaType.dataType, nullable = true)
+ }
+
+ SchemaType(StructType(fields.toSeq), nullable = false)
+ }
+
+ case other => throw new IncompatibleSchemaException(s"Unsupported type $other")
+ }
+ }
+
+ /**
+ * Converts a Spark SQL schema to a corresponding Avro schema.
+ *
+ * @since 2.4.0
+ */
+ def toAvroType(catalystType: DataType,
+ nullable: Boolean = false,
+ recordName: String = "topLevelRecord",
+ nameSpace: String = ""): Schema = {
+ val builder = SchemaBuilder.builder()
+
+ val schema = catalystType match {
+ case BooleanType => builder.booleanType()
+ case ByteType | ShortType | IntegerType => builder.intType()
+ case LongType => builder.longType()
+ case DateType =>
+ LogicalTypes.date().addToSchema(builder.intType())
+ case TimestampType =>
+ LogicalTypes.timestampMicros().addToSchema(builder.longType())
+
+ case FloatType => builder.floatType()
+ case DoubleType => builder.doubleType()
+ case StringType => builder.stringType()
+ case NullType => builder.nullType()
+ case d: DecimalType =>
+ val avroType = LogicalTypes.decimal(d.precision, d.scale)
+ val fixedSize = minBytesForPrecision(d.precision)
+ // Need to avoid naming conflict for the fixed fields
+ val name = nameSpace match {
+ case "" => s"$recordName.fixed"
+ case _ => s"$nameSpace.$recordName.fixed"
+ }
+ avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize))
+
+ case BinaryType => builder.bytesType()
+ case ArrayType(et, containsNull) =>
+ builder.array()
+ .items(toAvroType(et, containsNull, recordName, nameSpace))
+ case MapType(StringType, vt, valueContainsNull) =>
+ builder.map()
+ .values(toAvroType(vt, valueContainsNull, recordName, nameSpace))
+ case st: StructType =>
+ val childNameSpace = if (nameSpace != "") s"$nameSpace.$recordName" else recordName
+ val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields()
+ st.foreach { f =>
+ val fieldAvroType =
+ toAvroType(f.dataType, f.nullable, f.name, childNameSpace)
+ fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault()
+ }
+ fieldsAssembler.endRecord()
+
+ // This should never happen.
+ case other => throw new IncompatibleSchemaException(s"Unexpected type $other.")
+ }
+ if (nullable && catalystType != NullType) {
+ Schema.createUnion(schema, nullSchema)
+ } else {
+ schema
+ }
+ }
+}
+
+private[avro] class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex)
+
+private[avro] class UnsupportedAvroTypeException(msg: String) extends Exception(msg)
diff --git a/hudi-spark-datasource/hudi-spark/pom.xml b/hudi-spark-datasource/hudi-spark/pom.xml
index 606f6fa89..a9c6ee78b 100644
--- a/hudi-spark-datasource/hudi-spark/pom.xml
+++ b/hudi-spark-datasource/hudi-spark/pom.xml
@@ -329,13 +329,6 @@
test
-
-
- org.apache.spark
- spark-avro_${scala.binary.version}
- provided
-
-
org.apache.hadoop
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala
similarity index 100%
rename from hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala
rename to hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala
similarity index 100%
rename from hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala
rename to hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala
diff --git a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/index/bucket/TestBucketIdentifier.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/index/bucket/TestBucketIdentifier.java
similarity index 100%
rename from hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/index/bucket/TestBucketIdentifier.java
rename to hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/index/bucket/TestBucketIdentifier.java
diff --git a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/testutils/KeyGeneratorTestUtilities.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/testutils/KeyGeneratorTestUtilities.java
similarity index 100%
rename from hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/testutils/KeyGeneratorTestUtilities.java
rename to hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/testutils/KeyGeneratorTestUtilities.java
diff --git a/hudi-spark-datasource/hudi-spark2/pom.xml b/hudi-spark-datasource/hudi-spark2/pom.xml
index 3fb6cf3dd..bf2b251aa 100644
--- a/hudi-spark-datasource/hudi-spark2/pom.xml
+++ b/hudi-spark-datasource/hudi-spark2/pom.xml
@@ -203,14 +203,6 @@
true
-
- org.apache.spark
- spark-avro_${scala.binary.version}
- ${spark2.version}
- provided
- true
-
-
io.netty
netty
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala
index 42ad66598..d685ce2ee 100644
--- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.adapter
import org.apache.avro.Schema
import org.apache.hudi.Spark2RowSerDe
import org.apache.hudi.client.utils.SparkRowSerDe
-import org.apache.spark.sql.avro.{HoodieAvroDeserializer, HoodieAvroSerializer, HoodieSpark2AvroDeserializer, HoodieSparkAvroSerializer}
+import org.apache.spark.sql.avro.{HoodieAvroDeserializer, HoodieAvroSchemaConverters, HoodieAvroSerializer, HoodieSpark2_4AvroDeserializer, HoodieSpark2_4AvroSerializer, HoodieSparkAvroSchemaConverters}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, Like}
@@ -38,17 +38,19 @@ import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieSpark2Catalyst
import scala.collection.mutable.ArrayBuffer
/**
- * The adapter for spark2.
+ * Implementation of [[SparkAdapter]] for Spark 2.4.x
*/
class Spark2Adapter extends SparkAdapter {
override def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils = HoodieSpark2CatalystExpressionUtils
override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer =
- new HoodieSparkAvroSerializer(rootCatalystType, rootAvroType, nullable)
+ new HoodieSpark2_4AvroSerializer(rootCatalystType, rootAvroType, nullable)
override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
- new HoodieSpark2AvroDeserializer(rootAvroType, rootCatalystType)
+ new HoodieSpark2_4AvroDeserializer(rootAvroType, rootCatalystType)
+
+ override def getAvroSchemaConverters: HoodieAvroSchemaConverters = HoodieSparkAvroSchemaConverters
override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
new Spark2RowSerDe(encoder)
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/PatchedAvroDeserializer.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
similarity index 97%
rename from hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/PatchedAvroDeserializer.scala
rename to hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index 8d9948c58..2e0946f1e 100644
--- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/PatchedAvroDeserializer.scala
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -37,10 +37,16 @@ import scala.collection.mutable.ArrayBuffer
/**
* A deserializer to deserialize data in avro format to data in catalyst format.
*
- * NOTE: This is a version of {@code AvroDeserializer} impl from Spark 2.4.4 w/ the fix for SPARK-30267
+ * NOTE: This code is borrowed from Spark 2.4.4
+ * This code is borrowed, so that we can better control compatibility w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
+ *
+ * NOTE: This is a version of [[AvroDeserializer]] impl from Spark 2.4.4 w/ the fix for SPARK-30267
* applied on top of it
*/
-class PatchedAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
+class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
private lazy val decimalConversions = new DecimalConversion()
private val converter: Any => Any = rootCatalystType match {
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
new file mode 100644
index 000000000..31deb34be
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import java.nio.ByteBuffer
+
+import scala.collection.JavaConverters._
+
+import org.apache.avro.{LogicalTypes, Schema}
+import org.apache.avro.Conversions.DecimalConversion
+import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
+import org.apache.avro.Schema
+import org.apache.avro.Schema.Type
+import org.apache.avro.Schema.Type._
+import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
+import org.apache.avro.generic.GenericData.Record
+import org.apache.avro.util.Utf8
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
+import org.apache.spark.sql.types._
+
+/**
+ * A serializer to serialize data in catalyst format to data in avro format.
+ *
+ * NOTE: This code is borrowed from Spark 2.4.4
+ * This code is borrowed, so that we can better control compatibility w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
+ */
+class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) {
+
+ def serialize(catalystData: Any): Any = {
+ converter.apply(catalystData)
+ }
+
+ private val converter: Any => Any = {
+ val actualAvroType = resolveNullableType(rootAvroType, nullable)
+ val baseConverter = rootCatalystType match {
+ case st: StructType =>
+ newStructConverter(st, actualAvroType).asInstanceOf[Any => Any]
+ case _ =>
+ val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
+ val converter = newConverter(rootCatalystType, actualAvroType)
+ (data: Any) =>
+ tmpRow.update(0, data)
+ converter.apply(tmpRow, 0)
+ }
+ if (nullable) {
+ (data: Any) =>
+ if (data == null) {
+ null
+ } else {
+ baseConverter.apply(data)
+ }
+ } else {
+ baseConverter
+ }
+ }
+
+ private type Converter = (SpecializedGetters, Int) => Any
+
+ private lazy val decimalConversions = new DecimalConversion()
+
+ private def newConverter(catalystType: DataType, avroType: Schema): Converter = {
+ (catalystType, avroType.getType) match {
+ case (NullType, NULL) =>
+ (getter, ordinal) => null
+ case (BooleanType, BOOLEAN) =>
+ (getter, ordinal) => getter.getBoolean(ordinal)
+ case (ByteType, INT) =>
+ (getter, ordinal) => getter.getByte(ordinal).toInt
+ case (ShortType, INT) =>
+ (getter, ordinal) => getter.getShort(ordinal).toInt
+ case (IntegerType, INT) =>
+ (getter, ordinal) => getter.getInt(ordinal)
+ case (LongType, LONG) =>
+ (getter, ordinal) => getter.getLong(ordinal)
+ case (FloatType, FLOAT) =>
+ (getter, ordinal) => getter.getFloat(ordinal)
+ case (DoubleType, DOUBLE) =>
+ (getter, ordinal) => getter.getDouble(ordinal)
+ case (d: DecimalType, FIXED)
+ if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
+ (getter, ordinal) =>
+ val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
+ decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
+ LogicalTypes.decimal(d.precision, d.scale))
+
+ case (d: DecimalType, BYTES)
+ if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
+ (getter, ordinal) =>
+ val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
+ decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType,
+ LogicalTypes.decimal(d.precision, d.scale))
+
+ case (StringType, ENUM) =>
+ val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
+ (getter, ordinal) =>
+ val data = getter.getUTF8String(ordinal).toString
+ if (!enumSymbols.contains(data)) {
+ throw new IncompatibleSchemaException(
+ "Cannot write \"" + data + "\" since it's not defined in enum \"" +
+ enumSymbols.mkString("\", \"") + "\"")
+ }
+ new EnumSymbol(avroType, data)
+
+ case (StringType, STRING) =>
+ (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
+
+ case (BinaryType, FIXED) =>
+ val size = avroType.getFixedSize()
+ (getter, ordinal) =>
+ val data: Array[Byte] = getter.getBinary(ordinal)
+ if (data.length != size) {
+ throw new IncompatibleSchemaException(
+ s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " +
+ "binary data into FIXED Type with size of " +
+ s"$size ${if (size > 1) "bytes" else "byte"}")
+ }
+ new Fixed(avroType, data)
+
+ case (BinaryType, BYTES) =>
+ (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
+
+ case (DateType, INT) =>
+ (getter, ordinal) => getter.getInt(ordinal)
+
+ case (TimestampType, LONG) => avroType.getLogicalType match {
+ case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000
+ case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal)
+ // For backward compatibility, if the Avro type is Long and it is not logical type,
+ // output the timestamp value as with millisecond precision.
+ case null => (getter, ordinal) => getter.getLong(ordinal) / 1000
+ case other => throw new IncompatibleSchemaException(
+ s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}")
+ }
+
+ case (ArrayType(et, containsNull), ARRAY) =>
+ val elementConverter = newConverter(
+ et, resolveNullableType(avroType.getElementType, containsNull))
+ (getter, ordinal) => {
+ val arrayData = getter.getArray(ordinal)
+ val len = arrayData.numElements()
+ val result = new Array[Any](len)
+ var i = 0
+ while (i < len) {
+ if (containsNull && arrayData.isNullAt(i)) {
+ result(i) = null
+ } else {
+ result(i) = elementConverter(arrayData, i)
+ }
+ i += 1
+ }
+ // avro writer is expecting a Java Collection, so we convert it into
+ // `ArrayList` backed by the specified array without data copying.
+ java.util.Arrays.asList(result: _*)
+ }
+
+ case (st: StructType, RECORD) =>
+ val structConverter = newStructConverter(st, avroType)
+ val numFields = st.length
+ (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))
+
+ case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
+ val valueConverter = newConverter(
+ vt, resolveNullableType(avroType.getValueType, valueContainsNull))
+ (getter, ordinal) =>
+ val mapData = getter.getMap(ordinal)
+ val len = mapData.numElements()
+ val result = new java.util.HashMap[String, Any](len)
+ val keyArray = mapData.keyArray()
+ val valueArray = mapData.valueArray()
+ var i = 0
+ while (i < len) {
+ val key = keyArray.getUTF8String(i).toString
+ if (valueContainsNull && valueArray.isNullAt(i)) {
+ result.put(key, null)
+ } else {
+ result.put(key, valueConverter(valueArray, i))
+ }
+ i += 1
+ }
+ result
+
+ case other =>
+ throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystType to " +
+ s"Avro type $avroType.")
+ }
+ }
+
+ private def newStructConverter(
+ catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = {
+ if (avroStruct.getType != RECORD || avroStruct.getFields.size() != catalystStruct.length) {
+ throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " +
+ s"Avro type $avroStruct.")
+ }
+ val fieldConverters = catalystStruct.zip(avroStruct.getFields.asScala).map {
+ case (f1, f2) => newConverter(f1.dataType, resolveNullableType(f2.schema(), f1.nullable))
+ }
+ val numFields = catalystStruct.length
+ (row: InternalRow) =>
+ val result = new Record(avroStruct)
+ var i = 0
+ while (i < numFields) {
+ if (row.isNullAt(i)) {
+ result.put(i, null)
+ } else {
+ result.put(i, fieldConverters(i).apply(row, i))
+ }
+ i += 1
+ }
+ result
+ }
+
+ private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = {
+ if (nullable && avroType.getType != NULL) {
+ // avro uses union to represent nullable type.
+ val fields = avroType.getTypes.asScala
+ assert(fields.length == 2)
+ val actualType = fields.filter(_.getType != Type.NULL)
+ assert(actualType.length == 1)
+ actualType.head
+ } else {
+ avroType
+ }
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/HoodieSpark2AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/HoodieSpark2_4AvroDeserializer.scala
similarity index 73%
rename from hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/HoodieSpark2AvroDeserializer.scala
rename to hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/HoodieSpark2_4AvroDeserializer.scala
index 2b55c6695..1c9bc88a3 100644
--- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/HoodieSpark2AvroDeserializer.scala
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/HoodieSpark2_4AvroDeserializer.scala
@@ -20,14 +20,10 @@ package org.apache.spark.sql.avro
import org.apache.avro.Schema
import org.apache.spark.sql.types.DataType
-/**
- * This is Spark 2 implementation for the [[HoodieAvroDeserializer]] leveraging [[PatchedAvroDeserializer]],
- * which is just copied over version of [[AvroDeserializer]] from Spark 2.4.4 w/ SPARK-30267 being back-ported to it
- */
-class HoodieSpark2AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType)
+class HoodieSpark2_4AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType)
extends HoodieAvroDeserializer {
- private val avroDeserializer = new PatchedAvroDeserializer(rootAvroType, rootCatalystType)
+ private val avroDeserializer = new AvroDeserializer(rootAvroType, rootCatalystType)
// As of Spark 3.1, this will return data wrapped with Option, so we make sure these interfaces
// are aligned across Spark versions
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/HoodieSparkAvroSerializer.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/HoodieSpark2_4AvroSerializer.scala
similarity index 91%
rename from hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/HoodieSparkAvroSerializer.scala
rename to hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/HoodieSpark2_4AvroSerializer.scala
index 4a3a7c452..48009ca16 100644
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/HoodieSparkAvroSerializer.scala
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/HoodieSpark2_4AvroSerializer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.avro
import org.apache.avro.Schema
import org.apache.spark.sql.types.DataType
-class HoodieSparkAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean)
+class HoodieSpark2_4AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean)
extends HoodieAvroSerializer {
val avroSerializer = new AvroSerializer(rootCatalystType, rootAvroType, nullable)
diff --git a/hudi-spark-datasource/hudi-spark3-common/pom.xml b/hudi-spark-datasource/hudi-spark3-common/pom.xml
index 30e7bda2e..87366ef0b 100644
--- a/hudi-spark-datasource/hudi-spark3-common/pom.xml
+++ b/hudi-spark-datasource/hudi-spark3-common/pom.xml
@@ -172,14 +172,6 @@
true
-
- org.apache.spark
- spark-avro_2.12
- ${spark3.version}
- provided
- true
-
-
com.fasterxml.jackson.core
jackson-databind
diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala
index 33aae23df..681484034 100644
--- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala
+++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.adapter
-import org.apache.avro.Schema
import org.apache.hudi.Spark3RowSerDe
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.hudi.spark3.internal.ReflectUtil
-import org.apache.spark.sql.avro.{HoodieAvroDeserializer, HoodieAvroSerializer, HoodieSpark3AvroDeserializer, HoodieSparkAvroSerializer}
+import org.apache.spark.sql.avro.{HoodieAvroSchemaConverters, HoodieSparkAvroSchemaConverters}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, Like}
@@ -35,7 +34,6 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SparkSession}
/**
@@ -43,16 +41,12 @@ import org.apache.spark.sql.{Row, SparkSession}
*/
abstract class BaseSpark3Adapter extends SparkAdapter {
- override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer =
- new HoodieSparkAvroSerializer(rootCatalystType, rootAvroType, nullable)
-
- override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
- new HoodieSpark3AvroDeserializer(rootAvroType, rootCatalystType)
-
override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
new Spark3RowSerDe(encoder)
}
+ override def getAvroSchemaConverters: HoodieAvroSchemaConverters = HoodieSparkAvroSchemaConverters
+
override def toTableIdentifier(aliasId: AliasIdentifier): TableIdentifier = {
aliasId match {
case AliasIdentifier(name, Seq(database)) =>
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala
index 106939cbb..a9196173f 100644
--- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala
@@ -18,7 +18,10 @@
package org.apache.spark.sql.adapter
+import org.apache.avro.Schema
+import org.apache.spark.sql.avro.{HoodieAvroDeserializer, HoodieAvroSchemaConverters, HoodieAvroSerializer, HoodieSpark3_1AvroDeserializer, HoodieSpark3_1AvroSerializer, HoodieSparkAvroSchemaConverters}
import org.apache.spark.sql.hudi.SparkAdapter
+import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieSpark3_1CatalystExpressionUtils}
/**
@@ -28,4 +31,10 @@ class Spark3_1Adapter extends BaseSpark3Adapter {
override def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils = HoodieSpark3_1CatalystExpressionUtils
+ override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer =
+ new HoodieSpark3_1AvroSerializer(rootCatalystType, rootAvroType, nullable)
+
+ override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
+ new HoodieSpark3_1AvroDeserializer(rootAvroType, rootCatalystType)
+
}
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
new file mode 100644
index 000000000..717df0f40
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.Conversions.DecimalConversion
+import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
+import org.apache.avro.Schema.Type._
+import org.apache.avro.generic._
+import org.apache.avro.util.Utf8
+import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
+import org.apache.spark.sql.avro.AvroDeserializer.{createDateRebaseFuncInRead, createTimestampRebaseFuncInRead}
+import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+import java.math.BigDecimal
+import java.nio.ByteBuffer
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * A deserializer to deserialize data in avro format to data in catalyst format.
+ *
+ * NOTE: This code is borrowed from Spark 3.1.2
+ * This code is borrowed, so that we can better control compatibility w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
+ */
+private[sql] class AvroDeserializer(rootAvroType: Schema,
+ rootCatalystType: DataType,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ filters: StructFilters) {
+
+ def this(rootAvroType: Schema, rootCatalystType: DataType) = {
+ this(
+ rootAvroType,
+ rootCatalystType,
+ LegacyBehaviorPolicy.withName(SQLConf.get.getConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ)),
+ new NoopFilters)
+ }
+
+ private lazy val decimalConversions = new DecimalConversion()
+
+ private val dateRebaseFunc = createDateRebaseFuncInRead(
+ datetimeRebaseMode, "Avro")
+
+ private val timestampRebaseFunc = createTimestampRebaseFuncInRead(
+ datetimeRebaseMode, "Avro")
+
+ private val converter: Any => Option[Any] = rootCatalystType match {
+ // A shortcut for empty schema.
+ case st: StructType if st.isEmpty =>
+ (data: Any) => Some(InternalRow.empty)
+
+ case st: StructType =>
+ val resultRow = new SpecificInternalRow(st.map(_.dataType))
+ val fieldUpdater = new RowUpdater(resultRow)
+ val applyFilters = filters.skipRow(resultRow, _)
+ val writer = getRecordWriter(rootAvroType, st, Nil, applyFilters)
+ (data: Any) => {
+ val record = data.asInstanceOf[GenericRecord]
+ val skipRow = writer(fieldUpdater, record)
+ if (skipRow) None else Some(resultRow)
+ }
+
+ case _ =>
+ val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
+ val fieldUpdater = new RowUpdater(tmpRow)
+ val writer = newWriter(rootAvroType, rootCatalystType, Nil)
+ (data: Any) => {
+ writer(fieldUpdater, 0, data)
+ Some(tmpRow.get(0, rootCatalystType))
+ }
+ }
+
+ def deserialize(data: Any): Option[Any] = converter(data)
+
+ /**
+ * Creates a writer to write avro values to Catalyst values at the given ordinal with the given
+ * updater.
+ */
+ private def newWriter(avroType: Schema,
+ catalystType: DataType,
+ path: List[String]): (CatalystDataUpdater, Int, Any) => Unit =
+ (avroType.getType, catalystType) match {
+ case (NULL, NullType) => (updater, ordinal, _) =>
+ updater.setNullAt(ordinal)
+
+ // TODO: we can avoid boxing if future version of avro provide primitive accessors.
+ case (BOOLEAN, BooleanType) => (updater, ordinal, value) =>
+ updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
+
+ case (INT, IntegerType) => (updater, ordinal, value) =>
+ updater.setInt(ordinal, value.asInstanceOf[Int])
+
+ case (INT, DateType) => (updater, ordinal, value) =>
+ updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int]))
+
+ case (LONG, LongType) => (updater, ordinal, value) =>
+ updater.setLong(ordinal, value.asInstanceOf[Long])
+
+ case (LONG, TimestampType) => avroType.getLogicalType match {
+ // For backward compatibility, if the Avro type is Long and it is not logical type
+ // (the `null` case), the value is processed as timestamp type with millisecond precision.
+ case null | _: TimestampMillis => (updater, ordinal, value) =>
+ val millis = value.asInstanceOf[Long]
+ val micros = DateTimeUtils.millisToMicros(millis)
+ updater.setLong(ordinal, timestampRebaseFunc(micros))
+ case _: TimestampMicros => (updater, ordinal, value) =>
+ val micros = value.asInstanceOf[Long]
+ updater.setLong(ordinal, timestampRebaseFunc(micros))
+ case other => throw new IncompatibleSchemaException(
+ s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.")
+ }
+
+ // Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date.
+ // For backward compatibility, we still keep this conversion.
+ case (LONG, DateType) => (updater, ordinal, value) =>
+ updater.setInt(ordinal, (value.asInstanceOf[Long] / MILLIS_PER_DAY).toInt)
+
+ case (FLOAT, FloatType) => (updater, ordinal, value) =>
+ updater.setFloat(ordinal, value.asInstanceOf[Float])
+
+ case (DOUBLE, DoubleType) => (updater, ordinal, value) =>
+ updater.setDouble(ordinal, value.asInstanceOf[Double])
+
+ case (STRING, StringType) => (updater, ordinal, value) =>
+ val str = value match {
+ case s: String => UTF8String.fromString(s)
+ case s: Utf8 =>
+ val bytes = new Array[Byte](s.getByteLength)
+ System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength)
+ UTF8String.fromBytes(bytes)
+ }
+ updater.set(ordinal, str)
+
+ case (ENUM, StringType) => (updater, ordinal, value) =>
+ updater.set(ordinal, UTF8String.fromString(value.toString))
+
+ case (FIXED, BinaryType) => (updater, ordinal, value) =>
+ updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone())
+
+ case (BYTES, BinaryType) => (updater, ordinal, value) =>
+ val bytes = value match {
+ case b: ByteBuffer =>
+ val bytes = new Array[Byte](b.remaining)
+ b.get(bytes)
+ bytes
+ case b: Array[Byte] => b
+ case other => throw new RuntimeException(s"$other is not a valid avro binary.")
+ }
+ updater.set(ordinal, bytes)
+
+ case (FIXED, _: DecimalType) => (updater, ordinal, value) =>
+ val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal]
+ val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, d)
+ val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale)
+ updater.setDecimal(ordinal, decimal)
+
+ case (BYTES, _: DecimalType) => (updater, ordinal, value) =>
+ val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal]
+ val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, d)
+ val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale)
+ updater.setDecimal(ordinal, decimal)
+
+ case (RECORD, st: StructType) =>
+ // Avro datasource doesn't accept filters with nested attributes. See SPARK-32328.
+ // We can always return `false` from `applyFilters` for nested records.
+ val writeRecord = getRecordWriter(avroType, st, path, applyFilters = _ => false)
+ (updater, ordinal, value) =>
+ val row = new SpecificInternalRow(st)
+ writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord])
+ updater.set(ordinal, row)
+
+ case (ARRAY, ArrayType(elementType, containsNull)) =>
+ val elementWriter = newWriter(avroType.getElementType, elementType, path)
+ (updater, ordinal, value) =>
+ val collection = value.asInstanceOf[java.util.Collection[Any]]
+ val result = createArrayData(elementType, collection.size())
+ val elementUpdater = new ArrayDataUpdater(result)
+
+ var i = 0
+ val iter = collection.iterator()
+ while (iter.hasNext) {
+ val element = iter.next()
+ if (element == null) {
+ if (!containsNull) {
+ throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " +
+ "allowed to be null")
+ } else {
+ elementUpdater.setNullAt(i)
+ }
+ } else {
+ elementWriter(elementUpdater, i, element)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, result)
+
+ case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType =>
+ val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, path)
+ val valueWriter = newWriter(avroType.getValueType, valueType, path)
+ (updater, ordinal, value) =>
+ val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]]
+ val keyArray = createArrayData(keyType, map.size())
+ val keyUpdater = new ArrayDataUpdater(keyArray)
+ val valueArray = createArrayData(valueType, map.size())
+ val valueUpdater = new ArrayDataUpdater(valueArray)
+ val iter = map.entrySet().iterator()
+ var i = 0
+ while (iter.hasNext) {
+ val entry = iter.next()
+ assert(entry.getKey != null)
+ keyWriter(keyUpdater, i, entry.getKey)
+ if (entry.getValue == null) {
+ if (!valueContainsNull) {
+ throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " +
+ "allowed to be null")
+ } else {
+ valueUpdater.setNullAt(i)
+ }
+ } else {
+ valueWriter(valueUpdater, i, entry.getValue)
+ }
+ i += 1
+ }
+
+ // The Avro map will never have null or duplicated map keys, it's safe to create a
+ // ArrayBasedMapData directly here.
+ updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
+
+ case (UNION, _) =>
+ val allTypes = avroType.getTypes.asScala
+ val nonNullTypes = allTypes.filter(_.getType != NULL)
+ val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava)
+ if (nonNullTypes.nonEmpty) {
+ if (nonNullTypes.length == 1) {
+ newWriter(nonNullTypes.head, catalystType, path)
+ } else {
+ nonNullTypes.map(_.getType).toSeq match {
+ case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType =>
+ (updater, ordinal, value) =>
+ value match {
+ case null => updater.setNullAt(ordinal)
+ case l: java.lang.Long => updater.setLong(ordinal, l)
+ case i: java.lang.Integer => updater.setLong(ordinal, i.longValue())
+ }
+
+ case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType =>
+ (updater, ordinal, value) =>
+ value match {
+ case null => updater.setNullAt(ordinal)
+ case d: java.lang.Double => updater.setDouble(ordinal, d)
+ case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue())
+ }
+
+ case _ =>
+ catalystType match {
+ case st: StructType if st.length == nonNullTypes.size =>
+ val fieldWriters = nonNullTypes.zip(st.fields).map {
+ case (schema, field) => newWriter(schema, field.dataType, path :+ field.name)
+ }.toArray
+ (updater, ordinal, value) => {
+ val row = new SpecificInternalRow(st)
+ val fieldUpdater = new RowUpdater(row)
+ val i = GenericData.get().resolveUnion(nonNullAvroType, value)
+ fieldWriters(i)(fieldUpdater, i, value)
+ updater.set(ordinal, row)
+ }
+
+ case _ =>
+ throw new IncompatibleSchemaException(
+ s"Cannot convert Avro to catalyst because schema at path " +
+ s"${path.mkString(".")} is not compatible " +
+ s"(avroType = $avroType, sqlType = $catalystType).\n" +
+ s"Source Avro schema: $rootAvroType.\n" +
+ s"Target Catalyst type: $rootCatalystType")
+ }
+ }
+ }
+ } else {
+ (updater, ordinal, value) => updater.setNullAt(ordinal)
+ }
+
+ case _ =>
+ throw new IncompatibleSchemaException(
+ s"Cannot convert Avro to catalyst because schema at path ${path.mkString(".")} " +
+ s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n" +
+ s"Source Avro schema: $rootAvroType.\n" +
+ s"Target Catalyst type: $rootCatalystType")
+ }
+
+ // TODO: move the following method in Decimal object on creating Decimal from BigDecimal?
+ private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
+ if (precision <= Decimal.MAX_LONG_DIGITS) {
+ // Constructs a `Decimal` with an unscaled `Long` value if possible.
+ Decimal(decimal.unscaledValue().longValue(), precision, scale)
+ } else {
+ // Otherwise, resorts to an unscaled `BigInteger` instead.
+ Decimal(decimal, precision, scale)
+ }
+ }
+
+ private def getRecordWriter(avroType: Schema,
+ sqlType: StructType,
+ path: List[String],
+ applyFilters: Int => Boolean): (CatalystDataUpdater, GenericRecord) => Boolean = {
+ val validFieldIndexes = ArrayBuffer.empty[Int]
+ val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit]
+
+ val avroSchemaHelper = new AvroUtils.AvroSchemaHelper(avroType)
+ val length = sqlType.length
+ var i = 0
+ while (i < length) {
+ val sqlField = sqlType.fields(i)
+ avroSchemaHelper.getFieldByName(sqlField.name) match {
+ case Some(avroField) =>
+ validFieldIndexes += avroField.pos()
+
+ val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name)
+ val ordinal = i
+ val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
+ if (value == null) {
+ fieldUpdater.setNullAt(ordinal)
+ } else {
+ baseWriter(fieldUpdater, ordinal, value)
+ }
+ }
+ fieldWriters += fieldWriter
+ case None if !sqlField.nullable =>
+ val fieldStr = s"${path.mkString(".")}.${sqlField.name}"
+ throw new IncompatibleSchemaException(
+ s"""
+ |Cannot find non-nullable field $fieldStr in Avro schema.
+ |Source Avro schema: $rootAvroType.
+ |Target Catalyst type: $rootCatalystType.
+ """.stripMargin)
+ case _ => // nothing to do
+ }
+ i += 1
+ }
+
+ (fieldUpdater, record) => {
+ var i = 0
+ var skipRow = false
+ while (i < validFieldIndexes.length && !skipRow) {
+ fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i)))
+ skipRow = applyFilters(i)
+ i += 1
+ }
+ skipRow
+ }
+ }
+
+ private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match {
+ case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length))
+ case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length))
+ case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length))
+ case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
+ case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
+ case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length))
+ case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length))
+ case _ => new GenericArrayData(new Array[Any](length))
+ }
+
+ /**
+ * A base interface for updating values inside catalyst data structure like `InternalRow` and
+ * `ArrayData`.
+ */
+ sealed trait CatalystDataUpdater {
+ def set(ordinal: Int, value: Any): Unit
+
+ def setNullAt(ordinal: Int): Unit = set(ordinal, null)
+
+ def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
+
+ def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
+
+ def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
+
+ def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
+
+ def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
+
+ def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
+
+ def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
+
+ def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value)
+ }
+
+ final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
+ override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)
+
+ override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
+
+ override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
+
+ override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
+
+ override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
+
+ override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
+
+ override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
+
+ override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
+
+ override def setDecimal(ordinal: Int, value: Decimal): Unit =
+ row.setDecimal(ordinal, value, value.precision)
+ }
+
+ final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
+ override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value)
+
+ override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value)
+
+ override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value)
+
+ override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value)
+
+ override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value)
+
+ override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
+
+ override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
+
+ override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
+
+ override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value)
+ }
+}
+
+object AvroDeserializer {
+
+ // NOTE: Following methods have been renamed in Spark 3.1.3 [1] making [[AvroDeserializer]] implementation
+ // (which relies on it) be only compatible with the exact same version of [[DataSourceUtils]].
+ // To make sure this implementation is compatible w/ all Spark versions w/in Spark 3.1.x branch,
+ // we're preemptively cloned those methods to make sure Hudi is compatible w/ Spark 3.1.2 as well as
+ // w/ Spark >= 3.1.3
+ //
+ // [1] https://github.com/apache/spark/pull/34978
+
+ def createDateRebaseFuncInRead(rebaseMode: LegacyBehaviorPolicy.Value,
+ format: String): Int => Int = rebaseMode match {
+ case LegacyBehaviorPolicy.EXCEPTION => days: Int =>
+ if (days < RebaseDateTime.lastSwitchJulianDay) {
+ throw DataSourceUtils.newRebaseExceptionInRead(format)
+ }
+ days
+ case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays
+ case LegacyBehaviorPolicy.CORRECTED => identity[Int]
+ }
+
+ def createTimestampRebaseFuncInRead(rebaseMode: LegacyBehaviorPolicy.Value,
+ format: String): Long => Long = rebaseMode match {
+ case LegacyBehaviorPolicy.EXCEPTION => micros: Long =>
+ if (micros < RebaseDateTime.lastSwitchJulianTs) {
+ throw DataSourceUtils.newRebaseExceptionInRead(format)
+ }
+ micros
+ case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianMicros
+ case LegacyBehaviorPolicy.CORRECTED => identity[Long]
+ }
+}
+
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
new file mode 100644
index 000000000..b423f9b96
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.Conversions.DecimalConversion
+import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
+import org.apache.avro.{LogicalTypes, Schema}
+import org.apache.avro.Schema.Type
+import org.apache.avro.Schema.Type._
+import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
+import org.apache.avro.util.Utf8
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.avro.AvroSerializer.{createDateRebaseFuncInWrite, createTimestampRebaseFuncInWrite}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, RebaseDateTime}
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.sql.types._
+
+import java.nio.ByteBuffer
+import scala.collection.JavaConverters._
+
+/**
+ * A serializer to serialize data in catalyst format to data in avro format.
+ *
+ * NOTE: This code is borrowed from Spark 3.1.2
+ * This code is borrowed, so that we can better control compatibility w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
+ */
+private[sql] class AvroSerializer(rootCatalystType: DataType,
+ rootAvroType: Schema,
+ nullable: Boolean,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends Logging {
+
+ def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) = {
+ this(rootCatalystType, rootAvroType, nullable,
+ LegacyBehaviorPolicy.withName(SQLConf.get.getConf(
+ SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE)))
+ }
+
+ def serialize(catalystData: Any): Any = {
+ converter.apply(catalystData)
+ }
+
+ private val dateRebaseFunc = createDateRebaseFuncInWrite(
+ datetimeRebaseMode, "Avro")
+
+ private val timestampRebaseFunc = createTimestampRebaseFuncInWrite(
+ datetimeRebaseMode, "Avro")
+
+ private val converter: Any => Any = {
+ val actualAvroType = resolveNullableType(rootAvroType, nullable)
+ val baseConverter = rootCatalystType match {
+ case st: StructType =>
+ newStructConverter(st, actualAvroType).asInstanceOf[Any => Any]
+ case _ =>
+ val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
+ val converter = newConverter(rootCatalystType, actualAvroType)
+ (data: Any) =>
+ tmpRow.update(0, data)
+ converter.apply(tmpRow, 0)
+ }
+ if (nullable) {
+ (data: Any) =>
+ if (data == null) {
+ null
+ } else {
+ baseConverter.apply(data)
+ }
+ } else {
+ baseConverter
+ }
+ }
+
+ private type Converter = (SpecializedGetters, Int) => Any
+
+ private lazy val decimalConversions = new DecimalConversion()
+
+ private def newConverter(catalystType: DataType, avroType: Schema): Converter = {
+ (catalystType, avroType.getType) match {
+ case (NullType, NULL) =>
+ (getter, ordinal) => null
+ case (BooleanType, BOOLEAN) =>
+ (getter, ordinal) => getter.getBoolean(ordinal)
+ case (ByteType, INT) =>
+ (getter, ordinal) => getter.getByte(ordinal).toInt
+ case (ShortType, INT) =>
+ (getter, ordinal) => getter.getShort(ordinal).toInt
+ case (IntegerType, INT) =>
+ (getter, ordinal) => getter.getInt(ordinal)
+ case (LongType, LONG) =>
+ (getter, ordinal) => getter.getLong(ordinal)
+ case (FloatType, FLOAT) =>
+ (getter, ordinal) => getter.getFloat(ordinal)
+ case (DoubleType, DOUBLE) =>
+ (getter, ordinal) => getter.getDouble(ordinal)
+ case (d: DecimalType, FIXED)
+ if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
+ (getter, ordinal) =>
+ val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
+ decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
+ LogicalTypes.decimal(d.precision, d.scale))
+
+ case (d: DecimalType, BYTES)
+ if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
+ (getter, ordinal) =>
+ val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
+ decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType,
+ LogicalTypes.decimal(d.precision, d.scale))
+
+ case (StringType, ENUM) =>
+ val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
+ (getter, ordinal) =>
+ val data = getter.getUTF8String(ordinal).toString
+ if (!enumSymbols.contains(data)) {
+ throw new IncompatibleSchemaException(
+ "Cannot write \"" + data + "\" since it's not defined in enum \"" +
+ enumSymbols.mkString("\", \"") + "\"")
+ }
+ new EnumSymbol(avroType, data)
+
+ case (StringType, STRING) =>
+ (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
+
+ case (BinaryType, FIXED) =>
+ val size = avroType.getFixedSize()
+ (getter, ordinal) =>
+ val data: Array[Byte] = getter.getBinary(ordinal)
+ if (data.length != size) {
+ throw new IncompatibleSchemaException(
+ s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " +
+ "binary data into FIXED Type with size of " +
+ s"$size ${if (size > 1) "bytes" else "byte"}")
+ }
+ new Fixed(avroType, data)
+
+ case (BinaryType, BYTES) =>
+ (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
+
+ case (DateType, INT) =>
+ (getter, ordinal) => dateRebaseFunc(getter.getInt(ordinal))
+
+ case (TimestampType, LONG) => avroType.getLogicalType match {
+ // For backward compatibility, if the Avro type is Long and it is not logical type
+ // (the `null` case), output the timestamp value as with millisecond precision.
+ case null | _: TimestampMillis => (getter, ordinal) =>
+ DateTimeUtils.microsToMillis(timestampRebaseFunc(getter.getLong(ordinal)))
+ case _: TimestampMicros => (getter, ordinal) =>
+ timestampRebaseFunc(getter.getLong(ordinal))
+ case other => throw new IncompatibleSchemaException(
+ s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}")
+ }
+
+ case (ArrayType(et, containsNull), ARRAY) =>
+ val elementConverter = newConverter(
+ et, resolveNullableType(avroType.getElementType, containsNull))
+ (getter, ordinal) => {
+ val arrayData = getter.getArray(ordinal)
+ val len = arrayData.numElements()
+ val result = new Array[Any](len)
+ var i = 0
+ while (i < len) {
+ if (containsNull && arrayData.isNullAt(i)) {
+ result(i) = null
+ } else {
+ result(i) = elementConverter(arrayData, i)
+ }
+ i += 1
+ }
+ // avro writer is expecting a Java Collection, so we convert it into
+ // `ArrayList` backed by the specified array without data copying.
+ java.util.Arrays.asList(result: _*)
+ }
+
+ case (st: StructType, RECORD) =>
+ val structConverter = newStructConverter(st, avroType)
+ val numFields = st.length
+ (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))
+
+ case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
+ val valueConverter = newConverter(
+ vt, resolveNullableType(avroType.getValueType, valueContainsNull))
+ (getter, ordinal) =>
+ val mapData = getter.getMap(ordinal)
+ val len = mapData.numElements()
+ val result = new java.util.HashMap[String, Any](len)
+ val keyArray = mapData.keyArray()
+ val valueArray = mapData.valueArray()
+ var i = 0
+ while (i < len) {
+ val key = keyArray.getUTF8String(i).toString
+ if (valueContainsNull && valueArray.isNullAt(i)) {
+ result.put(key, null)
+ } else {
+ result.put(key, valueConverter(valueArray, i))
+ }
+ i += 1
+ }
+ result
+
+ case other =>
+ throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystType to " +
+ s"Avro type $avroType.")
+ }
+ }
+
+ private def newStructConverter(
+ catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = {
+ if (avroStruct.getType != RECORD || avroStruct.getFields.size() != catalystStruct.length) {
+ throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " +
+ s"Avro type $avroStruct.")
+ }
+ val avroSchemaHelper = new AvroUtils.AvroSchemaHelper(avroStruct)
+
+ val (avroIndices: Array[Int], fieldConverters: Array[Converter]) =
+ catalystStruct.map { catalystField =>
+ val avroField = avroSchemaHelper.getFieldByName(catalystField.name) match {
+ case Some(f) => f
+ case None => throw new IncompatibleSchemaException(
+ s"Cannot find ${catalystField.name} in Avro schema")
+ }
+ val converter = newConverter(catalystField.dataType, resolveNullableType(
+ avroField.schema(), catalystField.nullable))
+ (avroField.pos(), converter)
+ }.toArray.unzip
+
+ val numFields = catalystStruct.length
+ row: InternalRow =>
+ val result = new Record(avroStruct)
+ var i = 0
+ while (i < numFields) {
+ if (row.isNullAt(i)) {
+ result.put(avroIndices(i), null)
+ } else {
+ result.put(avroIndices(i), fieldConverters(i).apply(row, i))
+ }
+ i += 1
+ }
+ result
+ }
+
+ /**
+ * Resolve a possibly nullable Avro Type.
+ *
+ * An Avro type is nullable when it is a [[UNION]] of two types: one null type and another
+ * non-null type. This method will check the nullability of the input Avro type and return the
+ * non-null type within when it is nullable. Otherwise it will return the input Avro type
+ * unchanged. It will throw an [[UnsupportedAvroTypeException]] when the input Avro type is an
+ * unsupported nullable type.
+ *
+ * It will also log a warning message if the nullability for Avro and catalyst types are
+ * different.
+ */
+ private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = {
+ val (avroNullable, resolvedAvroType) = resolveAvroType(avroType)
+ warnNullabilityDifference(avroNullable, nullable)
+ resolvedAvroType
+ }
+
+ /**
+ * Check the nullability of the input Avro type and resolve it when it is nullable. The first
+ * return value is a [[Boolean]] indicating if the input Avro type is nullable. The second
+ * return value is the possibly resolved type.
+ */
+ private def resolveAvroType(avroType: Schema): (Boolean, Schema) = {
+ if (avroType.getType == Type.UNION) {
+ val fields = avroType.getTypes.asScala
+ val actualType = fields.filter(_.getType != Type.NULL)
+ if (fields.length != 2 || actualType.length != 1) {
+ throw new UnsupportedAvroTypeException(
+ s"Unsupported Avro UNION type $avroType: Only UNION of a null type and a non-null " +
+ "type is supported")
+ }
+ (true, actualType.head)
+ } else {
+ (false, avroType)
+ }
+ }
+
+ /**
+ * log a warning message if the nullability for Avro and catalyst types are different.
+ */
+ private def warnNullabilityDifference(avroNullable: Boolean, catalystNullable: Boolean): Unit = {
+ if (avroNullable && !catalystNullable) {
+ logWarning("Writing Avro files with nullable Avro schema and non-nullable catalyst schema.")
+ }
+ if (!avroNullable && catalystNullable) {
+ logWarning("Writing Avro files with non-nullable Avro schema and nullable catalyst " +
+ "schema will throw runtime exception if there is a record with null value.")
+ }
+ }
+}
+
+object AvroSerializer {
+
+ // NOTE: Following methods have been renamed in Spark 3.1.3 [1] making [[AvroDeserializer]] implementation
+ // (which relies on it) be only compatible with the exact same version of [[DataSourceUtils]].
+ // To make sure this implementation is compatible w/ all Spark versions w/in Spark 3.1.x branch,
+ // we're preemptively cloned those methods to make sure Hudi is compatible w/ Spark 3.1.2 as well as
+ // w/ Spark >= 3.1.3
+ //
+ // [1] https://github.com/apache/spark/pull/34978
+
+ def createDateRebaseFuncInWrite(rebaseMode: LegacyBehaviorPolicy.Value,
+ format: String): Int => Int = rebaseMode match {
+ case LegacyBehaviorPolicy.EXCEPTION => days: Int =>
+ if (days < RebaseDateTime.lastSwitchGregorianDay) {
+ throw DataSourceUtils.newRebaseExceptionInWrite(format)
+ }
+ days
+ case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianDays
+ case LegacyBehaviorPolicy.CORRECTED => identity[Int]
+ }
+
+ def createTimestampRebaseFuncInWrite(rebaseMode: LegacyBehaviorPolicy.Value,
+ format: String): Long => Long = rebaseMode match {
+ case LegacyBehaviorPolicy.EXCEPTION => micros: Long =>
+ if (micros < RebaseDateTime.lastSwitchGregorianTs) {
+ throw DataSourceUtils.newRebaseExceptionInWrite(format)
+ }
+ micros
+ case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianMicros
+ case LegacyBehaviorPolicy.CORRECTED => identity[Long]
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
new file mode 100644
index 000000000..54eacbaa0
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.avro
+
+import org.apache.avro.Schema
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.internal.SQLConf
+
+import java.util.Locale
+import scala.collection.JavaConverters._
+
+/**
+ * NOTE: This code is borrowed from Spark 3.1.3
+ * This code is borrowed, so that we can better control compatibility w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
+ */
+private[avro] object AvroUtils extends Logging {
+
+ /**
+ * Wraps an Avro Schema object so that field lookups are faster.
+ *
+ * @param avroSchema The schema in which to search for fields. Must be of type RECORD.
+ */
+ class AvroSchemaHelper(avroSchema: Schema) {
+ if (avroSchema.getType != Schema.Type.RECORD) {
+ throw new IncompatibleSchemaException(
+ s"Attempting to treat ${avroSchema.getName} as a RECORD, but it was: ${avroSchema.getType}")
+ }
+
+ private[this] val fieldMap = avroSchema.getFields.asScala
+ .groupBy(_.name.toLowerCase(Locale.ROOT))
+ .mapValues(_.toSeq) // toSeq needed for scala 2.13
+
+ /**
+ * Extract a single field from the contained avro schema which has the desired field name,
+ * performing the matching with proper case sensitivity according to SQLConf.resolver.
+ *
+ * @param name The name of the field to search for.
+ * @return `Some(match)` if a matching Avro field is found, otherwise `None`.
+ */
+ def getFieldByName(name: String): Option[Schema.Field] = {
+
+ // get candidates, ignoring case of field name
+ val candidates = fieldMap.get(name.toLowerCase(Locale.ROOT))
+ .getOrElse(Seq.empty[Schema.Field])
+
+ // search candidates, taking into account case sensitivity settings
+ candidates.filter(f => SQLConf.get.resolver(f.name(), name)) match {
+ case Seq(avroField) => Some(avroField)
+ case Seq() => None
+ case matches => throw new IncompatibleSchemaException(
+ s"Searching for '$name' in Avro schema gave ${matches.size} matches. Candidates: " +
+ matches.map(_.name()).mkString("[", ", ", "]")
+ )
+ }
+ }
+ }
+
+ /**
+ * Extract a single field from `avroSchema` which has the desired field name,
+ * performing the matching with proper case sensitivity according to [[SQLConf.resolver]].
+ *
+ * @param avroSchema The schema in which to search for the field. Must be of type RECORD.
+ * @param name The name of the field to search for.
+ * @return `Some(match)` if a matching Avro field is found, otherwise `None`.
+ * @throws IncompatibleSchemaException if `avroSchema` is not a RECORD or contains multiple
+ * fields matching `name` (i.e., case-insensitive matching
+ * is used and `avroSchema` has two or more fields that have
+ * the same name with difference case).
+ */
+ private[avro] def getAvroFieldByName(
+ avroSchema: Schema,
+ name: String): Option[Schema.Field] = {
+ if (avroSchema.getType != Schema.Type.RECORD) {
+ throw new IncompatibleSchemaException(
+ s"Attempting to treat ${avroSchema.getName} as a RECORD, but it was: ${avroSchema.getType}")
+ }
+ avroSchema.getFields.asScala.filter(f => SQLConf.get.resolver(f.name(), name)).toSeq match {
+ case Seq(avroField) => Some(avroField)
+ case Seq() => None
+ case matches => throw new IncompatibleSchemaException(
+ s"Searching for '$name' in Avro schema gave ${matches.size} matches. Candidates: " +
+ matches.map(_.name()).mkString("[", ", ", "]")
+ )
+ }
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_1AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_1AvroDeserializer.scala
new file mode 100644
index 000000000..bf6fcbee7
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_1AvroDeserializer.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.Schema
+import org.apache.spark.sql.types.DataType
+
+class HoodieSpark3_1AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType)
+ extends HoodieAvroDeserializer {
+
+ private val avroDeserializer = new AvroDeserializer(rootAvroType, rootCatalystType)
+
+ def deserialize(data: Any): Option[Any] = avroDeserializer.deserialize(data)
+}
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_1AvroSerializer.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_1AvroSerializer.scala
new file mode 100644
index 000000000..942a4e1b3
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_1AvroSerializer.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.Schema
+import org.apache.spark.sql.types.DataType
+
+class HoodieSpark3_1AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean)
+ extends HoodieAvroSerializer {
+
+ val avroSerializer = new AvroSerializer(rootCatalystType, rootAvroType, nullable)
+
+ override def serialize(catalystData: Any): Any = avroSerializer.serialize(catalystData)
+}
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala
index 1256344c3..c8193699d 100644
--- a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala
@@ -17,17 +17,26 @@
package org.apache.spark.sql.adapter
-import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieSpark3_2CatalystExpressionUtils, SparkSession}
+import org.apache.avro.Schema
+import org.apache.spark.sql.avro._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.parser.HoodieSpark3_2ExtendedSqlParser
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieSpark3_2CatalystExpressionUtils, SparkSession}
/**
* Implementation of [[SparkAdapter]] for Spark 3.2.x branch
*/
class Spark3_2Adapter extends BaseSpark3Adapter {
+ override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer =
+ new HoodieSpark3_2AvroSerializer(rootCatalystType, rootAvroType, nullable)
+
+ override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
+ new HoodieSpark3_2AvroDeserializer(rootAvroType, rootCatalystType)
+
override def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils = HoodieSpark3_2CatalystExpressionUtils
/**
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
new file mode 100644
index 000000000..ef9b59092
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -0,0 +1,510 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import java.math.BigDecimal
+import java.nio.ByteBuffer
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
+import org.apache.avro.Conversions.DecimalConversion
+import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
+import org.apache.avro.Schema.Type._
+import org.apache.avro.generic._
+import org.apache.avro.util.Utf8
+import org.apache.spark.sql.avro.AvroDeserializer.{RebaseSpec, createDateRebaseFuncInRead, createTimestampRebaseFuncInRead}
+import org.apache.spark.sql.avro.AvroUtils.{toFieldDescription, toFieldStr}
+import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
+import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData, RebaseDateTime}
+import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+import java.util.TimeZone
+
+/**
+ * A deserializer to deserialize data in avro format to data in catalyst format.
+ *
+ * NOTE: This code is borrowed from Spark 3.2.1
+ * This code is borrowed, so that we can better control compatibility w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
+ */
+private[sql] class AvroDeserializer(rootAvroType: Schema,
+ rootCatalystType: DataType,
+ positionalFieldMatch: Boolean,
+ datetimeRebaseSpec: RebaseSpec,
+ filters: StructFilters) {
+
+ def this(rootAvroType: Schema,
+ rootCatalystType: DataType,
+ datetimeRebaseMode: String) = {
+ this(
+ rootAvroType,
+ rootCatalystType,
+ positionalFieldMatch = false,
+ RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)),
+ new NoopFilters)
+ }
+
+ private lazy val decimalConversions = new DecimalConversion()
+
+ private val dateRebaseFunc = createDateRebaseFuncInRead(datetimeRebaseSpec.mode, "Avro")
+
+ private val timestampRebaseFunc = createTimestampRebaseFuncInRead(datetimeRebaseSpec, "Avro")
+
+ private val converter: Any => Option[Any] = try {
+ rootCatalystType match {
+ // A shortcut for empty schema.
+ case st: StructType if st.isEmpty =>
+ (_: Any) => Some(InternalRow.empty)
+
+ case st: StructType =>
+ val resultRow = new SpecificInternalRow(st.map(_.dataType))
+ val fieldUpdater = new RowUpdater(resultRow)
+ val applyFilters = filters.skipRow(resultRow, _)
+ val writer = getRecordWriter(rootAvroType, st, Nil, Nil, applyFilters)
+ (data: Any) => {
+ val record = data.asInstanceOf[GenericRecord]
+ val skipRow = writer(fieldUpdater, record)
+ if (skipRow) None else Some(resultRow)
+ }
+
+ case _ =>
+ val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
+ val fieldUpdater = new RowUpdater(tmpRow)
+ val writer = newWriter(rootAvroType, rootCatalystType, Nil, Nil)
+ (data: Any) => {
+ writer(fieldUpdater, 0, data)
+ Some(tmpRow.get(0, rootCatalystType))
+ }
+ }
+ } catch {
+ case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException(
+ s"Cannot convert Avro type $rootAvroType to SQL type ${rootCatalystType.sql}.", ise)
+ }
+
+ def deserialize(data: Any): Option[Any] = converter(data)
+
+ /**
+ * Creates a writer to write avro values to Catalyst values at the given ordinal with the given
+ * updater.
+ */
+ private def newWriter(avroType: Schema,
+ catalystType: DataType,
+ avroPath: Seq[String],
+ catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = {
+ val errorPrefix = s"Cannot convert Avro ${toFieldStr(avroPath)} to " +
+ s"SQL ${toFieldStr(catalystPath)} because "
+ val incompatibleMsg = errorPrefix +
+ s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})"
+
+ (avroType.getType, catalystType) match {
+ case (NULL, NullType) => (updater, ordinal, _) =>
+ updater.setNullAt(ordinal)
+
+ // TODO: we can avoid boxing if future version of avro provide primitive accessors.
+ case (BOOLEAN, BooleanType) => (updater, ordinal, value) =>
+ updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
+
+ case (INT, IntegerType) => (updater, ordinal, value) =>
+ updater.setInt(ordinal, value.asInstanceOf[Int])
+
+ case (INT, DateType) => (updater, ordinal, value) =>
+ updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int]))
+
+ case (LONG, LongType) => (updater, ordinal, value) =>
+ updater.setLong(ordinal, value.asInstanceOf[Long])
+
+ case (LONG, TimestampType) => avroType.getLogicalType match {
+ // For backward compatibility, if the Avro type is Long and it is not logical type
+ // (the `null` case), the value is processed as timestamp type with millisecond precision.
+ case null | _: TimestampMillis => (updater, ordinal, value) =>
+ val millis = value.asInstanceOf[Long]
+ val micros = DateTimeUtils.millisToMicros(millis)
+ updater.setLong(ordinal, timestampRebaseFunc(micros))
+ case _: TimestampMicros => (updater, ordinal, value) =>
+ val micros = value.asInstanceOf[Long]
+ updater.setLong(ordinal, timestampRebaseFunc(micros))
+ case other => throw new IncompatibleSchemaException(errorPrefix +
+ s"Avro logical type $other cannot be converted to SQL type ${TimestampType.sql}.")
+ }
+
+ // Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date.
+ // For backward compatibility, we still keep this conversion.
+ case (LONG, DateType) => (updater, ordinal, value) =>
+ updater.setInt(ordinal, (value.asInstanceOf[Long] / MILLIS_PER_DAY).toInt)
+
+ case (FLOAT, FloatType) => (updater, ordinal, value) =>
+ updater.setFloat(ordinal, value.asInstanceOf[Float])
+
+ case (DOUBLE, DoubleType) => (updater, ordinal, value) =>
+ updater.setDouble(ordinal, value.asInstanceOf[Double])
+
+ case (STRING, StringType) => (updater, ordinal, value) =>
+ val str = value match {
+ case s: String => UTF8String.fromString(s)
+ case s: Utf8 =>
+ val bytes = new Array[Byte](s.getByteLength)
+ System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength)
+ UTF8String.fromBytes(bytes)
+ }
+ updater.set(ordinal, str)
+
+ case (ENUM, StringType) => (updater, ordinal, value) =>
+ updater.set(ordinal, UTF8String.fromString(value.toString))
+
+ case (FIXED, BinaryType) => (updater, ordinal, value) =>
+ updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone())
+
+ case (BYTES, BinaryType) => (updater, ordinal, value) =>
+ val bytes = value match {
+ case b: ByteBuffer =>
+ val bytes = new Array[Byte](b.remaining)
+ b.get(bytes)
+ bytes
+ case b: Array[Byte] => b
+ case other =>
+ throw new RuntimeException(errorPrefix + s"$other is not a valid avro binary.")
+ }
+ updater.set(ordinal, bytes)
+
+ case (FIXED, _: DecimalType) => (updater, ordinal, value) =>
+ val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal]
+ val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, d)
+ val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale)
+ updater.setDecimal(ordinal, decimal)
+
+ case (BYTES, _: DecimalType) => (updater, ordinal, value) =>
+ val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal]
+ val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, d)
+ val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale)
+ updater.setDecimal(ordinal, decimal)
+
+ case (RECORD, st: StructType) =>
+ // Avro datasource doesn't accept filters with nested attributes. See SPARK-32328.
+ // We can always return `false` from `applyFilters` for nested records.
+ val writeRecord =
+ getRecordWriter(avroType, st, avroPath, catalystPath, applyFilters = _ => false)
+ (updater, ordinal, value) =>
+ val row = new SpecificInternalRow(st)
+ writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord])
+ updater.set(ordinal, row)
+
+ case (ARRAY, ArrayType(elementType, containsNull)) =>
+ val avroElementPath = avroPath :+ "element"
+ val elementWriter = newWriter(avroType.getElementType, elementType,
+ avroElementPath, catalystPath :+ "element")
+ (updater, ordinal, value) =>
+ val collection = value.asInstanceOf[java.util.Collection[Any]]
+ val result = createArrayData(elementType, collection.size())
+ val elementUpdater = new ArrayDataUpdater(result)
+
+ var i = 0
+ val iter = collection.iterator()
+ while (iter.hasNext) {
+ val element = iter.next()
+ if (element == null) {
+ if (!containsNull) {
+ throw new RuntimeException(
+ s"Array value at path ${toFieldStr(avroElementPath)} is not allowed to be null")
+ } else {
+ elementUpdater.setNullAt(i)
+ }
+ } else {
+ elementWriter(elementUpdater, i, element)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, result)
+
+ case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType =>
+ val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType,
+ avroPath :+ "key", catalystPath :+ "key")
+ val valueWriter = newWriter(avroType.getValueType, valueType,
+ avroPath :+ "value", catalystPath :+ "value")
+ (updater, ordinal, value) =>
+ val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]]
+ val keyArray = createArrayData(keyType, map.size())
+ val keyUpdater = new ArrayDataUpdater(keyArray)
+ val valueArray = createArrayData(valueType, map.size())
+ val valueUpdater = new ArrayDataUpdater(valueArray)
+ val iter = map.entrySet().iterator()
+ var i = 0
+ while (iter.hasNext) {
+ val entry = iter.next()
+ assert(entry.getKey != null)
+ keyWriter(keyUpdater, i, entry.getKey)
+ if (entry.getValue == null) {
+ if (!valueContainsNull) {
+ throw new RuntimeException(
+ s"Map value at path ${toFieldStr(avroPath :+ "value")} is not allowed to be null")
+ } else {
+ valueUpdater.setNullAt(i)
+ }
+ } else {
+ valueWriter(valueUpdater, i, entry.getValue)
+ }
+ i += 1
+ }
+
+ // The Avro map will never have null or duplicated map keys, it's safe to create a
+ // ArrayBasedMapData directly here.
+ updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
+
+ case (UNION, _) =>
+ val allTypes = avroType.getTypes.asScala
+ val nonNullTypes = allTypes.filter(_.getType != NULL)
+ val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava)
+ if (nonNullTypes.nonEmpty) {
+ if (nonNullTypes.length == 1) {
+ newWriter(nonNullTypes.head, catalystType, avroPath, catalystPath)
+ } else {
+ nonNullTypes.map(_.getType).toSeq match {
+ case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType =>
+ (updater, ordinal, value) =>
+ value match {
+ case null => updater.setNullAt(ordinal)
+ case l: java.lang.Long => updater.setLong(ordinal, l)
+ case i: java.lang.Integer => updater.setLong(ordinal, i.longValue())
+ }
+
+ case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType =>
+ (updater, ordinal, value) =>
+ value match {
+ case null => updater.setNullAt(ordinal)
+ case d: java.lang.Double => updater.setDouble(ordinal, d)
+ case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue())
+ }
+
+ case _ =>
+ catalystType match {
+ case st: StructType if st.length == nonNullTypes.size =>
+ val fieldWriters = nonNullTypes.zip(st.fields).map {
+ case (schema, field) =>
+ newWriter(schema, field.dataType, avroPath, catalystPath :+ field.name)
+ }.toArray
+ (updater, ordinal, value) => {
+ val row = new SpecificInternalRow(st)
+ val fieldUpdater = new RowUpdater(row)
+ val i = GenericData.get().resolveUnion(nonNullAvroType, value)
+ fieldWriters(i)(fieldUpdater, i, value)
+ updater.set(ordinal, row)
+ }
+
+ case _ => throw new IncompatibleSchemaException(incompatibleMsg)
+ }
+ }
+ }
+ } else {
+ (updater, ordinal, _) => updater.setNullAt(ordinal)
+ }
+
+ case _ => throw new IncompatibleSchemaException(incompatibleMsg)
+ }
+ }
+
+ // TODO: move the following method in Decimal object on creating Decimal from BigDecimal?
+ private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
+ if (precision <= Decimal.MAX_LONG_DIGITS) {
+ // Constructs a `Decimal` with an unscaled `Long` value if possible.
+ Decimal(decimal.unscaledValue().longValue(), precision, scale)
+ } else {
+ // Otherwise, resorts to an unscaled `BigInteger` instead.
+ Decimal(decimal, precision, scale)
+ }
+ }
+
+ private def getRecordWriter(avroType: Schema,
+ catalystType: StructType,
+ avroPath: Seq[String],
+ catalystPath: Seq[String],
+ applyFilters: Int => Boolean): (CatalystDataUpdater, GenericRecord) => Boolean = {
+ val validFieldIndexes = ArrayBuffer.empty[Int]
+ val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit]
+
+ val avroSchemaHelper =
+ new AvroUtils.AvroSchemaHelper(avroType, avroPath, positionalFieldMatch)
+ val length = catalystType.length
+ var i = 0
+ while (i < length) {
+ val catalystField = catalystType.fields(i)
+ avroSchemaHelper.getAvroField(catalystField.name, i) match {
+ case Some(avroField) =>
+ validFieldIndexes += avroField.pos()
+
+ val baseWriter = newWriter(avroField.schema(), catalystField.dataType,
+ avroPath :+ avroField.name, catalystPath :+ catalystField.name)
+ val ordinal = i
+ val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
+ if (value == null) {
+ fieldUpdater.setNullAt(ordinal)
+ } else {
+ baseWriter(fieldUpdater, ordinal, value)
+ }
+ }
+ fieldWriters += fieldWriter
+ case None if !catalystField.nullable =>
+ val fieldDescription =
+ toFieldDescription(catalystPath :+ catalystField.name, i, positionalFieldMatch)
+ throw new IncompatibleSchemaException(
+ s"Cannot find non-nullable $fieldDescription in Avro schema.")
+ case _ => // nothing to do
+ }
+ i += 1
+ }
+
+ (fieldUpdater, record) => {
+ var i = 0
+ var skipRow = false
+ while (i < validFieldIndexes.length && !skipRow) {
+ fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i)))
+ skipRow = applyFilters(i)
+ i += 1
+ }
+ skipRow
+ }
+ }
+
+ private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match {
+ case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length))
+ case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length))
+ case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length))
+ case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
+ case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
+ case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length))
+ case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length))
+ case _ => new GenericArrayData(new Array[Any](length))
+ }
+
+ /**
+ * A base interface for updating values inside catalyst data structure like `InternalRow` and
+ * `ArrayData`.
+ */
+ sealed trait CatalystDataUpdater {
+ def set(ordinal: Int, value: Any): Unit
+
+ def setNullAt(ordinal: Int): Unit = set(ordinal, null)
+
+ def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
+
+ def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
+
+ def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
+
+ def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
+
+ def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
+
+ def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
+
+ def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
+
+ def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value)
+ }
+
+ final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
+ override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)
+
+ override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
+
+ override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
+
+ override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
+
+ override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
+
+ override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
+
+ override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
+
+ override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
+
+ override def setDecimal(ordinal: Int, value: Decimal): Unit =
+ row.setDecimal(ordinal, value, value.precision)
+ }
+
+ final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
+ override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value)
+
+ override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value)
+
+ override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value)
+
+ override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value)
+
+ override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value)
+
+ override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
+
+ override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
+
+ override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
+
+ override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value)
+ }
+}
+
+object AvroDeserializer {
+
+ // NOTE: Following methods have been renamed in Spark 3.2.1 [1] making [[AvroDeserializer]] implementation
+ // (which relies on it) be only compatible with the exact same version of [[DataSourceUtils]].
+ // To make sure this implementation is compatible w/ all Spark versions w/in Spark 3.2.x branch,
+ // we're preemptively cloned those methods to make sure Hudi is compatible w/ Spark 3.2.0 as well as
+ // w/ Spark >= 3.2.1
+ //
+ // [1] https://github.com/apache/spark/pull/34978
+
+ // Specification of rebase operation including `mode` and the time zone in which it is performed
+ case class RebaseSpec(mode: LegacyBehaviorPolicy.Value, originTimeZone: Option[String] = None) {
+ // Use the default JVM time zone for backward compatibility
+ def timeZone: String = originTimeZone.getOrElse(TimeZone.getDefault.getID)
+ }
+
+ def createDateRebaseFuncInRead(rebaseMode: LegacyBehaviorPolicy.Value,
+ format: String): Int => Int = rebaseMode match {
+ case LegacyBehaviorPolicy.EXCEPTION => days: Int =>
+ if (days < RebaseDateTime.lastSwitchJulianDay) {
+ throw DataSourceUtils.newRebaseExceptionInRead(format)
+ }
+ days
+ case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays
+ case LegacyBehaviorPolicy.CORRECTED => identity[Int]
+ }
+
+ def createTimestampRebaseFuncInRead(rebaseSpec: RebaseSpec,
+ format: String): Long => Long = rebaseSpec.mode match {
+ case LegacyBehaviorPolicy.EXCEPTION => micros: Long =>
+ if (micros < RebaseDateTime.lastSwitchJulianTs) {
+ throw DataSourceUtils.newRebaseExceptionInRead(format)
+ }
+ micros
+ case LegacyBehaviorPolicy.LEGACY => micros: Long =>
+ RebaseDateTime.rebaseJulianToGregorianMicros(TimeZone.getTimeZone(rebaseSpec.timeZone), micros)
+ case LegacyBehaviorPolicy.CORRECTED => identity[Long]
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
new file mode 100644
index 000000000..2fe51d367
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -0,0 +1,377 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import java.nio.ByteBuffer
+import scala.collection.JavaConverters._
+import org.apache.avro.Conversions.DecimalConversion
+import org.apache.avro.LogicalTypes
+import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
+import org.apache.avro.Schema
+import org.apache.avro.Schema.Type
+import org.apache.avro.Schema.Type._
+import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
+import org.apache.avro.generic.GenericData.Record
+import org.apache.avro.util.Utf8
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.avro.AvroSerializer.{createDateRebaseFuncInWrite, createTimestampRebaseFuncInWrite}
+import org.apache.spark.sql.avro.AvroUtils.{toFieldDescription, toFieldStr}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, RebaseDateTime}
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.sql.types._
+
+import java.util.TimeZone
+
+/**
+ * A serializer to serialize data in catalyst format to data in avro format.
+ *
+ * NOTE: This code is borrowed from Spark 3.2.1
+ * This code is borrowed, so that we can better control compatibility w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
+ */
+private[sql] class AvroSerializer(rootCatalystType: DataType,
+ rootAvroType: Schema,
+ nullable: Boolean,
+ positionalFieldMatch: Boolean,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends Logging {
+
+ def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) = {
+ this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch = false,
+ LegacyBehaviorPolicy.withName(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE)))
+ }
+
+ def serialize(catalystData: Any): Any = {
+ converter.apply(catalystData)
+ }
+
+ private val dateRebaseFunc = createDateRebaseFuncInWrite(
+ datetimeRebaseMode, "Avro")
+
+ private val timestampRebaseFunc = createTimestampRebaseFuncInWrite(
+ datetimeRebaseMode, "Avro")
+
+ private val converter: Any => Any = {
+ val actualAvroType = resolveNullableType(rootAvroType, nullable)
+ val baseConverter = try {
+ rootCatalystType match {
+ case st: StructType =>
+ newStructConverter(st, actualAvroType, Nil, Nil).asInstanceOf[Any => Any]
+ case _ =>
+ val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
+ val converter = newConverter(rootCatalystType, actualAvroType, Nil, Nil)
+ (data: Any) =>
+ tmpRow.update(0, data)
+ converter.apply(tmpRow, 0)
+ }
+ } catch {
+ case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException(
+ s"Cannot convert SQL type ${rootCatalystType.sql} to Avro type $rootAvroType.", ise)
+ }
+ if (nullable) {
+ (data: Any) =>
+ if (data == null) {
+ null
+ } else {
+ baseConverter.apply(data)
+ }
+ } else {
+ baseConverter
+ }
+ }
+
+ private type Converter = (SpecializedGetters, Int) => Any
+
+ private lazy val decimalConversions = new DecimalConversion()
+
+ private def newConverter(catalystType: DataType,
+ avroType: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): Converter = {
+ val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " +
+ s"to Avro ${toFieldStr(avroPath)} because "
+ (catalystType, avroType.getType) match {
+ case (NullType, NULL) =>
+ (getter, ordinal) => null
+ case (BooleanType, BOOLEAN) =>
+ (getter, ordinal) => getter.getBoolean(ordinal)
+ case (ByteType, INT) =>
+ (getter, ordinal) => getter.getByte(ordinal).toInt
+ case (ShortType, INT) =>
+ (getter, ordinal) => getter.getShort(ordinal).toInt
+ case (IntegerType, INT) =>
+ (getter, ordinal) => getter.getInt(ordinal)
+ case (LongType, LONG) =>
+ (getter, ordinal) => getter.getLong(ordinal)
+ case (FloatType, FLOAT) =>
+ (getter, ordinal) => getter.getFloat(ordinal)
+ case (DoubleType, DOUBLE) =>
+ (getter, ordinal) => getter.getDouble(ordinal)
+ case (d: DecimalType, FIXED)
+ if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
+ (getter, ordinal) =>
+ val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
+ decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
+ LogicalTypes.decimal(d.precision, d.scale))
+
+ case (d: DecimalType, BYTES)
+ if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
+ (getter, ordinal) =>
+ val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
+ decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType,
+ LogicalTypes.decimal(d.precision, d.scale))
+
+ case (StringType, ENUM) =>
+ val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
+ (getter, ordinal) =>
+ val data = getter.getUTF8String(ordinal).toString
+ if (!enumSymbols.contains(data)) {
+ throw new IncompatibleSchemaException(errorPrefix +
+ s""""$data" cannot be written since it's not defined in enum """ +
+ enumSymbols.mkString("\"", "\", \"", "\""))
+ }
+ new EnumSymbol(avroType, data)
+
+ case (StringType, STRING) =>
+ (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
+
+ case (BinaryType, FIXED) =>
+ val size = avroType.getFixedSize
+ (getter, ordinal) =>
+ val data: Array[Byte] = getter.getBinary(ordinal)
+ if (data.length != size) {
+ def len2str(len: Int): String = s"$len ${if (len > 1) "bytes" else "byte"}"
+
+ throw new IncompatibleSchemaException(errorPrefix + len2str(data.length) +
+ " of binary data cannot be written into FIXED type with size of " + len2str(size))
+ }
+ new Fixed(avroType, data)
+
+ case (BinaryType, BYTES) =>
+ (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
+
+ case (DateType, INT) =>
+ (getter, ordinal) => dateRebaseFunc(getter.getInt(ordinal))
+
+ case (TimestampType, LONG) => avroType.getLogicalType match {
+ // For backward compatibility, if the Avro type is Long and it is not logical type
+ // (the `null` case), output the timestamp value as with millisecond precision.
+ case null | _: TimestampMillis => (getter, ordinal) =>
+ DateTimeUtils.microsToMillis(timestampRebaseFunc(getter.getLong(ordinal)))
+ case _: TimestampMicros => (getter, ordinal) =>
+ timestampRebaseFunc(getter.getLong(ordinal))
+ case other => throw new IncompatibleSchemaException(errorPrefix +
+ s"SQL type ${TimestampType.sql} cannot be converted to Avro logical type $other")
+ }
+
+ case (ArrayType(et, containsNull), ARRAY) =>
+ val elementConverter = newConverter(
+ et, resolveNullableType(avroType.getElementType, containsNull),
+ catalystPath :+ "element", avroPath :+ "element")
+ (getter, ordinal) => {
+ val arrayData = getter.getArray(ordinal)
+ val len = arrayData.numElements()
+ val result = new Array[Any](len)
+ var i = 0
+ while (i < len) {
+ if (containsNull && arrayData.isNullAt(i)) {
+ result(i) = null
+ } else {
+ result(i) = elementConverter(arrayData, i)
+ }
+ i += 1
+ }
+ // avro writer is expecting a Java Collection, so we convert it into
+ // `ArrayList` backed by the specified array without data copying.
+ java.util.Arrays.asList(result: _*)
+ }
+
+ case (st: StructType, RECORD) =>
+ val structConverter = newStructConverter(st, avroType, catalystPath, avroPath)
+ val numFields = st.length
+ (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))
+
+ case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
+ val valueConverter = newConverter(
+ vt, resolveNullableType(avroType.getValueType, valueContainsNull),
+ catalystPath :+ "value", avroPath :+ "value")
+ (getter, ordinal) =>
+ val mapData = getter.getMap(ordinal)
+ val len = mapData.numElements()
+ val result = new java.util.HashMap[String, Any](len)
+ val keyArray = mapData.keyArray()
+ val valueArray = mapData.valueArray()
+ var i = 0
+ while (i < len) {
+ val key = keyArray.getUTF8String(i).toString
+ if (valueContainsNull && valueArray.isNullAt(i)) {
+ result.put(key, null)
+ } else {
+ result.put(key, valueConverter(valueArray, i))
+ }
+ i += 1
+ }
+ result
+
+ case _ =>
+ throw new IncompatibleSchemaException(errorPrefix +
+ s"schema is incompatible (sqlType = ${catalystType.sql}, avroType = $avroType)")
+ }
+ }
+
+ private def newStructConverter(catalystStruct: StructType,
+ avroStruct: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): InternalRow => Record = {
+
+ val avroPathStr = toFieldStr(avroPath)
+ if (avroStruct.getType != RECORD) {
+ throw new IncompatibleSchemaException(s"$avroPathStr was not a RECORD")
+ }
+ val avroFields = avroStruct.getFields.asScala
+ if (avroFields.size != catalystStruct.length) {
+ throw new IncompatibleSchemaException(
+ s"Avro $avroPathStr schema length (${avroFields.size}) doesn't match " +
+ s"SQL ${toFieldStr(catalystPath)} schema length (${catalystStruct.length})")
+ }
+ val avroSchemaHelper =
+ new AvroUtils.AvroSchemaHelper(avroStruct, avroPath, positionalFieldMatch)
+
+ val (avroIndices: Array[Int], fieldConverters: Array[Converter]) =
+ catalystStruct.zipWithIndex.map { case (catalystField, catalystPos) =>
+ val avroField = avroSchemaHelper.getAvroField(catalystField.name, catalystPos) match {
+ case Some(f) => f
+ case None =>
+ val fieldDescription = toFieldDescription(
+ catalystPath :+ catalystField.name, catalystPos, positionalFieldMatch)
+ throw new IncompatibleSchemaException(
+ s"Cannot find $fieldDescription in Avro schema at $avroPathStr")
+ }
+ val converter = newConverter(catalystField.dataType,
+ resolveNullableType(avroField.schema(), catalystField.nullable),
+ catalystPath :+ catalystField.name, avroPath :+ avroField.name)
+ (avroField.pos(), converter)
+ }.toArray.unzip
+
+ val numFields = catalystStruct.length
+ row: InternalRow =>
+ val result = new Record(avroStruct)
+ var i = 0
+ while (i < numFields) {
+ if (row.isNullAt(i)) {
+ result.put(avroIndices(i), null)
+ } else {
+ result.put(avroIndices(i), fieldConverters(i).apply(row, i))
+ }
+ i += 1
+ }
+ result
+ }
+
+ /**
+ * Resolve a possibly nullable Avro Type.
+ *
+ * An Avro type is nullable when it is a [[UNION]] of two types: one null type and another
+ * non-null type. This method will check the nullability of the input Avro type and return the
+ * non-null type within when it is nullable. Otherwise it will return the input Avro type
+ * unchanged. It will throw an [[UnsupportedAvroTypeException]] when the input Avro type is an
+ * unsupported nullable type.
+ *
+ * It will also log a warning message if the nullability for Avro and catalyst types are
+ * different.
+ */
+ private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = {
+ val (avroNullable, resolvedAvroType) = resolveAvroType(avroType)
+ warnNullabilityDifference(avroNullable, nullable)
+ resolvedAvroType
+ }
+
+ /**
+ * Check the nullability of the input Avro type and resolve it when it is nullable. The first
+ * return value is a [[Boolean]] indicating if the input Avro type is nullable. The second
+ * return value is the possibly resolved type.
+ */
+ private def resolveAvroType(avroType: Schema): (Boolean, Schema) = {
+ if (avroType.getType == Type.UNION) {
+ val fields = avroType.getTypes.asScala
+ val actualType = fields.filter(_.getType != Type.NULL)
+ if (fields.length != 2 || actualType.length != 1) {
+ throw new UnsupportedAvroTypeException(
+ s"Unsupported Avro UNION type $avroType: Only UNION of a null type and a non-null " +
+ "type is supported")
+ }
+ (true, actualType.head)
+ } else {
+ (false, avroType)
+ }
+ }
+
+ /**
+ * log a warning message if the nullability for Avro and catalyst types are different.
+ */
+ private def warnNullabilityDifference(avroNullable: Boolean, catalystNullable: Boolean): Unit = {
+ if (avroNullable && !catalystNullable) {
+ logWarning("Writing Avro files with nullable Avro schema and non-nullable catalyst schema.")
+ }
+ if (!avroNullable && catalystNullable) {
+ logWarning("Writing Avro files with non-nullable Avro schema and nullable catalyst " +
+ "schema will throw runtime exception if there is a record with null value.")
+ }
+ }
+}
+
+object AvroSerializer {
+
+ // NOTE: Following methods have been renamed in Spark 3.2.1 [1] making [[AvroSerializer]] implementation
+ // (which relies on it) be only compatible with the exact same version of [[DataSourceUtils]].
+ // To make sure this implementation is compatible w/ all Spark versions w/in Spark 3.2.x branch,
+ // we're preemptively cloned those methods to make sure Hudi is compatible w/ Spark 3.2.0 as well as
+ // w/ Spark >= 3.2.1
+ //
+ // [1] https://github.com/apache/spark/pull/34978
+
+ def createDateRebaseFuncInWrite(rebaseMode: LegacyBehaviorPolicy.Value,
+ format: String): Int => Int = rebaseMode match {
+ case LegacyBehaviorPolicy.EXCEPTION => days: Int =>
+ if (days < RebaseDateTime.lastSwitchGregorianDay) {
+ throw DataSourceUtils.newRebaseExceptionInWrite(format)
+ }
+ days
+ case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianDays
+ case LegacyBehaviorPolicy.CORRECTED => identity[Int]
+ }
+
+ def createTimestampRebaseFuncInWrite(rebaseMode: LegacyBehaviorPolicy.Value,
+ format: String): Long => Long = rebaseMode match {
+ case LegacyBehaviorPolicy.EXCEPTION => micros: Long =>
+ if (micros < RebaseDateTime.lastSwitchGregorianTs) {
+ throw DataSourceUtils.newRebaseExceptionInWrite(format)
+ }
+ micros
+ case LegacyBehaviorPolicy.LEGACY =>
+ val timeZone = SQLConf.get.sessionLocalTimeZone
+ RebaseDateTime.rebaseGregorianToJulianMicros(TimeZone.getTimeZone(timeZone), _)
+ case LegacyBehaviorPolicy.CORRECTED => identity[Long]
+ }
+
+}
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
new file mode 100644
index 000000000..f63133795
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.Schema
+import org.apache.spark.sql.internal.SQLConf
+
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+
+/**
+ * NOTE: This code is borrowed from Spark 3.2.1
+ * This code is borrowed, so that we can better control compatibility w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
+ */
+private[avro] object AvroUtils {
+
+ /**
+ * Wraps an Avro Schema object so that field lookups are faster.
+ *
+ * @param avroSchema The schema in which to search for fields. Must be of type RECORD.
+ * @param avroPath The seq of parent field names leading to `avroSchema`.
+ * @param positionalFieldMatch If true, perform field matching in a positional fashion
+ * (structural comparison between schemas, ignoring names);
+ * otherwise, perform field matching using field names.
+ */
+ class AvroSchemaHelper(avroSchema: Schema,
+ avroPath: Seq[String],
+ positionalFieldMatch: Boolean) {
+ if (avroSchema.getType != Schema.Type.RECORD) {
+ throw new IncompatibleSchemaException(
+ s"Attempting to treat ${avroSchema.getName} as a RECORD, but it was: ${avroSchema.getType}")
+ }
+
+ private[this] val avroFieldArray = avroSchema.getFields.asScala.toArray
+ private[this] val fieldMap = avroSchema.getFields.asScala
+ .groupBy(_.name.toLowerCase(Locale.ROOT))
+ .mapValues(_.toSeq) // toSeq needed for scala 2.13
+
+ /**
+ * Extract a single field from the contained avro schema which has the desired field name,
+ * performing the matching with proper case sensitivity according to SQLConf.resolver.
+ *
+ * @param name The name of the field to search for.
+ * @return `Some(match)` if a matching Avro field is found, otherwise `None`.
+ */
+ private[avro] def getFieldByName(name: String): Option[Schema.Field] = {
+
+ // get candidates, ignoring case of field name
+ val candidates = fieldMap.getOrElse(name.toLowerCase(Locale.ROOT), Seq.empty)
+
+ // search candidates, taking into account case sensitivity settings
+ candidates.filter(f => SQLConf.get.resolver(f.name(), name)) match {
+ case Seq(avroField) => Some(avroField)
+ case Seq() => None
+ case matches => throw new IncompatibleSchemaException(s"Searching for '$name' in Avro " +
+ s"schema at ${toFieldStr(avroPath)} gave ${matches.size} matches. Candidates: " +
+ matches.map(_.name()).mkString("[", ", ", "]")
+ )
+ }
+ }
+
+ /** Get the Avro field corresponding to the provided Catalyst field name/position, if any. */
+ def getAvroField(fieldName: String, catalystPos: Int): Option[Schema.Field] = {
+ if (positionalFieldMatch) {
+ avroFieldArray.lift(catalystPos)
+ } else {
+ getFieldByName(fieldName)
+ }
+ }
+ }
+
+
+ /**
+ * Take a field's hierarchical names (see [[toFieldStr]]) and position, and convert it to a
+ * human-readable description of the field. Depending on the value of `positionalFieldMatch`,
+ * either the position or name will be emphasized (for true and false, respectively); both will
+ * be included in either case.
+ */
+ private[avro] def toFieldDescription(
+ names: Seq[String],
+ position: Int,
+ positionalFieldMatch: Boolean): String = if (positionalFieldMatch) {
+ s"field at position $position (${toFieldStr(names)})"
+ } else {
+ s"${toFieldStr(names)} (at position $position)"
+ }
+
+ /**
+ * Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable
+ * string representing the field, like "field 'foo.bar'". If `names` is empty, the string
+ * "top-level record" is returned.
+ */
+ private[avro] def toFieldStr(names: Seq[String]): String = names match {
+ case Seq() => "top-level record"
+ case n => s"field '${n.mkString(".")}'"
+ }
+
+}
diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_2AvroDeserializer.scala
similarity index 60%
rename from hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3AvroDeserializer.scala
rename to hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_2AvroDeserializer.scala
index bd9ead5a7..0275e2f63 100644
--- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3AvroDeserializer.scala
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_2AvroDeserializer.scala
@@ -21,18 +21,10 @@ import org.apache.avro.Schema
import org.apache.hudi.HoodieSparkUtils
import org.apache.spark.sql.types.DataType
-class HoodieSpark3AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType)
+class HoodieSpark3_2AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType)
extends HoodieAvroDeserializer {
- // SPARK-34404: As of Spark3.2, there is no AvroDeserializer's constructor with Schema and DataType arguments.
- // So use the reflection to get AvroDeserializer instance.
- private val avroDeserializer = if (HoodieSparkUtils.isSpark3_2) {
- val constructor = classOf[AvroDeserializer].getConstructor(classOf[Schema], classOf[DataType], classOf[String])
- constructor.newInstance(rootAvroType, rootCatalystType, "EXCEPTION")
- } else {
- val constructor = classOf[AvroDeserializer].getConstructor(classOf[Schema], classOf[DataType])
- constructor.newInstance(rootAvroType, rootCatalystType)
- }
+ private val avroDeserializer = new AvroDeserializer(rootAvroType, rootCatalystType, "EXCEPTION")
def deserialize(data: Any): Option[Any] = avroDeserializer.deserialize(data)
}
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_2AvroSerializer.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_2AvroSerializer.scala
new file mode 100644
index 000000000..6e76ba68f
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_2AvroSerializer.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.Schema
+import org.apache.spark.sql.types.DataType
+
+class HoodieSpark3_2AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean)
+ extends HoodieAvroSerializer {
+
+ val avroSerializer = new AvroSerializer(rootCatalystType, rootAvroType, nullable)
+
+ override def serialize(catalystData: Any): Any = avroSerializer.serialize(catalystData)
+}
diff --git a/hudi-utilities/pom.xml b/hudi-utilities/pom.xml
index 39510537b..7d297bb4c 100644
--- a/hudi-utilities/pom.xml
+++ b/hudi-utilities/pom.xml
@@ -233,12 +233,6 @@
-
- org.apache.spark
- spark-avro_${scala.binary.version}
- provided
-
-
org.apache.spark
spark-streaming_${scala.binary.version}
diff --git a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieDataTableValidator.java b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieDataTableValidator.java
index 0180fa0af..ef05bdc03 100644
--- a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieDataTableValidator.java
+++ b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieDataTableValidator.java
@@ -69,7 +69,6 @@ import java.util.stream.Stream;
* ```
* spark-submit \
* --class org.apache.hudi.utilities.HoodieDataTableValidator \
- * --packages org.apache.spark:spark-avro_2.11:2.4.4 \
* --master spark://xxxx:7077 \
* --driver-memory 1g \
* --executor-memory 1g \
@@ -85,7 +84,6 @@ import java.util.stream.Stream;
* ```
* spark-submit \
* --class org.apache.hudi.utilities.HoodieDataTableValidator \
- * --packages org.apache.spark:spark-avro_2.11:2.4.4 \
* --master spark://xxxx:7077 \
* --driver-memory 1g \
* --executor-memory 1g \
diff --git a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieMetadataTableValidator.java b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieMetadataTableValidator.java
index af0c10099..bed4c812c 100644
--- a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieMetadataTableValidator.java
+++ b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieMetadataTableValidator.java
@@ -92,7 +92,6 @@ import java.util.stream.Collectors;
* ```
* spark-submit \
* --class org.apache.hudi.utilities.HoodieMetadataTableValidator \
- * --packages org.apache.spark:spark-avro_2.11:2.4.4 \
* --master spark://xxxx:7077 \
* --driver-memory 1g \
* --executor-memory 1g \
@@ -111,7 +110,6 @@ import java.util.stream.Collectors;
* ```
* spark-submit \
* --class org.apache.hudi.utilities.HoodieMetadataTableValidator \
- * --packages org.apache.spark:spark-avro_2.11:2.4.4 \
* --master spark://xxxx:7077 \
* --driver-memory 1g \
* --executor-memory 1g \
diff --git a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieRepairTool.java b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieRepairTool.java
index 7d725ed6a..14b637d5e 100644
--- a/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieRepairTool.java
+++ b/hudi-utilities/src/main/java/org/apache/hudi/utilities/HoodieRepairTool.java
@@ -65,7 +65,6 @@ import java.util.stream.Collectors;
* --conf spark.serializer=org.apache.spark.serializer.KryoSerializer \
* --conf spark.sql.catalogImplementation=hive \
* --conf spark.sql.extensions=org.apache.spark.sql.hudi.HoodieSparkSessionExtension \
- * --packages org.apache.spark:spark-avro_2.12:3.1.2 \
* $HUDI_DIR/packaging/hudi-utilities-bundle/target/hudi-utilities-bundle_2.12-0.11.0-SNAPSHOT.jar \
* --mode dry_run \
* --base-path base_path \
@@ -89,7 +88,6 @@ import java.util.stream.Collectors;
* --conf spark.serializer=org.apache.spark.serializer.KryoSerializer \
* --conf spark.sql.catalogImplementation=hive \
* --conf spark.sql.extensions=org.apache.spark.sql.hudi.HoodieSparkSessionExtension \
- * --packages org.apache.spark:spark-avro_2.12:3.1.2 \
* $HUDI_DIR/packaging/hudi-utilities-bundle/target/hudi-utilities-bundle_2.12-0.11.0-SNAPSHOT.jar \
* --mode repair \
* --base-path base_path \
@@ -112,7 +110,6 @@ import java.util.stream.Collectors;
* --conf spark.serializer=org.apache.spark.serializer.KryoSerializer \
* --conf spark.sql.catalogImplementation=hive \
* --conf spark.sql.extensions=org.apache.spark.sql.hudi.HoodieSparkSessionExtension \
- * --packages org.apache.spark:spark-avro_2.12:3.1.2 \
* $HUDI_DIR/packaging/hudi-utilities-bundle/target/hudi-utilities-bundle_2.12-0.11.0-SNAPSHOT.jar \
* --mode dry_run \
* --base-path base_path \
@@ -133,7 +130,6 @@ import java.util.stream.Collectors;
* --conf spark.serializer=org.apache.spark.serializer.KryoSerializer \
* --conf spark.sql.catalogImplementation=hive \
* --conf spark.sql.extensions=org.apache.spark.sql.hudi.HoodieSparkSessionExtension \
- * --packages org.apache.spark:spark-avro_2.12:3.1.2 \
* $HUDI_DIR/packaging/hudi-utilities-bundle/target/hudi-utilities-bundle_2.12-0.11.0-SNAPSHOT.jar \
* --mode undo \
* --base-path base_path \
diff --git a/packaging/hudi-integ-test-bundle/pom.xml b/packaging/hudi-integ-test-bundle/pom.xml
index 78e76a32d..06c679c2a 100644
--- a/packaging/hudi-integ-test-bundle/pom.xml
+++ b/packaging/hudi-integ-test-bundle/pom.xml
@@ -176,6 +176,13 @@
+
+
+ org.apache.spark.sql.avro.
+ org.apache.hudi.org.apache.spark.sql.avro.
+
com.beust.jcommander.
org.apache.hudi.com.beust.jcommander.
diff --git a/packaging/hudi-spark-bundle/pom.xml b/packaging/hudi-spark-bundle/pom.xml
index fd79a3460..e77f2e6f9 100644
--- a/packaging/hudi-spark-bundle/pom.xml
+++ b/packaging/hudi-spark-bundle/pom.xml
@@ -108,7 +108,6 @@
com.yammer.metrics:metrics-core
com.google.guava:guava
- org.apache.spark:spark-avro_${scala.binary.version}
org.apache.hive:hive-common
org.apache.hive:hive-service
org.apache.hive:hive-service-rpc
@@ -135,6 +134,13 @@
+
+
+ org.apache.spark.sql.avro.
+ org.apache.hudi.org.apache.spark.sql.avro.
+
com.yammer.metrics.
org.apache.hudi.com.yammer.metrics.
@@ -162,10 +168,6 @@
org.apache.htrace.
org.apache.hudi.org.apache.htrace.
-
- org.apache.spark.sql.avro.
- ${spark.bundle.spark.shade.prefix}org.apache.spark.sql.avro.
-
org.apache.hive.jdbc.
${spark.bundle.hive.shade.prefix}org.apache.hive.jdbc.
@@ -208,7 +210,7 @@
com.google.common.
- ${spark.bundle.spark.shade.prefix}com.google.common.
+ org.apache.hudi.com.google.common.
-
- org.apache.spark
- spark-avro_${scala.binary.version}
- ${spark.bundle.avro.scope}
-
-
org.apache.parquet
@@ -441,12 +436,5 @@
org.apache.hudi.
-
- spark-shade-unbundle-avro
-
- provided
-
-
-
diff --git a/packaging/hudi-utilities-bundle/pom.xml b/packaging/hudi-utilities-bundle/pom.xml
index c46a6d7d6..2c025a955 100644
--- a/packaging/hudi-utilities-bundle/pom.xml
+++ b/packaging/hudi-utilities-bundle/pom.xml
@@ -168,6 +168,13 @@
+
+
+ org.apache.spark.sql.avro.
+ org.apache.hudi.org.apache.spark.sql.avro.
+
com.yammer.metrics.
org.apache.hudi.com.yammer.metrics.
diff --git a/pom.xml b/pom.xml
index 0c6708571..86a42160f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -154,8 +154,6 @@
${project.basedir}
provided
- compile
- org.apache.hudi.spark.
provided
-Xmx2g
@@ -603,14 +601,6 @@
test
-
-
- org.apache.spark
- spark-avro_${scala.binary.version}
- ${spark.version}
- provided
-
-
org.apache.flink
@@ -1678,7 +1668,7 @@
spark3.1.x
- 3.1.2
+ 3.1.3
${spark3.version}
${spark3.version}
${scala12.version}