1
0

[HUDI-4186] Support Hudi with Spark 3.3.0 (#5943)

Co-authored-by: Shawn Chang <yxchang@amazon.com>
This commit is contained in:
Shawn Chang
2022-07-27 14:47:49 -07:00
committed by GitHub
parent 924c30c7ea
commit cdaec5a8da
99 changed files with 10865 additions and 104 deletions

View File

@@ -21,8 +21,9 @@ This repo contains the code that integrate Hudi with Spark. The repo is split in
`hudi-spark`
`hudi-spark2`
`hudi-spark3`
`hudi-spark3.1.x`
`hudi-spark3.2.x`
`hudi-spark3.3.x`
`hudi-spark2-common`
`hudi-spark3-common`
`hudi-spark-common`
@@ -30,8 +31,9 @@ This repo contains the code that integrate Hudi with Spark. The repo is split in
* hudi-spark is the module that contains the code that both spark2 & spark3 version would share, also contains the antlr4
file that supports spark sql on spark 2.x version.
* hudi-spark2 is the module that contains the code that compatible with spark 2.x versions.
* hudi-spark3 is the module that contains the code that compatible with spark 3.2.0(and above) versions。
* hudi-spark3.1.x is the module that contains the code that compatible with spark3.1.x and spark3.0.x version.
* hudi-spark3.1.x is the module that contains the code that compatible with spark3.1.x and spark3.0.x version.
* hudi-spark3.2.x is the module that contains the code that compatible with spark 3.2.x versions.
* hudi-spark3.3.x is the module that contains the code that compatible with spark 3.3.x+ versions.
* hudi-spark2-common is the module that contains the code that would be reused between spark2.x versions, right now the module
has no class since hudi only supports spark 2.4.4 version, and it acts as the placeholder when packaging hudi-spark-bundle module.
* hudi-spark3-common is the module that contains the code that would be reused between spark3.x versions.
@@ -50,7 +52,12 @@ has no class since hudi only supports spark 2.4.4 version, and it acts as the pl
| 3.1.2 | No |
| 3.2.0 | Yes |
### About upgrading Time Travel:
### To improve:
Spark3.3 support time travel syntax link [SPARK-37219](https://issues.apache.org/jira/browse/SPARK-37219).
Once Spark 3.3 released. The files in the following list will be removed:
* hudi-spark3's `HoodieSpark3_2ExtendedSqlAstBuilder.scala``HoodieSpark3_2ExtendedSqlParser.scala``TimeTravelRelation.scala``SqlBase.g4``HoodieSqlBase.g4`
* hudi-spark3.3.x's `HoodieSpark3_3ExtendedSqlAstBuilder.scala`, `HoodieSpark3_3ExtendedSqlParser.scala`, `TimeTravelRelation.scala`, `SqlBase.g4`, `HoodieSqlBase.g4`
Tracking Jira: [HUDI-4468](https://issues.apache.org/jira/browse/HUDI-4468)
Some other improvements undergoing:
* Port borrowed classes from Spark 3.3 [HUDI-4467](https://issues.apache.org/jira/browse/HUDI-4467)

View File

@@ -52,6 +52,8 @@ class BaseFileOnlyRelation(sqlContext: SQLContext,
globPaths: Seq[Path])
extends HoodieBaseRelation(sqlContext, metaClient, optParams, userSchema) with SparkAdapterSupport {
case class HoodieBaseFileSplit(filePartition: FilePartition) extends HoodieFileSplit
override type FileSplit = HoodieBaseFileSplit
// TODO(HUDI-3204) this is to override behavior (exclusively) for COW tables to always extract
@@ -97,7 +99,9 @@ class BaseFileOnlyRelation(sqlContext: SQLContext,
// back into the one expected by the caller
val projectedReader = projectReader(baseFileReader, requiredSchema.structTypeSchema)
new HoodieFileScanRDD(sparkSession, projectedReader.apply, fileSplits)
// SPARK-37273 FileScanRDD constructor changed in SPARK 3.3
sparkAdapter.createHoodieFileScanRDD(sparkSession, projectedReader.apply, fileSplits.map(_.filePartition), requiredSchema.structTypeSchema)
.asInstanceOf[HoodieUnsafeRDD]
}
protected def collectFileSplits(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[HoodieBaseFileSplit] = {

View File

@@ -44,15 +44,24 @@ import scala.collection.mutable.ListBuffer
object HoodieAnalysis {
type RuleBuilder = SparkSession => Rule[LogicalPlan]
def customOptimizerRules: Seq[RuleBuilder] =
def customOptimizerRules: Seq[RuleBuilder] = {
if (HoodieSparkUtils.gteqSpark3_1) {
val nestedSchemaPruningClass = "org.apache.spark.sql.execution.datasources.NestedSchemaPruning"
val nestedSchemaPruningRule = ReflectionUtils.loadClass(nestedSchemaPruningClass).asInstanceOf[Rule[LogicalPlan]]
val nestedSchemaPruningClass =
if (HoodieSparkUtils.gteqSpark3_3) {
"org.apache.spark.sql.execution.datasources.Spark33NestedSchemaPruning"
} else if (HoodieSparkUtils.gteqSpark3_2) {
"org.apache.spark.sql.execution.datasources.Spark32NestedSchemaPruning"
} else {
// spark 3.1
"org.apache.spark.sql.execution.datasources.Spark31NestedSchemaPruning"
}
val nestedSchemaPruningRule = ReflectionUtils.loadClass(nestedSchemaPruningClass).asInstanceOf[Rule[LogicalPlan]]
Seq(_ => nestedSchemaPruningRule)
} else {
Seq.empty
}
}
def customResolutionRules: Seq[RuleBuilder] = {
val rules: ListBuffer[RuleBuilder] = ListBuffer(
@@ -74,18 +83,21 @@ object HoodieAnalysis {
val spark3ResolveReferences: RuleBuilder =
session => ReflectionUtils.loadClass(spark3ResolveReferencesClass, session).asInstanceOf[Rule[LogicalPlan]]
val spark32ResolveAlterTableCommandsClass = "org.apache.spark.sql.hudi.ResolveHudiAlterTableCommandSpark32"
val spark32ResolveAlterTableCommands: RuleBuilder =
session => ReflectionUtils.loadClass(spark32ResolveAlterTableCommandsClass, session).asInstanceOf[Rule[LogicalPlan]]
val resolveAlterTableCommandsClass =
if (HoodieSparkUtils.gteqSpark3_3)
"org.apache.spark.sql.hudi.Spark33ResolveHudiAlterTableCommand"
else "org.apache.spark.sql.hudi.Spark32ResolveHudiAlterTableCommand"
val resolveAlterTableCommands: RuleBuilder =
session => ReflectionUtils.loadClass(resolveAlterTableCommandsClass, session).asInstanceOf[Rule[LogicalPlan]]
// NOTE: PLEASE READ CAREFULLY
//
// It's critical for this rules to follow in this order, so that DataSource V2 to V1 fallback
// is performed prior to other rules being evaluated
rules ++= Seq(dataSourceV2ToV1Fallback, spark3Analysis, spark3ResolveReferences, spark32ResolveAlterTableCommands)
rules ++= Seq(dataSourceV2ToV1Fallback, spark3Analysis, spark3ResolveReferences, resolveAlterTableCommands)
} else if (HoodieSparkUtils.gteqSpark3_1) {
val spark31ResolveAlterTableCommandsClass = "org.apache.spark.sql.hudi.ResolveHudiAlterTableCommand312"
val spark31ResolveAlterTableCommandsClass = "org.apache.spark.sql.hudi.Spark312ResolveHudiAlterTableCommand"
val spark31ResolveAlterTableCommands: RuleBuilder =
session => ReflectionUtils.loadClass(spark31ResolveAlterTableCommandsClass, session).asInstanceOf[Rule[LogicalPlan]]
@@ -421,12 +433,10 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
UpdateTable(table, resolvedAssignments, resolvedCondition)
// Resolve Delete Table
case DeleteFromTable(table, condition)
case dft @ DeleteFromTable(table, condition)
if sparkAdapter.isHoodieTable(table, sparkSession) && table.resolved =>
// Resolve condition
val resolvedCondition = condition.map(resolveExpressionFrom(table)(_))
// Return the resolved DeleteTable
DeleteFromTable(table, resolvedCondition)
val resolveExpression = resolveExpressionFrom(table, None)_
sparkAdapter.resolveDeleteFromTable(dft, resolveExpression)
// Append the meta field to the insert query to walk through the validate for the
// number of insert fields with the number of the target table fields.

View File

@@ -21,6 +21,7 @@ import org.apache.hudi.SparkAdapterSupport
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable
import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTable
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._
import org.apache.spark.sql.hudi.ProvidesHoodieConfig
@@ -36,9 +37,9 @@ case class DeleteHoodieTableCommand(deleteTable: DeleteFromTable) extends Hoodie
// Remove meta fields from the data frame
var df = removeMetaFields(Dataset.ofRows(sparkSession, table))
if (deleteTable.condition.isDefined) {
df = df.filter(Column(deleteTable.condition.get))
}
// SPARK-38626 DeleteFromTable.condition is changed from Option[Expression] to Expression in Spark 3.3
val condition = sparkAdapter.extractCondition(deleteTable)
if (condition != null) df = df.filter(Column(condition))
val hoodieCatalogTable = HoodieCatalogTable(sparkSession, tableId)
val config = buildHoodieDeleteTableConfig(hoodieCatalogTable, sparkSession)

View File

@@ -57,6 +57,14 @@ class HoodieCommonSqlParser(session: SparkSession, delegate: ParserInterface)
override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText)
/* SPARK-37266 Added parseQuery to ParserInterface in Spark 3.3.0. This is a patch to prevent
hackers from tampering text with persistent view, it won't be called in older Spark
Don't mark this as override for backward compatibility
Can't use sparkExtendedParser directly here due to the same reason */
def parseQuery(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
sparkAdapter.getQueryParserFromExtendedSqlParser(session, delegate, sqlText)
}
def parseRawDataType(sqlText : String) : DataType = {
throw new UnsupportedOperationException(s"Unsupported parseRawDataType method")
}

View File

@@ -139,9 +139,8 @@ class HoodieSparkSqlTestBase extends FunSuite with BeforeAndAfterAll {
try {
spark.sql(sql)
} catch {
case e: Throwable =>
assertResult(true)(e.getMessage.contains(errorMsg))
hasException = true
case e: Throwable if e.getMessage.contains(errorMsg) => hasException = true
case f: Throwable => fail("Exception should contain: " + errorMsg + ", error message: " + f.getMessage, f)
}
assertResult(true)(hasException)
}

View File

@@ -18,6 +18,7 @@
package org.apache.spark.sql.hudi
import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.HoodieSparkUtils
import org.apache.hudi.common.util.PartitionPathEncodeUtils
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.keygen.{ComplexKeyGenerator, SimpleKeyGenerator}
@@ -210,8 +211,14 @@ class TestAlterTableDropPartition extends HoodieSparkSqlTestBase {
spark.sql(s"""insert into $tableName values (1, "z3", "v1", "2021-10-01"), (2, "l4", "v1", "2021-10-02")""")
// specify duplicate partition columns
checkExceptionContain(s"alter table $tableName drop partition (dt='2021-10-01', dt='2021-10-02')")(
"Found duplicate keys 'dt'")
if (HoodieSparkUtils.gteqSpark3_3) {
checkExceptionContain(s"alter table $tableName drop partition (dt='2021-10-01', dt='2021-10-02')")(
"Found duplicate keys `dt`")
} else {
checkExceptionContain(s"alter table $tableName drop partition (dt='2021-10-01', dt='2021-10-02')")(
"Found duplicate keys 'dt'")
}
// drop 2021-10-01 partition
spark.sql(s"alter table $tableName drop partition (dt='2021-10-01')")

View File

@@ -31,6 +31,7 @@ class TestNestedSchemaPruningOptimization extends HoodieSparkSqlTestBase with Sp
val explainCommand = sparkAdapter.getCatalystPlanUtils.createExplainCommand(plan, extended = true)
executePlan(explainCommand)
.executeCollect()
.map(_.getString(0))
.mkString("\n")
}

View File

@@ -82,7 +82,11 @@ class TestCallCommandParser extends HoodieSparkSqlTestBase {
}
test("Test Call Parse Error") {
checkParseExceptionContain("CALL cat.system radish kebab")("mismatched input 'CALL' expecting")
if (HoodieSparkUtils.gteqSpark3_3) {
checkParseExceptionContain("CALL cat.system radish kebab")("Syntax error at or near 'CALL'")
} else {
checkParseExceptionContain("CALL cat.system radish kebab")("mismatched input 'CALL' expecting")
}
}
test("Test Call Produce with semicolon") {
@@ -110,9 +114,8 @@ class TestCallCommandParser extends HoodieSparkSqlTestBase {
try {
parser.parsePlan(sql)
} catch {
case e: Throwable =>
assertResult(true)(e.getMessage.contains(errorMsg))
hasException = true
case e: Throwable if e.getMessage.contains(errorMsg) => hasException = true
case f: Throwable => fail("Exception should contain: " + errorMsg + ", error message: " + f.getMessage, f)
}
assertResult(true)(hasException)
}

View File

@@ -18,16 +18,17 @@
package org.apache.hudi
import org.apache.hudi.HoodieUnsafeRDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.types.StructType
case class HoodieBaseFileSplit(filePartition: FilePartition) extends HoodieFileSplit
class HoodieFileScanRDD(@transient private val sparkSession: SparkSession,
read: PartitionedFile => Iterator[InternalRow],
@transient fileSplits: Seq[HoodieBaseFileSplit])
extends FileScanRDD(sparkSession, read, fileSplits.map(_.filePartition))
class Spark2HoodieFileScanRDD(@transient private val sparkSession: SparkSession,
read: PartitionedFile => Iterator[InternalRow],
@transient filePartitions: Seq[FilePartition])
extends FileScanRDD(sparkSession, read, filePartitions)
with HoodieUnsafeRDD {
override final def collect(): Array[InternalRow] = super[HoodieUnsafeRDD].collect()

View File

@@ -19,22 +19,23 @@
package org.apache.spark.sql.adapter
import org.apache.avro.Schema
import org.apache.hudi.Spark2RowSerDe
import org.apache.hudi.{Spark2HoodieFileScanRDD, Spark2RowSerDe}
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.spark.sql.avro._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedPredicate}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedPredicate}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Join, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{Command, InsertIntoTable, Join, LogicalPlan, DeleteFromTable}
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Spark24HoodieParquetFileFormat}
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile, Spark2ParsePartitionUtil, SparkParsePartitionUtil}
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile, Spark2ParsePartitionUtil, SparkParsePartitionUtil}
import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.hudi.parser.HoodieSpark2ExtendedSqlParser
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieCatalystPlansUtils, HoodieSpark2CatalystExpressionUtils, HoodieSpark2CatalystPlanUtils, Row, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel._
@@ -122,6 +123,30 @@ class Spark2Adapter extends SparkAdapter {
InterpretedPredicate.create(e)
}
override def createHoodieFileScanRDD(sparkSession: SparkSession,
readFunction: PartitionedFile => Iterator[InternalRow],
filePartitions: Seq[FilePartition],
readDataSchema: StructType,
metadataColumns: Seq[AttributeReference] = Seq.empty): FileScanRDD = {
new Spark2HoodieFileScanRDD(sparkSession, readFunction, filePartitions)
}
override def resolveDeleteFromTable(deleteFromTable: Command,
resolveExpression: Expression => Expression): DeleteFromTable = {
val deleteFromTableCommand = deleteFromTable.asInstanceOf[DeleteFromTable]
val resolvedCondition = deleteFromTableCommand.condition.map(resolveExpression)
DeleteFromTable(deleteFromTableCommand.table, resolvedCondition)
}
override def extractCondition(deleteFromTable: Command): Expression = {
deleteFromTable.asInstanceOf[DeleteFromTable].condition.getOrElse(null)
}
override def getQueryParserFromExtendedSqlParser(session: SparkSession, delegate: ParserInterface,
sqlText: String): LogicalPlan = {
throw new UnsupportedOperationException(s"Unsupported parseQuery method in Spark earlier than Spark 3.3.0")
}
override def convertStorageLevelToString(level: StorageLevel): String = level match {
case NONE => "NONE"
case DISK_ONLY => "DISK_ONLY"

View File

@@ -52,7 +52,7 @@ public class ReflectUtil {
public static DateFormatter getDateFormatter(ZoneId zoneId) {
try {
ClassLoader loader = Thread.currentThread().getContextClassLoader();
if (HoodieSparkUtils.isSpark3_2()) {
if (HoodieSparkUtils.gteqSpark3_2()) {
Class clazz = loader.loadClass(DateFormatter.class.getName());
Method applyMethod = clazz.getDeclaredMethod("apply");
applyMethod.setAccessible(true);

View File

@@ -19,7 +19,6 @@ package org.apache.spark.sql.adapter
import org.apache.hudi.Spark3RowSerDe
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.spark.SPARK_VERSION
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.{HoodieAvroSchemaConverters, HoodieSparkAvroSchemaConverters}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -83,22 +82,6 @@ abstract class BaseSpark3Adapter extends SparkAdapter with Logging {
}
}
override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = {
// since spark3.2.1 support datasourceV2, so we need to a new SqlParser to deal DDL statment
if (SPARK_VERSION.startsWith("3.1")) {
val loadClassName = "org.apache.spark.sql.parser.HoodieSpark312ExtendedSqlParser"
Some {
(spark: SparkSession, delegate: ParserInterface) => {
val clazz = Class.forName(loadClassName, true, Thread.currentThread().getContextClassLoader)
val ctor = clazz.getConstructors.head
ctor.newInstance(spark, delegate).asInstanceOf[ParserInterface]
}
}
} else {
None
}
}
override def createInterpretedPredicate(e: Expression): InterpretedPredicate = {
Predicate.createInterpreted(e)
}

View File

@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi
import org.apache.hudi.HoodieUnsafeRDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.types.StructType
class Spark31HoodieFileScanRDD(@transient private val sparkSession: SparkSession,
read: PartitionedFile => Iterator[InternalRow],
@transient filePartitions: Seq[FilePartition])
extends FileScanRDD(sparkSession, read, filePartitions)
with HoodieUnsafeRDD {
override final def collect(): Array[InternalRow] = super[HoodieUnsafeRDD].collect()
}

View File

@@ -18,12 +18,19 @@
package org.apache.spark.sql.adapter
import org.apache.hudi.Spark31HoodieFileScanRDD
import org.apache.avro.Schema
import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieCatalystPlansUtils, HoodieSpark31CatalystExpressionUtils, HoodieSpark31CatalystPlanUtils}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.avro.{HoodieAvroDeserializer, HoodieAvroSerializer, HoodieSpark3_1AvroDeserializer, HoodieSpark3_1AvroSerializer}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.parser.HoodieSpark3_1ExtendedSqlParser
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Spark31HoodieParquetFileFormat}
import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieCatalystPlansUtils, HoodieSpark31CatalystExpressionUtils, HoodieSpark31CatalystPlanUtils, SparkSession}
/**
* Implementation of [[SparkAdapter]] for Spark 3.1.x
@@ -40,7 +47,33 @@ class Spark3_1Adapter extends BaseSpark3Adapter {
override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
new HoodieSpark3_1AvroDeserializer(rootAvroType, rootCatalystType)
override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = {
// since spark3.2.1 support datasourceV2, so we need to a new SqlParser to deal DDL statment
Some(
(spark: SparkSession, delegate: ParserInterface) => new HoodieSpark3_1ExtendedSqlParser(spark, delegate)
)
}
override def createHoodieParquetFileFormat(appendPartitionValues: Boolean): Option[ParquetFileFormat] = {
Some(new Spark31HoodieParquetFileFormat(appendPartitionValues))
}
override def createHoodieFileScanRDD(sparkSession: SparkSession,
readFunction: PartitionedFile => Iterator[InternalRow],
filePartitions: Seq[FilePartition],
readDataSchema: StructType,
metadataColumns: Seq[AttributeReference] = Seq.empty): FileScanRDD = {
new Spark31HoodieFileScanRDD(sparkSession, readFunction, filePartitions)
}
override def resolveDeleteFromTable(deleteFromTable: Command,
resolveExpression: Expression => Expression): DeleteFromTable = {
val deleteFromTableCommand = deleteFromTable.asInstanceOf[DeleteFromTable]
val resolvedCondition = deleteFromTableCommand.condition.map(resolveExpression)
DeleteFromTable(deleteFromTableCommand.table, resolvedCondition)
}
override def extractCondition(deleteFromTable: Command): Expression = {
deleteFromTable.asInstanceOf[DeleteFromTable].condition.getOrElse(null)
}
}

View File

@@ -37,7 +37,7 @@ import org.apache.spark.sql.util.SchemaUtils.restoreOriginalOutputNames
* NOTE: This class is borrowed from Spark 3.2.1, with modifications adapting it to handle [[HoodieBaseRelation]],
* instead of [[HadoopFsRelation]]
*/
class NestedSchemaPruning extends Rule[LogicalPlan] {
class Spark31NestedSchemaPruning extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.expressions.SchemaPruning._
override def apply(plan: LogicalPlan): LogicalPlan =

View File

@@ -39,7 +39,7 @@ import scala.collection.mutable
* for alter table column commands.
* TODO: we should remove this file when we support datasourceV2 for hoodie on spark3.1x
*/
case class ResolveHudiAlterTableCommand312(sparkSession: SparkSession) extends Rule[LogicalPlan] {
case class Spark312ResolveHudiAlterTableCommand(sparkSession: SparkSession) extends Rule[LogicalPlan] {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case add @ HoodieAlterTableAddColumnsStatement(asTable(table), cols) =>

View File

@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.execution.{SparkSqlAstBuilder, SparkSqlParser}
// TODO: we should remove this file when we support datasourceV2 for hoodie on spark3.1x
class HoodieSpark312ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) extends SparkSqlParser with Logging {
class HoodieSpark3_1ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) extends SparkSqlParser with Logging {
override val astBuilder: SparkSqlAstBuilder = new HoodieSpark312SqlAstBuilder(session)
}

View File

@@ -0,0 +1,335 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>hudi-spark-datasource</artifactId>
<groupId>org.apache.hudi</groupId>
<version>0.12.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>hudi-spark3.2.x_2.12</artifactId>
<version>0.12.0-SNAPSHOT</version>
<name>hudi-spark3.2.x_2.12</name>
<packaging>jar</packaging>
<properties>
<main.basedir>${project.parent.parent.basedir}</main.basedir>
</properties>
<build>
<resources>
<resource>
<directory>src/main/resources</directory>
</resource>
</resources>
<pluginManagement>
<plugins>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>${scala-maven-plugin.version}</version>
<configuration>
<args>
<arg>-nobootcp</arg>
</args>
<checkMultipleScalaVersions>false</checkMultipleScalaVersions>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
</plugins>
</pluginManagement>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<executions>
<execution>
<id>copy-dependencies</id>
<phase>prepare-package</phase>
<goals>
<goal>copy-dependencies</goal>
</goals>
<configuration>
<outputDirectory>${project.build.directory}/lib</outputDirectory>
<overWriteReleases>true</overWriteReleases>
<overWriteSnapshots>true</overWriteSnapshots>
<overWriteIfNewer>true</overWriteIfNewer>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<args>
<arg>-nobootcp</arg>
<arg>-target:jvm-1.8</arg>
</args>
</configuration>
<executions>
<execution>
<id>scala-compile-first</id>
<phase>process-resources</phase>
<goals>
<goal>add-source</goal>
<goal>compile</goal>
</goals>
</execution>
<execution>
<id>scala-test-compile</id>
<phase>process-test-resources</phase>
<goals>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<executions>
<execution>
<phase>compile</phase>
<goals>
<goal>compile</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>test-jar</goal>
</goals>
<phase>test-compile</phase>
</execution>
</executions>
<configuration>
<skip>false</skip>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skipTests>${skip.hudi-spark3.unit.tests}</skipTests>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.rat</groupId>
<artifactId>apache-rat-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.antlr</groupId>
<artifactId>antlr4-maven-plugin</artifactId>
<version>${antlr.version}</version>
<executions>
<execution>
<goals>
<goal>antlr4</goal>
</goals>
</execution>
</executions>
<configuration>
<visitor>true</visitor>
<listener>true</listener>
<sourceDirectory>../hudi-spark3.2.x/src/main/antlr4</sourceDirectory>
<libDirectory>../hudi-spark3.2.x/src/main/antlr4/imports</libDirectory>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala12.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>${spark32.version}</version>
<scope>provided</scope>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_2.12</artifactId>
<version>${spark32.version}</version>
<scope>provided</scope>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>${spark32.version}</version>
<scope>provided</scope>
<optional>true</optional>
<exclusions>
<exclusion>
<groupId>*</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${fasterxml.spark3.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>${fasterxml.spark3.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>${fasterxml.spark3.version}</version>
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark-client</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark-common_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.json4s</groupId>
<artifactId>json4s-jackson_${scala.binary.version}</artifactId>
<version>3.7.0-M11</version>
<exclusions>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark3-common</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<!-- Hoodie - Test -->
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-client-common</artifactId>
<version>${project.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark-client</artifactId>
<version>${project.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-common</artifactId>
<version>${project.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark-common_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi
import org.apache.hudi.HoodieUnsafeRDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.types.StructType
class Spark32HoodieFileScanRDD(@transient private val sparkSession: SparkSession,
read: PartitionedFile => Iterator[InternalRow],
@transient filePartitions: Seq[FilePartition])
extends FileScanRDD(sparkSession, read, filePartitions)
with HoodieUnsafeRDD {
override final def collect(): Array[InternalRow] = super[HoodieUnsafeRDD].collect()
}

View File

@@ -17,12 +17,17 @@
package org.apache.spark.sql.adapter
import org.apache.hudi.Spark32HoodieFileScanRDD
import org.apache.avro.Schema
import org.apache.spark.sql.avro._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{Command, DeleteFromTable, LogicalPlan}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Spark32HoodieParquetFileFormat}
import org.apache.spark.sql.parser.HoodieSpark3_2ExtendedSqlParser
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql._
/**
@@ -30,16 +35,16 @@ import org.apache.spark.sql._
*/
class Spark3_2Adapter extends BaseSpark3Adapter {
override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark32CatalystExpressionUtils
override def getCatalystPlanUtils: HoodieCatalystPlansUtils = HoodieSpark32CatalystPlanUtils
override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer =
new HoodieSpark3_2AvroSerializer(rootCatalystType, rootAvroType, nullable)
override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
new HoodieSpark3_2AvroDeserializer(rootAvroType, rootCatalystType)
override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark32CatalystExpressionUtils
override def getCatalystPlanUtils: HoodieCatalystPlansUtils = HoodieSpark32CatalystPlanUtils
override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = {
Some(
(spark: SparkSession, delegate: ParserInterface) => new HoodieSpark3_2ExtendedSqlParser(spark, delegate)
@@ -49,4 +54,23 @@ class Spark3_2Adapter extends BaseSpark3Adapter {
override def createHoodieParquetFileFormat(appendPartitionValues: Boolean): Option[ParquetFileFormat] = {
Some(new Spark32HoodieParquetFileFormat(appendPartitionValues))
}
override def createHoodieFileScanRDD(sparkSession: SparkSession,
readFunction: PartitionedFile => Iterator[InternalRow],
filePartitions: Seq[FilePartition],
readDataSchema: StructType,
metadataColumns: Seq[AttributeReference] = Seq.empty): FileScanRDD = {
new Spark32HoodieFileScanRDD(sparkSession, readFunction, filePartitions)
}
override def resolveDeleteFromTable(deleteFromTable: Command,
resolveExpression: Expression => Expression): DeleteFromTable = {
val deleteFromTableCommand = deleteFromTable.asInstanceOf[DeleteFromTable]
val resolvedCondition = deleteFromTableCommand.condition.map(resolveExpression)
DeleteFromTable(deleteFromTableCommand.table, resolvedCondition)
}
override def extractCondition(deleteFromTable: Command): Expression = {
deleteFromTable.asInstanceOf[DeleteFromTable].condition.getOrElse(null)
}
}

View File

@@ -0,0 +1,197 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources
import org.apache.hudi.HoodieBaseRelation
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, Expression, NamedExpression, ProjectionOverSchema}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.util.SchemaUtils.restoreOriginalOutputNames
/**
* Prunes unnecessary physical columns given a [[PhysicalOperation]] over a data source relation.
* By "physical column", we mean a column as defined in the data source format like Parquet format
* or ORC format. For example, in Spark SQL, a root-level Parquet column corresponds to a SQL
* column, and a nested Parquet column corresponds to a [[StructField]].
*
* NOTE: This class is borrowed from Spark 3.2.1, with modifications adapting it to handle [[HoodieBaseRelation]],
* instead of [[HadoopFsRelation]]
*/
class Spark32NestedSchemaPruning extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.expressions.SchemaPruning._
override def apply(plan: LogicalPlan): LogicalPlan =
if (conf.nestedSchemaPruningEnabled) {
apply0(plan)
} else {
plan
}
private def apply0(plan: LogicalPlan): LogicalPlan =
plan transformDown {
case op @ PhysicalOperation(projects, filters,
// NOTE: This is modified to accommodate for Hudi's custom relations, given that original
// [[NestedSchemaPruning]] rule is tightly coupled w/ [[HadoopFsRelation]]
// TODO generalize to any file-based relation
l @ LogicalRelation(relation: HoodieBaseRelation, _, _, _))
if relation.canPruneRelationSchema =>
prunePhysicalColumns(l.output, projects, filters, relation.dataSchema,
prunedDataSchema => {
val prunedRelation =
relation.updatePrunedDataSchema(prunedSchema = prunedDataSchema)
buildPrunedRelation(l, prunedRelation)
}).getOrElse(op)
}
/**
* This method returns optional logical plan. `None` is returned if no nested field is required or
* all nested fields are required.
*/
private def prunePhysicalColumns(output: Seq[AttributeReference],
projects: Seq[NamedExpression],
filters: Seq[Expression],
dataSchema: StructType,
outputRelationBuilder: StructType => LogicalRelation): Option[LogicalPlan] = {
val (normalizedProjects, normalizedFilters) =
normalizeAttributeRefNames(output, projects, filters)
val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)
// If requestedRootFields includes a nested field, continue. Otherwise,
// return op
if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) {
val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields)
// If the data schema is different from the pruned data schema, continue. Otherwise,
// return op. We effect this comparison by counting the number of "leaf" fields in
// each schemata, assuming the fields in prunedDataSchema are a subset of the fields
// in dataSchema.
if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
val prunedRelation = outputRelationBuilder(prunedDataSchema)
val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
prunedRelation, projectionOverSchema))
} else {
None
}
} else {
None
}
}
/**
* Normalizes the names of the attribute references in the given projects and filters to reflect
* the names in the given logical relation. This makes it possible to compare attributes and
* fields by name. Returns a tuple with the normalized projects and filters, respectively.
*/
private def normalizeAttributeRefNames(output: Seq[AttributeReference],
projects: Seq[NamedExpression],
filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = {
val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap
val normalizedProjects = projects.map(_.transform {
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
}).map { case expr: NamedExpression => expr }
val normalizedFilters = filters.map(_.transform {
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
})
(normalizedProjects, normalizedFilters)
}
/**
* Builds the new output [[Project]] Spark SQL operator that has the `leafNode`.
*/
private def buildNewProjection(projects: Seq[NamedExpression],
normalizedProjects: Seq[NamedExpression],
filters: Seq[Expression],
prunedRelation: LogicalRelation,
projectionOverSchema: ProjectionOverSchema): Project = {
// Construct a new target for our projection by rewriting and
// including the original filters where available
val projectionChild =
if (filters.nonEmpty) {
val projectedFilters = filters.map(_.transformDown {
case projectionOverSchema(expr) => expr
})
val newFilterCondition = projectedFilters.reduce(And)
Filter(newFilterCondition, prunedRelation)
} else {
prunedRelation
}
// Construct the new projections of our Project by
// rewriting the original projections
val newProjects = normalizedProjects.map(_.transformDown {
case projectionOverSchema(expr) => expr
}).map { case expr: NamedExpression => expr }
if (log.isDebugEnabled) {
logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}")
}
Project(restoreOriginalOutputNames(newProjects, projects.map(_.name)), projectionChild)
}
/**
* Builds a pruned logical relation from the output of the output relation and the schema of the
* pruned base relation.
*/
private def buildPrunedRelation(outputRelation: LogicalRelation,
prunedBaseRelation: BaseRelation): LogicalRelation = {
val prunedOutput = getPrunedOutput(outputRelation.output, prunedBaseRelation.schema)
outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
}
// Prune the given output to make it consistent with `requiredSchema`.
private def getPrunedOutput(output: Seq[AttributeReference],
requiredSchema: StructType): Seq[AttributeReference] = {
// We need to replace the expression ids of the pruned relation output attributes
// with the expression ids of the original relation output attributes so that
// references to the original relation's output are not broken
val outputIdMap = output.map(att => (att.name, att.exprId)).toMap
requiredSchema
.toAttributes
.map {
case att if outputIdMap.contains(att.name) =>
att.withExprId(outputIdMap(att.name))
case att => att
}
}
/**
* Counts the "leaf" fields of the given dataType. Informally, this is the
* number of fields of non-complex data type in the tree representation of
* [[DataType]].
*/
private def countLeaves(dataType: DataType): Int = {
dataType match {
case array: ArrayType => countLeaves(array.elementType)
case map: MapType => countLeaves(map.keyType) + countLeaves(map.valueType)
case struct: StructType =>
struct.map(field => countLeaves(field.dataType)).sum
case _ => 1
}
}
}

View File

@@ -31,7 +31,7 @@ import org.apache.spark.sql.hudi.command.{AlterTableCommand => HudiAlterTableCom
* Rule to mostly resolve, normalize and rewrite column names based on case sensitivity.
* for alter table column commands.
*/
class ResolveHudiAlterTableCommandSpark32(sparkSession: SparkSession) extends Rule[LogicalPlan] {
class Spark32ResolveHudiAlterTableCommand(sparkSession: SparkSession) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
if (schemaEvolutionEnabled) {

View File

@@ -0,0 +1,367 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.catalog
import org.apache.hadoop.fs.Path
import org.apache.hudi.exception.HoodieException
import org.apache.hudi.sql.InsertMode
import org.apache.hudi.sync.common.util.ConfigUtils
import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions, SparkAdapterSupport}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable.needFilterProps
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, HoodieCatalogTable}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChange, UpdateColumnComment, UpdateColumnType}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.hudi.analysis.HoodieV1OrV2Table
import org.apache.spark.sql.hudi.catalog.HoodieCatalog.convertTransforms
import org.apache.spark.sql.hudi.command._
import org.apache.spark.sql.hudi.{HoodieSqlCommonUtils, ProvidesHoodieConfig}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.{Dataset, SaveMode, SparkSession, _}
import java.net.URI
import java.util
import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConverter}
import scala.collection.mutable
class HoodieCatalog extends DelegatingCatalogExtension
with StagingTableCatalog
with SparkAdapterSupport
with ProvidesHoodieConfig {
val spark: SparkSession = SparkSession.active
override def stageCreate(ident: Identifier, schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): StagedTable = {
if (sparkAdapter.isHoodieTable(properties)) {
val locUriAndTableType = deduceTableLocationURIAndTableType(ident, properties)
HoodieStagedTable(ident, locUriAndTableType, this, schema, partitions,
properties, TableCreationMode.STAGE_CREATE)
} else {
BasicStagedTable(
ident,
super.createTable(ident, schema, partitions, properties),
this)
}
}
override def stageReplace(ident: Identifier, schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): StagedTable = {
if (sparkAdapter.isHoodieTable(properties)) {
val locUriAndTableType = deduceTableLocationURIAndTableType(ident, properties)
HoodieStagedTable(ident, locUriAndTableType, this, schema, partitions,
properties, TableCreationMode.STAGE_REPLACE)
} else {
super.dropTable(ident)
BasicStagedTable(
ident,
super.createTable(ident, schema, partitions, properties),
this)
}
}
override def stageCreateOrReplace(ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): StagedTable = {
if (sparkAdapter.isHoodieTable(properties)) {
val locUriAndTableType = deduceTableLocationURIAndTableType(ident, properties)
HoodieStagedTable(ident, locUriAndTableType, this, schema, partitions,
properties, TableCreationMode.CREATE_OR_REPLACE)
} else {
try super.dropTable(ident) catch {
case _: NoSuchTableException => // ignore the exception
}
BasicStagedTable(
ident,
super.createTable(ident, schema, partitions, properties),
this)
}
}
override def loadTable(ident: Identifier): Table = {
super.loadTable(ident) match {
case V1Table(catalogTable0) if sparkAdapter.isHoodieTable(catalogTable0) =>
val catalogTable = catalogTable0.comment match {
case Some(v) =>
val newProps = catalogTable0.properties + (TableCatalog.PROP_COMMENT -> v)
catalogTable0.copy(properties = newProps)
case _ =>
catalogTable0
}
val v2Table = HoodieInternalV2Table(
spark = spark,
path = catalogTable.location.toString,
catalogTable = Some(catalogTable),
tableIdentifier = Some(ident.toString))
val schemaEvolutionEnabled: Boolean = spark.sessionState.conf.getConfString(DataSourceReadOptions.SCHEMA_EVOLUTION_ENABLED.key,
DataSourceReadOptions.SCHEMA_EVOLUTION_ENABLED.defaultValue.toString).toBoolean
// NOTE: PLEASE READ CAREFULLY
//
// Since Hudi relations don't currently implement DS V2 Read API, we by default fallback to V1 here.
// Such fallback will have considerable performance impact, therefore it's only performed in cases
// where V2 API have to be used. Currently only such use-case is using of Schema Evolution feature
//
// Check out HUDI-4178 for more details
if (schemaEvolutionEnabled) {
v2Table
} else {
v2Table.v1TableWrapper
}
case t => t
}
}
override def createTable(ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
if (sparkAdapter.isHoodieTable(properties)) {
val locUriAndTableType = deduceTableLocationURIAndTableType(ident, properties)
createHoodieTable(ident, schema, locUriAndTableType, partitions, properties,
Map.empty, Option.empty, TableCreationMode.CREATE)
} else {
super.createTable(ident, schema, partitions, properties)
}
}
override def tableExists(ident: Identifier): Boolean = super.tableExists(ident)
override def dropTable(ident: Identifier): Boolean = {
val table = loadTable(ident)
table match {
case HoodieV1OrV2Table(_) =>
DropHoodieTableCommand(ident.asTableIdentifier, ifExists = true, isView = false, purge = false).run(spark)
true
case _ => super.dropTable(ident)
}
}
override def purgeTable(ident: Identifier): Boolean = {
val table = loadTable(ident)
table match {
case HoodieV1OrV2Table(_) =>
DropHoodieTableCommand(ident.asTableIdentifier, ifExists = true, isView = false, purge = true).run(spark)
true
case _ => super.purgeTable(ident)
}
}
@throws[NoSuchTableException]
@throws[TableAlreadyExistsException]
override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = {
loadTable(oldIdent) match {
case HoodieV1OrV2Table(_) =>
AlterHoodieTableRenameCommand(oldIdent.asTableIdentifier, newIdent.asTableIdentifier, false).run(spark)
case _ => super.renameTable(oldIdent, newIdent)
}
}
override def alterTable(ident: Identifier, changes: TableChange*): Table = {
loadTable(ident) match {
case HoodieV1OrV2Table(table) => {
val tableIdent = TableIdentifier(ident.name(), ident.namespace().lastOption)
changes.groupBy(c => c.getClass).foreach {
case (t, newColumns) if t == classOf[AddColumn] =>
AlterHoodieTableAddColumnsCommand(
tableIdent,
newColumns.asInstanceOf[Seq[AddColumn]].map { col =>
StructField(
col.fieldNames()(0),
col.dataType(),
col.isNullable)
}).run(spark)
case (t, columnChanges) if classOf[ColumnChange].isAssignableFrom(t) =>
columnChanges.foreach {
case dataType: UpdateColumnType =>
val colName = UnresolvedAttribute(dataType.fieldNames()).name
val newDataType = dataType.newDataType()
val structField = StructField(colName, newDataType)
AlterHoodieTableChangeColumnCommand(tableIdent, colName, structField).run(spark)
case dataType: UpdateColumnComment =>
val newComment = dataType.newComment()
val colName = UnresolvedAttribute(dataType.fieldNames()).name
val fieldOpt = table.schema.findNestedField(dataType.fieldNames(), includeCollections = true,
spark.sessionState.conf.resolver).map(_._2)
val field = fieldOpt.getOrElse {
throw new AnalysisException(
s"Couldn't find column $colName in:\n${table.schema.treeString}")
}
AlterHoodieTableChangeColumnCommand(tableIdent, colName, field.withComment(newComment)).run(spark)
}
case (t, _) =>
throw new UnsupportedOperationException(s"not supported table change: ${t.getClass}")
}
loadTable(ident)
}
case _ => super.alterTable(ident, changes: _*)
}
}
private def deduceTableLocationURIAndTableType(
ident: Identifier, properties: util.Map[String, String]): (URI, CatalogTableType) = {
val locOpt = if (isPathIdentifier(ident)) {
Option(ident.name())
} else {
Option(properties.get("location"))
}
val tableType = if (locOpt.nonEmpty) {
CatalogTableType.EXTERNAL
} else {
CatalogTableType.MANAGED
}
val locUriOpt = locOpt.map(CatalogUtils.stringToURI)
val tableIdent = ident.asTableIdentifier
val existingTableOpt = getExistingTableIfExists(tableIdent)
val locURI = locUriOpt
.orElse(existingTableOpt.flatMap(_.storage.locationUri))
.getOrElse(spark.sessionState.catalog.defaultTablePath(tableIdent))
(locURI, tableType)
}
def createHoodieTable(ident: Identifier,
schema: StructType,
locUriAndTableType: (URI, CatalogTableType),
partitions: Array[Transform],
allTableProperties: util.Map[String, String],
writeOptions: Map[String, String],
sourceQuery: Option[DataFrame],
operation: TableCreationMode): Table = {
val (partitionColumns, maybeBucketSpec) = HoodieCatalog.convertTransforms(partitions)
val newSchema = schema
val newPartitionColumns = partitionColumns
val newBucketSpec = maybeBucketSpec
val storage = DataSource.buildStorageFormatFromOptions(writeOptions.--(needFilterProps))
.copy(locationUri = Option(locUriAndTableType._1))
val commentOpt = Option(allTableProperties.get("comment"))
val tablePropertiesNew = new util.HashMap[String, String](allTableProperties)
// put path to table properties.
tablePropertiesNew.put("path", locUriAndTableType._1.getPath)
val tableDesc = new CatalogTable(
identifier = ident.asTableIdentifier,
tableType = locUriAndTableType._2,
storage = storage,
schema = newSchema,
provider = Option("hudi"),
partitionColumnNames = newPartitionColumns,
bucketSpec = newBucketSpec,
properties = tablePropertiesNew.asScala.toMap.--(needFilterProps),
comment = commentOpt)
val hoodieCatalogTable = HoodieCatalogTable(spark, tableDesc)
if (operation == TableCreationMode.STAGE_CREATE) {
val tablePath = hoodieCatalogTable.tableLocation
val hadoopConf = spark.sessionState.newHadoopConf()
assert(HoodieSqlCommonUtils.isEmptyPath(tablePath, hadoopConf),
s"Path '$tablePath' should be empty for CTAS")
hoodieCatalogTable.initHoodieTable()
val tblProperties = hoodieCatalogTable.catalogProperties
val options = Map(
DataSourceWriteOptions.HIVE_CREATE_MANAGED_TABLE.key -> (tableDesc.tableType == CatalogTableType.MANAGED).toString,
DataSourceWriteOptions.HIVE_TABLE_SERDE_PROPERTIES.key -> ConfigUtils.configToString(tblProperties.asJava),
DataSourceWriteOptions.HIVE_TABLE_PROPERTIES.key -> ConfigUtils.configToString(tableDesc.properties.asJava),
DataSourceWriteOptions.SQL_INSERT_MODE.key -> InsertMode.NON_STRICT.value(),
DataSourceWriteOptions.SQL_ENABLE_BULK_INSERT.key -> "true"
)
saveSourceDF(sourceQuery, tableDesc.properties ++ buildHoodieInsertConfig(hoodieCatalogTable, spark, isOverwrite = false, Map.empty, options))
CreateHoodieTableCommand.createTableInCatalog(spark, hoodieCatalogTable, ignoreIfExists = false)
} else if (sourceQuery.isEmpty) {
saveSourceDF(sourceQuery, tableDesc.properties)
new CreateHoodieTableCommand(tableDesc, false).run(spark)
} else {
saveSourceDF(sourceQuery, tableDesc.properties ++ buildHoodieInsertConfig(hoodieCatalogTable, spark, isOverwrite = false, Map.empty, Map.empty))
new CreateHoodieTableCommand(tableDesc, false).run(spark)
}
loadTable(ident)
}
private def isPathIdentifier(ident: Identifier) = new Path(ident.name()).isAbsolute
protected def isPathIdentifier(table: CatalogTable): Boolean = {
isPathIdentifier(table.identifier)
}
protected def isPathIdentifier(tableIdentifier: TableIdentifier): Boolean = {
isPathIdentifier(HoodieIdentifier(tableIdentifier.database.toArray, tableIdentifier.table))
}
private def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = {
// If this is a path identifier, we cannot return an existing CatalogTable. The Create command
// will check the file system itself
val catalog = spark.sessionState.catalog
// scalastyle:off
if (isPathIdentifier(table)) return None
// scalastyle:on
val tableExists = catalog.tableExists(table)
if (tableExists) {
val oldTable = catalog.getTableMetadata(table)
if (oldTable.tableType == CatalogTableType.VIEW) throw new HoodieException(
s"$table is a view. You may not write data into a view.")
if (!sparkAdapter.isHoodieTable(oldTable)) throw new HoodieException(s"$table is not a Hoodie table.")
Some(oldTable)
} else None
}
private def saveSourceDF(sourceQuery: Option[Dataset[_]],
properties: Map[String, String]): Unit = {
sourceQuery.map(df => {
df.write.format("org.apache.hudi")
.options(properties)
.mode(SaveMode.Append)
.save()
df
})
}
}
object HoodieCatalog {
def convertTransforms(partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = {
val identityCols = new mutable.ArrayBuffer[String]
var bucketSpec = Option.empty[BucketSpec]
partitions.map {
case IdentityTransform(FieldReference(Seq(col))) =>
identityCols += col
case BucketTransform(numBuckets, FieldReference(Seq(col))) =>
bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil))
case _ =>
throw new HoodieException(s"Partitioning by expressions is not supported.")
}
(identityCols, bucketSpec)
}
}

View File

@@ -21,10 +21,10 @@
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>hudi-spark3_2.12</artifactId>
<artifactId>hudi-spark3.3.x_2.12</artifactId>
<version>0.12.0-SNAPSHOT</version>
<name>hudi-spark3_2.12</name>
<name>hudi-spark3.3.x_2.12</name>
<packaging>jar</packaging>
<properties>
@@ -164,8 +164,8 @@
<configuration>
<visitor>true</visitor>
<listener>true</listener>
<sourceDirectory>../hudi-spark3/src/main/antlr4</sourceDirectory>
<libDirectory>../hudi-spark3/src/main/antlr4/imports</libDirectory>
<sourceDirectory>../hudi-spark3.3.x/src/main/antlr4</sourceDirectory>
<libDirectory>../hudi-spark3.3.x/src/main/antlr4/imports</libDirectory>
</configuration>
</plugin>
</plugins>
@@ -181,7 +181,7 @@
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>${spark32.version}</version>
<version>${spark33.version}</version>
<scope>provided</scope>
<optional>true</optional>
</dependency>
@@ -189,7 +189,7 @@
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_2.12</artifactId>
<version>${spark32.version}</version>
<version>${spark33.version}</version>
<scope>provided</scope>
<optional>true</optional>
</dependency>
@@ -197,7 +197,7 @@
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>${spark32.version}</version>
<version>${spark33.version}</version>
<scope>provided</scope>
<optional>true</optional>
<exclusions>

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
grammar HoodieSqlBase;
import SqlBase;
singleStatement
: statement EOF
;
statement
: query #queryStatement
| ctes? dmlStatementNoWith #dmlStatement
| createTableHeader ('(' colTypeList ')')? tableProvider?
createTableClauses
(AS? query)? #createTable
| .*? #passThrough
;

View File

@@ -0,0 +1,191 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.parquet;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hudi.client.utils.SparkInternalSchemaConverter;
import org.apache.hudi.common.util.collection.Pair;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import java.io.IOException;
import java.time.ZoneId;
import java.util.HashMap;
import java.util.Map;
public class Spark33HoodieVectorizedParquetRecordReader extends VectorizedParquetRecordReader {
// save the col type change info.
private Map<Integer, Pair<DataType, DataType>> typeChangeInfos;
private ColumnarBatch columnarBatch;
private Map<Integer, WritableColumnVector> idToColumnVectors;
private WritableColumnVector[] columnVectors;
// The capacity of vectorized batch.
private int capacity;
// If true, this class returns batches instead of rows.
private boolean returnColumnarBatch;
// The memory mode of the columnarBatch.
private final MemoryMode memoryMode;
/**
* Batch of rows that we assemble and the current index we've returned. Every time this
* batch is used up (batchIdx == numBatched), we populated the batch.
*/
private int batchIdx = 0;
private int numBatched = 0;
public Spark33HoodieVectorizedParquetRecordReader(
ZoneId convertTz,
String datetimeRebaseMode,
String datetimeRebaseTz,
String int96RebaseMode,
String int96RebaseTz,
boolean useOffHeap,
int capacity,
Map<Integer, Pair<DataType, DataType>> typeChangeInfos) {
super(convertTz, datetimeRebaseMode, datetimeRebaseTz, int96RebaseMode, int96RebaseTz, useOffHeap, capacity);
memoryMode = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP;
this.typeChangeInfos = typeChangeInfos;
this.capacity = capacity;
}
@Override
public void initBatch(StructType partitionColumns, InternalRow partitionValues) {
super.initBatch(partitionColumns, partitionValues);
if (columnVectors == null) {
columnVectors = new WritableColumnVector[sparkSchema.length() + partitionColumns.length()];
}
if (idToColumnVectors == null) {
idToColumnVectors = new HashMap<>();
typeChangeInfos.entrySet()
.stream()
.forEach(f -> {
WritableColumnVector vector =
memoryMode == MemoryMode.OFF_HEAP ? new OffHeapColumnVector(capacity, f.getValue().getLeft()) : new OnHeapColumnVector(capacity, f.getValue().getLeft());
idToColumnVectors.put(f.getKey(), vector);
});
}
}
@Override
public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException, InterruptedException, UnsupportedOperationException {
super.initialize(inputSplit, taskAttemptContext);
}
@Override
public void close() throws IOException {
super.close();
for (Map.Entry<Integer, WritableColumnVector> e : idToColumnVectors.entrySet()) {
e.getValue().close();
}
idToColumnVectors = null;
columnarBatch = null;
columnVectors = null;
}
@Override
public ColumnarBatch resultBatch() {
ColumnarBatch currentColumnBatch = super.resultBatch();
boolean changed = false;
for (Map.Entry<Integer, Pair<DataType, DataType>> entry : typeChangeInfos.entrySet()) {
boolean rewrite = SparkInternalSchemaConverter
.convertColumnVectorType((WritableColumnVector) currentColumnBatch.column(entry.getKey()),
idToColumnVectors.get(entry.getKey()), currentColumnBatch.numRows());
if (rewrite) {
changed = true;
columnVectors[entry.getKey()] = idToColumnVectors.get(entry.getKey());
}
}
if (changed) {
if (columnarBatch == null) {
// fill other vector
for (int i = 0; i < columnVectors.length; i++) {
if (columnVectors[i] == null) {
columnVectors[i] = (WritableColumnVector) currentColumnBatch.column(i);
}
}
columnarBatch = new ColumnarBatch(columnVectors);
}
columnarBatch.setNumRows(currentColumnBatch.numRows());
return columnarBatch;
} else {
return currentColumnBatch;
}
}
@Override
public boolean nextBatch() throws IOException {
boolean result = super.nextBatch();
if (idToColumnVectors != null) {
idToColumnVectors.entrySet().stream().forEach(e -> e.getValue().reset());
}
numBatched = resultBatch().numRows();
batchIdx = 0;
return result;
}
@Override
public void enableReturningBatches() {
returnColumnarBatch = true;
super.enableReturningBatches();
}
@Override
public Object getCurrentValue() {
if (typeChangeInfos == null || typeChangeInfos.isEmpty()) {
return super.getCurrentValue();
}
if (returnColumnarBatch) {
return columnarBatch == null ? super.getCurrentValue() : columnarBatch;
}
return columnarBatch == null ? super.getCurrentValue() : columnarBatch.getRow(batchIdx - 1);
}
@Override
public boolean nextKeyValue() throws IOException {
resultBatch();
if (returnColumnarBatch) {
return nextBatch();
}
if (batchIdx >= numBatched) {
if (!nextBatch()) {
return false;
}
}
++batchIdx;
return true;
}
}

View File

@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
org.apache.hudi.Spark3DefaultSource

View File

@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.types.StructType
class Spark33HoodieFileScanRDD(@transient private val sparkSession: SparkSession,
read: PartitionedFile => Iterator[InternalRow],
@transient filePartitions: Seq[FilePartition],
readDataSchema: StructType, metadataColumns: Seq[AttributeReference] = Seq.empty)
extends FileScanRDD(sparkSession, read, filePartitions, readDataSchema, metadataColumns)
with HoodieUnsafeRDD {
override final def collect(): Array[InternalRow] = super[HoodieUnsafeRDD].collect()
}

View File

@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi
import org.apache.spark.sql.sources.DataSourceRegister
/**
* NOTE: PLEASE READ CAREFULLY
* All of Spark DataSourceV2 APIs are deliberately disabled to make sure
* there are no regressions in performance
* Please check out HUDI-4178 for more details
*/
class Spark3DefaultSource extends DefaultSource with DataSourceRegister /* with TableProvider */ {
override def shortName(): String = "hudi"
/*
def inferSchema: StructType = new StructType()
override def inferSchema(options: CaseInsensitiveStringMap): StructType = inferSchema
override def getTable(schema: StructType,
partitioning: Array[Transform],
properties: java.util.Map[String, String]): Table = {
val options = new CaseInsensitiveStringMap(properties)
val path = options.get("path")
if (path == null) throw new HoodieException("'path' cannot be null, missing 'path' from table properties")
HoodieInternalV2Table(SparkSession.active, path)
}
*/
}

View File

@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import HoodieSparkTypeUtils.isCastPreservingOrdering
import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
object HoodieSpark33CatalystExpressionUtils extends HoodieCatalystExpressionUtils {
override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = {
expr match {
case OrderPreservingTransformation(attrRef) => Some(attrRef)
case _ => None
}
}
private object OrderPreservingTransformation {
def unapply(expr: Expression): Option[AttributeReference] = {
expr match {
// Date/Time Expressions
case DateFormatClass(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
case DateAdd(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
case DateSub(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
case DateDiff(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
case DateDiff(_, OrderPreservingTransformation(attrRef)) => Some(attrRef)
case FromUnixTime(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
case FromUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
case ParseToDate(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
case ParseToTimestamp(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef)
case ToUnixTimestamp(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef)
case ToUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
// String Expressions
case Lower(OrderPreservingTransformation(attrRef)) => Some(attrRef)
case Upper(OrderPreservingTransformation(attrRef)) => Some(attrRef)
// Left API change: Improve RuntimeReplaceable
// https://issues.apache.org/jira/browse/SPARK-38240
case org.apache.spark.sql.catalyst.expressions.Left(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
// Math Expressions
// Binary
case Add(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
case Add(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef)
case Multiply(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
case Multiply(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef)
case Divide(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
case BitwiseOr(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
case BitwiseOr(_, OrderPreservingTransformation(attrRef)) => Some(attrRef)
// Unary
case Exp(OrderPreservingTransformation(attrRef)) => Some(attrRef)
case Expm1(OrderPreservingTransformation(attrRef)) => Some(attrRef)
case Log(OrderPreservingTransformation(attrRef)) => Some(attrRef)
case Log10(OrderPreservingTransformation(attrRef)) => Some(attrRef)
case Log1p(OrderPreservingTransformation(attrRef)) => Some(attrRef)
case Log2(OrderPreservingTransformation(attrRef)) => Some(attrRef)
case ShiftLeft(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
case ShiftRight(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
// Other
case cast @ Cast(OrderPreservingTransformation(attrRef), _, _, _)
if isCastPreservingOrdering(cast.child.dataType, cast.dataType) => Some(attrRef)
// Identity transformation
case attrRef: AttributeReference => Some(attrRef)
// No match
case _ => None
}
}
}
}

View File

@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TimeTravelRelation}
object HoodieSpark33CatalystPlanUtils extends HoodieSpark3CatalystPlanUtils {
override def isRelationTimeTravel(plan: LogicalPlan): Boolean = {
plan.isInstanceOf[TimeTravelRelation]
}
override def getRelationTimeTravel(plan: LogicalPlan): Option[(LogicalPlan, Option[Expression], Option[String])] = {
plan match {
case timeTravel: TimeTravelRelation =>
Some((timeTravel.table, timeTravel.timestamp, timeTravel.version))
case _ =>
None
}
}
}

View File

@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.adapter
import org.apache.avro.Schema
import org.apache.hudi.Spark33HoodieFileScanRDD
import org.apache.spark.sql.avro._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Spark33HoodieParquetFileFormat}
import org.apache.spark.sql.parser.HoodieSpark3_3ExtendedSqlParser
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieCatalystPlansUtils, HoodieSpark33CatalystPlanUtils, HoodieSpark33CatalystExpressionUtils, SparkSession}
/**
* Implementation of [[SparkAdapter]] for Spark 3.3.x branch
*/
class Spark3_3Adapter extends BaseSpark3Adapter {
override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark33CatalystExpressionUtils
override def getCatalystPlanUtils: HoodieCatalystPlansUtils = HoodieSpark33CatalystPlanUtils
override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer =
new HoodieSpark3_3AvroSerializer(rootCatalystType, rootAvroType, nullable)
override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
new HoodieSpark3_3AvroDeserializer(rootAvroType, rootCatalystType)
override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = {
Some(
(spark: SparkSession, delegate: ParserInterface) => new HoodieSpark3_3ExtendedSqlParser(spark, delegate)
)
}
override def createHoodieParquetFileFormat(appendPartitionValues: Boolean): Option[ParquetFileFormat] = {
Some(new Spark33HoodieParquetFileFormat(appendPartitionValues))
}
override def createHoodieFileScanRDD(sparkSession: SparkSession,
readFunction: PartitionedFile => Iterator[InternalRow],
filePartitions: Seq[FilePartition],
readDataSchema: StructType,
metadataColumns: Seq[AttributeReference] = Seq.empty): FileScanRDD = {
new Spark33HoodieFileScanRDD(sparkSession, readFunction, filePartitions, readDataSchema, metadataColumns)
}
override def resolveDeleteFromTable(deleteFromTable: Command,
resolveExpression: Expression => Expression): DeleteFromTable = {
val deleteFromTableCommand = deleteFromTable.asInstanceOf[DeleteFromTable]
DeleteFromTable(deleteFromTableCommand.table, resolveExpression(deleteFromTableCommand.condition))
}
override def extractCondition(deleteFromTable: Command): Expression = {
deleteFromTable.asInstanceOf[DeleteFromTable].condition
}
override def getQueryParserFromExtendedSqlParser(session: SparkSession, delegate: ParserInterface,
sqlText: String): LogicalPlan = {
new HoodieSpark3_3ExtendedSqlParser(session, delegate).parseQuery(sqlText)
}
}

View File

@@ -0,0 +1,499 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.avro
import java.math.BigDecimal
import java.nio.ByteBuffer
import scala.collection.JavaConverters._
import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.LogicalTypes.{LocalTimestampMicros, LocalTimestampMillis, TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._
import org.apache.avro.generic._
import org.apache.avro.util.Utf8
import org.apache.spark.sql.avro.AvroDeserializer.{RebaseSpec, createDateRebaseFuncInRead, createTimestampRebaseFuncInRead}
import org.apache.spark.sql.avro.AvroUtils.{AvroMatchedField, toFieldStr}
import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData, RebaseDateTime}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import java.util.TimeZone
/**
* A deserializer to deserialize data in avro format to data in catalyst format.
*
* NOTE: This code is borrowed from Spark 3.3.0
* This code is borrowed, so that we can better control compatibility w/in Spark minor
* branches (3.2.x, 3.1.x, etc)
*
* PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
*/
private[sql] class AvroDeserializer(
rootAvroType: Schema,
rootCatalystType: DataType,
positionalFieldMatch: Boolean,
datetimeRebaseSpec: RebaseSpec,
filters: StructFilters) {
def this(
rootAvroType: Schema,
rootCatalystType: DataType,
datetimeRebaseMode: String) = {
this(
rootAvroType,
rootCatalystType,
positionalFieldMatch = false,
RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)),
new NoopFilters)
}
private lazy val decimalConversions = new DecimalConversion()
private val dateRebaseFunc = createDateRebaseFuncInRead(
datetimeRebaseSpec.mode, "Avro")
private val timestampRebaseFunc = createTimestampRebaseFuncInRead(
datetimeRebaseSpec, "Avro")
private val converter: Any => Option[Any] = try {
rootCatalystType match {
// A shortcut for empty schema.
case st: StructType if st.isEmpty =>
(_: Any) => Some(InternalRow.empty)
case st: StructType =>
val resultRow = new SpecificInternalRow(st.map(_.dataType))
val fieldUpdater = new RowUpdater(resultRow)
val applyFilters = filters.skipRow(resultRow, _)
val writer = getRecordWriter(rootAvroType, st, Nil, Nil, applyFilters)
(data: Any) => {
val record = data.asInstanceOf[GenericRecord]
val skipRow = writer(fieldUpdater, record)
if (skipRow) None else Some(resultRow)
}
case _ =>
val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
val fieldUpdater = new RowUpdater(tmpRow)
val writer = newWriter(rootAvroType, rootCatalystType, Nil, Nil)
(data: Any) => {
writer(fieldUpdater, 0, data)
Some(tmpRow.get(0, rootCatalystType))
}
}
} catch {
case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException(
s"Cannot convert Avro type $rootAvroType to SQL type ${rootCatalystType.sql}.", ise)
}
def deserialize(data: Any): Option[Any] = converter(data)
/**
* Creates a writer to write avro values to Catalyst values at the given ordinal with the given
* updater.
*/
private def newWriter(
avroType: Schema,
catalystType: DataType,
avroPath: Seq[String],
catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = {
val errorPrefix = s"Cannot convert Avro ${toFieldStr(avroPath)} to " +
s"SQL ${toFieldStr(catalystPath)} because "
val incompatibleMsg = errorPrefix +
s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})"
(avroType.getType, catalystType) match {
case (NULL, NullType) => (updater, ordinal, _) =>
updater.setNullAt(ordinal)
// TODO: we can avoid boxing if future version of avro provide primitive accessors.
case (BOOLEAN, BooleanType) => (updater, ordinal, value) =>
updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
case (INT, IntegerType) => (updater, ordinal, value) =>
updater.setInt(ordinal, value.asInstanceOf[Int])
case (INT, DateType) => (updater, ordinal, value) =>
updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int]))
case (LONG, LongType) => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long])
case (LONG, TimestampType) => avroType.getLogicalType match {
// For backward compatibility, if the Avro type is Long and it is not logical type
// (the `null` case), the value is processed as timestamp type with millisecond precision.
case null | _: TimestampMillis => (updater, ordinal, value) =>
val millis = value.asInstanceOf[Long]
val micros = DateTimeUtils.millisToMicros(millis)
updater.setLong(ordinal, timestampRebaseFunc(micros))
case _: TimestampMicros => (updater, ordinal, value) =>
val micros = value.asInstanceOf[Long]
updater.setLong(ordinal, timestampRebaseFunc(micros))
case other => throw new IncompatibleSchemaException(errorPrefix +
s"Avro logical type $other cannot be converted to SQL type ${TimestampType.sql}.")
}
case (LONG, TimestampNTZType) => avroType.getLogicalType match {
// To keep consistent with TimestampType, if the Avro type is Long and it is not
// logical type (the `null` case), the value is processed as TimestampNTZ
// with millisecond precision.
case null | _: LocalTimestampMillis => (updater, ordinal, value) =>
val millis = value.asInstanceOf[Long]
val micros = DateTimeUtils.millisToMicros(millis)
updater.setLong(ordinal, micros)
case _: LocalTimestampMicros => (updater, ordinal, value) =>
val micros = value.asInstanceOf[Long]
updater.setLong(ordinal, micros)
case other => throw new IncompatibleSchemaException(errorPrefix +
s"Avro logical type $other cannot be converted to SQL type ${TimestampNTZType.sql}.")
}
// Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date.
// For backward compatibility, we still keep this conversion.
case (LONG, DateType) => (updater, ordinal, value) =>
updater.setInt(ordinal, (value.asInstanceOf[Long] / MILLIS_PER_DAY).toInt)
case (FLOAT, FloatType) => (updater, ordinal, value) =>
updater.setFloat(ordinal, value.asInstanceOf[Float])
case (DOUBLE, DoubleType) => (updater, ordinal, value) =>
updater.setDouble(ordinal, value.asInstanceOf[Double])
case (STRING, StringType) => (updater, ordinal, value) =>
val str = value match {
case s: String => UTF8String.fromString(s)
case s: Utf8 =>
val bytes = new Array[Byte](s.getByteLength)
System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength)
UTF8String.fromBytes(bytes)
}
updater.set(ordinal, str)
case (ENUM, StringType) => (updater, ordinal, value) =>
updater.set(ordinal, UTF8String.fromString(value.toString))
case (FIXED, BinaryType) => (updater, ordinal, value) =>
updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone())
case (BYTES, BinaryType) => (updater, ordinal, value) =>
val bytes = value match {
case b: ByteBuffer =>
val bytes = new Array[Byte](b.remaining)
b.get(bytes)
// Do not forget to reset the position
b.rewind()
bytes
case b: Array[Byte] => b
case other =>
throw new RuntimeException(errorPrefix + s"$other is not a valid avro binary.")
}
updater.set(ordinal, bytes)
case (FIXED, _: DecimalType) => (updater, ordinal, value) =>
val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal]
val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, d)
val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale)
updater.setDecimal(ordinal, decimal)
case (BYTES, _: DecimalType) => (updater, ordinal, value) =>
val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal]
val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, d)
val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale)
updater.setDecimal(ordinal, decimal)
case (RECORD, st: StructType) =>
// Avro datasource doesn't accept filters with nested attributes. See SPARK-32328.
// We can always return `false` from `applyFilters` for nested records.
val writeRecord =
getRecordWriter(avroType, st, avroPath, catalystPath, applyFilters = _ => false)
(updater, ordinal, value) =>
val row = new SpecificInternalRow(st)
writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord])
updater.set(ordinal, row)
case (ARRAY, ArrayType(elementType, containsNull)) =>
val avroElementPath = avroPath :+ "element"
val elementWriter = newWriter(avroType.getElementType, elementType,
avroElementPath, catalystPath :+ "element")
(updater, ordinal, value) =>
val collection = value.asInstanceOf[java.util.Collection[Any]]
val result = createArrayData(elementType, collection.size())
val elementUpdater = new ArrayDataUpdater(result)
var i = 0
val iter = collection.iterator()
while (iter.hasNext) {
val element = iter.next()
if (element == null) {
if (!containsNull) {
throw new RuntimeException(
s"Array value at path ${toFieldStr(avroElementPath)} is not allowed to be null")
} else {
elementUpdater.setNullAt(i)
}
} else {
elementWriter(elementUpdater, i, element)
}
i += 1
}
updater.set(ordinal, result)
case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType =>
val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType,
avroPath :+ "key", catalystPath :+ "key")
val valueWriter = newWriter(avroType.getValueType, valueType,
avroPath :+ "value", catalystPath :+ "value")
(updater, ordinal, value) =>
val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]]
val keyArray = createArrayData(keyType, map.size())
val keyUpdater = new ArrayDataUpdater(keyArray)
val valueArray = createArrayData(valueType, map.size())
val valueUpdater = new ArrayDataUpdater(valueArray)
val iter = map.entrySet().iterator()
var i = 0
while (iter.hasNext) {
val entry = iter.next()
assert(entry.getKey != null)
keyWriter(keyUpdater, i, entry.getKey)
if (entry.getValue == null) {
if (!valueContainsNull) {
throw new RuntimeException(
s"Map value at path ${toFieldStr(avroPath :+ "value")} is not allowed to be null")
} else {
valueUpdater.setNullAt(i)
}
} else {
valueWriter(valueUpdater, i, entry.getValue)
}
i += 1
}
// The Avro map will never have null or duplicated map keys, it's safe to create a
// ArrayBasedMapData directly here.
updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
case (UNION, _) =>
val allTypes = avroType.getTypes.asScala
val nonNullTypes = allTypes.filter(_.getType != NULL)
val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava)
if (nonNullTypes.nonEmpty) {
if (nonNullTypes.length == 1) {
newWriter(nonNullTypes.head, catalystType, avroPath, catalystPath)
} else {
nonNullTypes.map(_.getType).toSeq match {
case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType =>
(updater, ordinal, value) => value match {
case null => updater.setNullAt(ordinal)
case l: java.lang.Long => updater.setLong(ordinal, l)
case i: java.lang.Integer => updater.setLong(ordinal, i.longValue())
}
case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType =>
(updater, ordinal, value) => value match {
case null => updater.setNullAt(ordinal)
case d: java.lang.Double => updater.setDouble(ordinal, d)
case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue())
}
case _ =>
catalystType match {
case st: StructType if st.length == nonNullTypes.size =>
val fieldWriters = nonNullTypes.zip(st.fields).map {
case (schema, field) =>
newWriter(schema, field.dataType, avroPath, catalystPath :+ field.name)
}.toArray
(updater, ordinal, value) => {
val row = new SpecificInternalRow(st)
val fieldUpdater = new RowUpdater(row)
val i = GenericData.get().resolveUnion(nonNullAvroType, value)
fieldWriters(i)(fieldUpdater, i, value)
updater.set(ordinal, row)
}
case _ => throw new IncompatibleSchemaException(incompatibleMsg)
}
}
}
} else {
(updater, ordinal, _) => updater.setNullAt(ordinal)
}
case (INT, _: YearMonthIntervalType) => (updater, ordinal, value) =>
updater.setInt(ordinal, value.asInstanceOf[Int])
case (LONG, _: DayTimeIntervalType) => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long])
case _ => throw new IncompatibleSchemaException(incompatibleMsg)
}
}
// TODO: move the following method in Decimal object on creating Decimal from BigDecimal?
private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
if (precision <= Decimal.MAX_LONG_DIGITS) {
// Constructs a `Decimal` with an unscaled `Long` value if possible.
Decimal(decimal.unscaledValue().longValue(), precision, scale)
} else {
// Otherwise, resorts to an unscaled `BigInteger` instead.
Decimal(decimal, precision, scale)
}
}
private def getRecordWriter(
avroType: Schema,
catalystType: StructType,
avroPath: Seq[String],
catalystPath: Seq[String],
applyFilters: Int => Boolean): (CatalystDataUpdater, GenericRecord) => Boolean = {
val avroSchemaHelper = new AvroUtils.AvroSchemaHelper(
avroType, catalystType, avroPath, catalystPath, positionalFieldMatch)
avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true)
// no need to validateNoExtraAvroFields since extra Avro fields are ignored
val (validFieldIndexes, fieldWriters) = avroSchemaHelper.matchedFields.map {
case AvroMatchedField(catalystField, ordinal, avroField) =>
val baseWriter = newWriter(avroField.schema(), catalystField.dataType,
avroPath :+ avroField.name, catalystPath :+ catalystField.name)
val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
if (value == null) {
fieldUpdater.setNullAt(ordinal)
} else {
baseWriter(fieldUpdater, ordinal, value)
}
}
(avroField.pos(), fieldWriter)
}.toArray.unzip
(fieldUpdater, record) => {
var i = 0
var skipRow = false
while (i < validFieldIndexes.length && !skipRow) {
fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i)))
skipRow = applyFilters(i)
i += 1
}
skipRow
}
}
private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match {
case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length))
case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length))
case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length))
case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length))
case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length))
case _ => new GenericArrayData(new Array[Any](length))
}
/**
* A base interface for updating values inside catalyst data structure like `InternalRow` and
* `ArrayData`.
*/
sealed trait CatalystDataUpdater {
def set(ordinal: Int, value: Any): Unit
def setNullAt(ordinal: Int): Unit = set(ordinal, null)
def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value)
}
final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)
override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
override def setDecimal(ordinal: Int, value: Decimal): Unit =
row.setDecimal(ordinal, value, value.precision)
}
final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value)
override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value)
override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value)
override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value)
override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value)
override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value)
}
}
object AvroDeserializer {
// NOTE: Following methods have been renamed in Spark 3.2.1 [1] making [[AvroDeserializer]] implementation
// (which relies on it) be only compatible with the exact same version of [[DataSourceUtils]].
// To make sure this implementation is compatible w/ all Spark versions w/in Spark 3.2.x branch,
// we're preemptively cloned those methods to make sure Hudi is compatible w/ Spark 3.2.0 as well as
// w/ Spark >= 3.2.1
//
// [1] https://github.com/apache/spark/pull/34978
// Specification of rebase operation including `mode` and the time zone in which it is performed
case class RebaseSpec(mode: LegacyBehaviorPolicy.Value, originTimeZone: Option[String] = None) {
// Use the default JVM time zone for backward compatibility
def timeZone: String = originTimeZone.getOrElse(TimeZone.getDefault.getID)
}
def createDateRebaseFuncInRead(rebaseMode: LegacyBehaviorPolicy.Value,
format: String): Int => Int = rebaseMode match {
case LegacyBehaviorPolicy.EXCEPTION => days: Int =>
if (days < RebaseDateTime.lastSwitchJulianDay) {
throw DataSourceUtils.newRebaseExceptionInRead(format)
}
days
case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays
case LegacyBehaviorPolicy.CORRECTED => identity[Int]
}
def createTimestampRebaseFuncInRead(rebaseSpec: RebaseSpec,
format: String): Long => Long = rebaseSpec.mode match {
case LegacyBehaviorPolicy.EXCEPTION => micros: Long =>
if (micros < RebaseDateTime.lastSwitchJulianTs) {
throw DataSourceUtils.newRebaseExceptionInRead(format)
}
micros
case LegacyBehaviorPolicy.LEGACY => micros: Long =>
RebaseDateTime.rebaseJulianToGregorianMicros(TimeZone.getTimeZone(rebaseSpec.timeZone), micros)
case LegacyBehaviorPolicy.CORRECTED => identity[Long]
}
}

View File

@@ -0,0 +1,381 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.avro
import java.nio.ByteBuffer
import scala.collection.JavaConverters._
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.LogicalTypes
import org.apache.avro.LogicalTypes.{LocalTimestampMicros, LocalTimestampMillis, TimestampMicros, TimestampMillis}
import org.apache.avro.Schema
import org.apache.avro.Schema.Type
import org.apache.avro.Schema.Type._
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.avro.generic.GenericData.Record
import org.apache.avro.util.Utf8
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.AvroUtils.{AvroMatchedField, toFieldStr}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, RebaseDateTime}
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.types._
import java.util.TimeZone
/**
* A serializer to serialize data in catalyst format to data in avro format.
*
* NOTE: This code is borrowed from Spark 3.3.0
* This code is borrowed, so that we can better control compatibility w/in Spark minor
* branches (3.2.x, 3.1.x, etc)
*
* PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
*/
private[sql] class AvroSerializer(
rootCatalystType: DataType,
rootAvroType: Schema,
nullable: Boolean,
positionalFieldMatch: Boolean,
datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends Logging {
def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) = {
this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch = false,
LegacyBehaviorPolicy.withName(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE)))
}
def serialize(catalystData: Any): Any = {
converter.apply(catalystData)
}
private val dateRebaseFunc = DataSourceUtils.createDateRebaseFuncInWrite(
datetimeRebaseMode, "Avro")
private val timestampRebaseFunc = DataSourceUtils.createTimestampRebaseFuncInWrite(
datetimeRebaseMode, "Avro")
private val converter: Any => Any = {
val actualAvroType = resolveNullableType(rootAvroType, nullable)
val baseConverter = try {
rootCatalystType match {
case st: StructType =>
newStructConverter(st, actualAvroType, Nil, Nil).asInstanceOf[Any => Any]
case _ =>
val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
val converter = newConverter(rootCatalystType, actualAvroType, Nil, Nil)
(data: Any) =>
tmpRow.update(0, data)
converter.apply(tmpRow, 0)
}
} catch {
case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException(
s"Cannot convert SQL type ${rootCatalystType.sql} to Avro type $rootAvroType.", ise)
}
if (nullable) {
(data: Any) =>
if (data == null) {
null
} else {
baseConverter.apply(data)
}
} else {
baseConverter
}
}
private type Converter = (SpecializedGetters, Int) => Any
private lazy val decimalConversions = new DecimalConversion()
private def newConverter(
catalystType: DataType,
avroType: Schema,
catalystPath: Seq[String],
avroPath: Seq[String]): Converter = {
val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " +
s"to Avro ${toFieldStr(avroPath)} because "
(catalystType, avroType.getType) match {
case (NullType, NULL) =>
(getter, ordinal) => null
case (BooleanType, BOOLEAN) =>
(getter, ordinal) => getter.getBoolean(ordinal)
case (ByteType, INT) =>
(getter, ordinal) => getter.getByte(ordinal).toInt
case (ShortType, INT) =>
(getter, ordinal) => getter.getShort(ordinal).toInt
case (IntegerType, INT) =>
(getter, ordinal) => getter.getInt(ordinal)
case (LongType, LONG) =>
(getter, ordinal) => getter.getLong(ordinal)
case (FloatType, FLOAT) =>
(getter, ordinal) => getter.getFloat(ordinal)
case (DoubleType, DOUBLE) =>
(getter, ordinal) => getter.getDouble(ordinal)
case (d: DecimalType, FIXED)
if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
(getter, ordinal) =>
val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
LogicalTypes.decimal(d.precision, d.scale))
case (d: DecimalType, BYTES)
if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
(getter, ordinal) =>
val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType,
LogicalTypes.decimal(d.precision, d.scale))
case (StringType, ENUM) =>
val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
(getter, ordinal) =>
val data = getter.getUTF8String(ordinal).toString
if (!enumSymbols.contains(data)) {
throw new IncompatibleSchemaException(errorPrefix +
s""""$data" cannot be written since it's not defined in enum """ +
enumSymbols.mkString("\"", "\", \"", "\""))
}
new EnumSymbol(avroType, data)
case (StringType, STRING) =>
(getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
case (BinaryType, FIXED) =>
val size = avroType.getFixedSize
(getter, ordinal) =>
val data: Array[Byte] = getter.getBinary(ordinal)
if (data.length != size) {
def len2str(len: Int): String = s"$len ${if (len > 1) "bytes" else "byte"}"
throw new IncompatibleSchemaException(errorPrefix + len2str(data.length) +
" of binary data cannot be written into FIXED type with size of " + len2str(size))
}
new Fixed(avroType, data)
case (BinaryType, BYTES) =>
(getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
case (DateType, INT) =>
(getter, ordinal) => dateRebaseFunc(getter.getInt(ordinal))
case (TimestampType, LONG) => avroType.getLogicalType match {
// For backward compatibility, if the Avro type is Long and it is not logical type
// (the `null` case), output the timestamp value as with millisecond precision.
case null | _: TimestampMillis => (getter, ordinal) =>
DateTimeUtils.microsToMillis(timestampRebaseFunc(getter.getLong(ordinal)))
case _: TimestampMicros => (getter, ordinal) =>
timestampRebaseFunc(getter.getLong(ordinal))
case other => throw new IncompatibleSchemaException(errorPrefix +
s"SQL type ${TimestampType.sql} cannot be converted to Avro logical type $other")
}
case (TimestampNTZType, LONG) => avroType.getLogicalType match {
// To keep consistent with TimestampType, if the Avro type is Long and it is not
// logical type (the `null` case), output the TimestampNTZ as long value
// in millisecond precision.
case null | _: LocalTimestampMillis => (getter, ordinal) =>
DateTimeUtils.microsToMillis(getter.getLong(ordinal))
case _: LocalTimestampMicros => (getter, ordinal) =>
getter.getLong(ordinal)
case other => throw new IncompatibleSchemaException(errorPrefix +
s"SQL type ${TimestampNTZType.sql} cannot be converted to Avro logical type $other")
}
case (ArrayType(et, containsNull), ARRAY) =>
val elementConverter = newConverter(
et, resolveNullableType(avroType.getElementType, containsNull),
catalystPath :+ "element", avroPath :+ "element")
(getter, ordinal) => {
val arrayData = getter.getArray(ordinal)
val len = arrayData.numElements()
val result = new Array[Any](len)
var i = 0
while (i < len) {
if (containsNull && arrayData.isNullAt(i)) {
result(i) = null
} else {
result(i) = elementConverter(arrayData, i)
}
i += 1
}
// avro writer is expecting a Java Collection, so we convert it into
// `ArrayList` backed by the specified array without data copying.
java.util.Arrays.asList(result: _*)
}
case (st: StructType, RECORD) =>
val structConverter = newStructConverter(st, avroType, catalystPath, avroPath)
val numFields = st.length
(getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))
case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
val valueConverter = newConverter(
vt, resolveNullableType(avroType.getValueType, valueContainsNull),
catalystPath :+ "value", avroPath :+ "value")
(getter, ordinal) =>
val mapData = getter.getMap(ordinal)
val len = mapData.numElements()
val result = new java.util.HashMap[String, Any](len)
val keyArray = mapData.keyArray()
val valueArray = mapData.valueArray()
var i = 0
while (i < len) {
val key = keyArray.getUTF8String(i).toString
if (valueContainsNull && valueArray.isNullAt(i)) {
result.put(key, null)
} else {
result.put(key, valueConverter(valueArray, i))
}
i += 1
}
result
case (_: YearMonthIntervalType, INT) =>
(getter, ordinal) => getter.getInt(ordinal)
case (_: DayTimeIntervalType, LONG) =>
(getter, ordinal) => getter.getLong(ordinal)
case _ =>
throw new IncompatibleSchemaException(errorPrefix +
s"schema is incompatible (sqlType = ${catalystType.sql}, avroType = $avroType)")
}
}
private def newStructConverter(
catalystStruct: StructType,
avroStruct: Schema,
catalystPath: Seq[String],
avroPath: Seq[String]): InternalRow => Record = {
val avroSchemaHelper = new AvroUtils.AvroSchemaHelper(
avroStruct, catalystStruct, avroPath, catalystPath, positionalFieldMatch)
avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false)
avroSchemaHelper.validateNoExtraRequiredAvroFields()
val (avroIndices, fieldConverters) = avroSchemaHelper.matchedFields.map {
case AvroMatchedField(catalystField, _, avroField) =>
val converter = newConverter(catalystField.dataType,
resolveNullableType(avroField.schema(), catalystField.nullable),
catalystPath :+ catalystField.name, avroPath :+ avroField.name)
(avroField.pos(), converter)
}.toArray.unzip
val numFields = catalystStruct.length
row: InternalRow =>
val result = new Record(avroStruct)
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
result.put(avroIndices(i), null)
} else {
result.put(avroIndices(i), fieldConverters(i).apply(row, i))
}
i += 1
}
result
}
/**
* Resolve a possibly nullable Avro Type.
*
* An Avro type is nullable when it is a [[UNION]] of two types: one null type and another
* non-null type. This method will check the nullability of the input Avro type and return the
* non-null type within when it is nullable. Otherwise it will return the input Avro type
* unchanged. It will throw an [[UnsupportedAvroTypeException]] when the input Avro type is an
* unsupported nullable type.
*
* It will also log a warning message if the nullability for Avro and catalyst types are
* different.
*/
private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = {
val (avroNullable, resolvedAvroType) = resolveAvroType(avroType)
warnNullabilityDifference(avroNullable, nullable)
resolvedAvroType
}
/**
* Check the nullability of the input Avro type and resolve it when it is nullable. The first
* return value is a [[Boolean]] indicating if the input Avro type is nullable. The second
* return value is the possibly resolved type.
*/
private def resolveAvroType(avroType: Schema): (Boolean, Schema) = {
if (avroType.getType == Type.UNION) {
val fields = avroType.getTypes.asScala
val actualType = fields.filter(_.getType != Type.NULL)
if (fields.length != 2 || actualType.length != 1) {
throw new UnsupportedAvroTypeException(
s"Unsupported Avro UNION type $avroType: Only UNION of a null type and a non-null " +
"type is supported")
}
(true, actualType.head)
} else {
(false, avroType)
}
}
/**
* log a warning message if the nullability for Avro and catalyst types are different.
*/
private def warnNullabilityDifference(avroNullable: Boolean, catalystNullable: Boolean): Unit = {
if (avroNullable && !catalystNullable) {
logWarning("Writing Avro files with nullable Avro schema and non-nullable catalyst schema.")
}
if (!avroNullable && catalystNullable) {
logWarning("Writing Avro files with non-nullable Avro schema and nullable catalyst " +
"schema will throw runtime exception if there is a record with null value.")
}
}
}
object AvroSerializer {
// NOTE: Following methods have been renamed in Spark 3.2.1 [1] making [[AvroSerializer]] implementation
// (which relies on it) be only compatible with the exact same version of [[DataSourceUtils]].
// To make sure this implementation is compatible w/ all Spark versions w/in Spark 3.2.x branch,
// we're preemptively cloned those methods to make sure Hudi is compatible w/ Spark 3.2.0 as well as
// w/ Spark >= 3.2.1
//
// [1] https://github.com/apache/spark/pull/34978
def createDateRebaseFuncInWrite(rebaseMode: LegacyBehaviorPolicy.Value,
format: String): Int => Int = rebaseMode match {
case LegacyBehaviorPolicy.EXCEPTION => days: Int =>
if (days < RebaseDateTime.lastSwitchGregorianDay) {
throw DataSourceUtils.newRebaseExceptionInWrite(format)
}
days
case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianDays
case LegacyBehaviorPolicy.CORRECTED => identity[Int]
}
def createTimestampRebaseFuncInWrite(rebaseMode: LegacyBehaviorPolicy.Value,
format: String): Long => Long = rebaseMode match {
case LegacyBehaviorPolicy.EXCEPTION => micros: Long =>
if (micros < RebaseDateTime.lastSwitchGregorianTs) {
throw DataSourceUtils.newRebaseExceptionInWrite(format)
}
micros
case LegacyBehaviorPolicy.LEGACY =>
val timeZone = SQLConf.get.sessionLocalTimeZone
RebaseDateTime.rebaseGregorianToJulianMicros(TimeZone.getTimeZone(timeZone), _)
case LegacyBehaviorPolicy.CORRECTED => identity[Long]
}
}

View File

@@ -0,0 +1,228 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.avro
import java.util.Locale
import scala.collection.JavaConverters._
import org.apache.avro.Schema
import org.apache.avro.file. FileReader
import org.apache.avro.generic.GenericRecord
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
/**
* NOTE: This code is borrowed from Spark 3.3.0
* This code is borrowed, so that we can better control compatibility w/in Spark minor
* branches (3.2.x, 3.1.x, etc)
*
* PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
*/
private[sql] object AvroUtils extends Logging {
def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true
case st: StructType => st.forall { f => supportsDataType(f.dataType) }
case ArrayType(elementType, _) => supportsDataType(elementType)
case MapType(keyType, valueType, _) =>
supportsDataType(keyType) && supportsDataType(valueType)
case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
case _: NullType => true
case _ => false
}
// The trait provides iterator-like interface for reading records from an Avro file,
// deserializing and returning them as internal rows.
trait RowReader {
protected val fileReader: FileReader[GenericRecord]
protected val deserializer: AvroDeserializer
protected val stopPosition: Long
private[this] var completed = false
private[this] var currentRow: Option[InternalRow] = None
def hasNextRow: Boolean = {
while (!completed && currentRow.isEmpty) {
val r = fileReader.hasNext && !fileReader.pastSync(stopPosition)
if (!r) {
fileReader.close()
completed = true
currentRow = None
} else {
val record = fileReader.next()
// the row must be deserialized in hasNextRow, because AvroDeserializer#deserialize
// potentially filters rows
currentRow = deserializer.deserialize(record).asInstanceOf[Option[InternalRow]]
}
}
currentRow.isDefined
}
def nextRow: InternalRow = {
if (currentRow.isEmpty) {
hasNextRow
}
val returnRow = currentRow
currentRow = None // free up hasNextRow to consume more Avro records, if not exhausted
returnRow.getOrElse {
throw new NoSuchElementException("next on empty iterator")
}
}
}
/** Wrapper for a pair of matched fields, one Catalyst and one corresponding Avro field. */
private[sql] case class AvroMatchedField(
catalystField: StructField,
catalystPosition: Int,
avroField: Schema.Field)
/**
* Helper class to perform field lookup/matching on Avro schemas.
*
* This will match `avroSchema` against `catalystSchema`, attempting to find a matching field in
* the Avro schema for each field in the Catalyst schema and vice-versa, respecting settings for
* case sensitivity. The match results can be accessed using the getter methods.
*
* @param avroSchema The schema in which to search for fields. Must be of type RECORD.
* @param catalystSchema The Catalyst schema to use for matching.
* @param avroPath The seq of parent field names leading to `avroSchema`.
* @param catalystPath The seq of parent field names leading to `catalystSchema`.
* @param positionalFieldMatch If true, perform field matching in a positional fashion
* (structural comparison between schemas, ignoring names);
* otherwise, perform field matching using field names.
*/
class AvroSchemaHelper(
avroSchema: Schema,
catalystSchema: StructType,
avroPath: Seq[String],
catalystPath: Seq[String],
positionalFieldMatch: Boolean) {
if (avroSchema.getType != Schema.Type.RECORD) {
throw new IncompatibleSchemaException(
s"Attempting to treat ${avroSchema.getName} as a RECORD, but it was: ${avroSchema.getType}")
}
private[this] val avroFieldArray = avroSchema.getFields.asScala.toArray
private[this] val fieldMap = avroSchema.getFields.asScala
.groupBy(_.name.toLowerCase(Locale.ROOT))
.mapValues(_.toSeq) // toSeq needed for scala 2.13
/** The fields which have matching equivalents in both Avro and Catalyst schemas. */
val matchedFields: Seq[AvroMatchedField] = catalystSchema.zipWithIndex.flatMap {
case (sqlField, sqlPos) =>
getAvroField(sqlField.name, sqlPos).map(AvroMatchedField(sqlField, sqlPos, _))
}
/**
* Validate that there are no Catalyst fields which don't have a matching Avro field, throwing
* [[IncompatibleSchemaException]] if such extra fields are found. If `ignoreNullable` is false,
* consider nullable Catalyst fields to be eligible to be an extra field; otherwise,
* ignore nullable Catalyst fields when checking for extras.
*/
def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit =
catalystSchema.zipWithIndex.foreach { case (sqlField, sqlPos) =>
if (getAvroField(sqlField.name, sqlPos).isEmpty &&
(!ignoreNullable || !sqlField.nullable)) {
if (positionalFieldMatch) {
throw new IncompatibleSchemaException("Cannot find field at position " +
s"$sqlPos of ${toFieldStr(avroPath)} from Avro schema (using positional matching)")
} else {
throw new IncompatibleSchemaException(
s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Avro schema")
}
}
}
/**
* Validate that there are no Avro fields which don't have a matching Catalyst field, throwing
* [[IncompatibleSchemaException]] if such extra fields are found. Only required (non-nullable)
* fields are checked; nullable fields are ignored.
*/
def validateNoExtraRequiredAvroFields(): Unit = {
val extraFields = avroFieldArray.toSet -- matchedFields.map(_.avroField)
extraFields.filterNot(isNullable).foreach { extraField =>
if (positionalFieldMatch) {
throw new IncompatibleSchemaException(s"Found field '${extraField.name()}' at position " +
s"${extraField.pos()} of ${toFieldStr(avroPath)} from Avro schema but there is no " +
s"match in the SQL schema at ${toFieldStr(catalystPath)} (using positional matching)")
} else {
throw new IncompatibleSchemaException(
s"Found ${toFieldStr(avroPath :+ extraField.name())} in Avro schema but there is no " +
"match in the SQL schema")
}
}
}
/**
* Extract a single field from the contained avro schema which has the desired field name,
* performing the matching with proper case sensitivity according to SQLConf.resolver.
*
* @param name The name of the field to search for.
* @return `Some(match)` if a matching Avro field is found, otherwise `None`.
*/
private[avro] def getFieldByName(name: String): Option[Schema.Field] = {
// get candidates, ignoring case of field name
val candidates = fieldMap.getOrElse(name.toLowerCase(Locale.ROOT), Seq.empty)
// search candidates, taking into account case sensitivity settings
candidates.filter(f => SQLConf.get.resolver(f.name(), name)) match {
case Seq(avroField) => Some(avroField)
case Seq() => None
case matches => throw new IncompatibleSchemaException(s"Searching for '$name' in Avro " +
s"schema at ${toFieldStr(avroPath)} gave ${matches.size} matches. Candidates: " +
matches.map(_.name()).mkString("[", ", ", "]")
)
}
}
/** Get the Avro field corresponding to the provided Catalyst field name/position, if any. */
def getAvroField(fieldName: String, catalystPos: Int): Option[Schema.Field] = {
if (positionalFieldMatch) {
avroFieldArray.lift(catalystPos)
} else {
getFieldByName(fieldName)
}
}
}
/**
* Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable
* string representing the field, like "field 'foo.bar'". If `names` is empty, the string
* "top-level record" is returned.
*/
private[avro] def toFieldStr(names: Seq[String]): String = names match {
case Seq() => "top-level record"
case n => s"field '${n.mkString(".")}'"
}
/** Return true iff `avroField` is nullable, i.e. `UNION` type and has `NULL` as an option. */
private[avro] def isNullable(avroField: Schema.Field): Boolean =
avroField.schema().getType == Schema.Type.UNION &&
avroField.schema().getTypes.asScala.exists(_.getType == Schema.Type.NULL)
}

View File

@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.avro
import org.apache.avro.Schema
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
class HoodieSpark3_3AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType)
extends HoodieAvroDeserializer {
private val avroDeserializer = new AvroDeserializer(rootAvroType, rootCatalystType,
SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_READ))
def deserialize(data: Any): Option[Any] = avroDeserializer.deserialize(data)
}

View File

@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.avro
import org.apache.avro.Schema
import org.apache.spark.sql.types.DataType
class HoodieSpark3_3AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean)
extends HoodieAvroSerializer {
val avroSerializer = new AvroSerializer(rootCatalystType, rootAvroType, nullable)
override def serialize(catalystData: Any): Any = avroSerializer.serialize(catalystData)
}

View File

@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
case class TimeTravelRelation(
table: LogicalPlan,
timestamp: Option[Expression],
version: Option[String]) extends Command {
override def children: Seq[LogicalPlan] = Seq.empty
override def output: Seq[Attribute] = Nil
override lazy val resolved: Boolean = false
def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = this
}

View File

@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.catalog
import java.util
import java.util.Objects
/**
* This class is to make scala-2.11 compilable.
* Using Identifier.of(namespace, name) to get a IdentifierImpl will throw
* compile exception( Static methods in interface require -target:jvm-1.8)
*/
case class HoodieIdentifier(namespace: Array[String], name: String) extends Identifier {
override def equals(o: Any): Boolean = {
o match {
case that: HoodieIdentifier => util.Arrays.equals(namespace.asInstanceOf[Array[Object]],
that.namespace.asInstanceOf[Array[Object]]) && name == that.name
case _ => false
}
}
override def hashCode: Int = {
val nh = namespace.toSeq.hashCode().asInstanceOf[Object]
Objects.hash(nh, name)
}
}

View File

@@ -0,0 +1,195 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources
import org.apache.hudi.HoodieBaseRelation
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, NamedExpression, ProjectionOverSchema}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.util.SchemaUtils.restoreOriginalOutputNames
/**
* Prunes unnecessary physical columns given a [[PhysicalOperation]] over a data source relation.
* By "physical column", we mean a column as defined in the data source format like Parquet format
* or ORC format. For example, in Spark SQL, a root-level Parquet column corresponds to a SQL
* column, and a nested Parquet column corresponds to a [[StructField]].
*
* NOTE: This class is borrowed from Spark 3.2.1, with modifications adapting it to handle [[HoodieBaseRelation]],
* instead of [[HadoopFsRelation]]
*/
class Spark33NestedSchemaPruning extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.expressions.SchemaPruning._
override def apply(plan: LogicalPlan): LogicalPlan =
if (conf.nestedSchemaPruningEnabled) {
apply0(plan)
} else {
plan
}
private def apply0(plan: LogicalPlan): LogicalPlan =
plan transformDown {
case op @ PhysicalOperation(projects, filters,
// NOTE: This is modified to accommodate for Hudi's custom relations, given that original
// [[NestedSchemaPruning]] rule is tightly coupled w/ [[HadoopFsRelation]]
// TODO generalize to any file-based relation
l @ LogicalRelation(relation: HoodieBaseRelation, _, _, _))
if relation.canPruneRelationSchema =>
prunePhysicalColumns(l.output, projects, filters, relation.dataSchema,
prunedDataSchema => {
val prunedRelation =
relation.updatePrunedDataSchema(prunedSchema = prunedDataSchema)
buildPrunedRelation(l, prunedRelation)
}).getOrElse(op)
}
/**
* This method returns optional logical plan. `None` is returned if no nested field is required or
* all nested fields are required.
*/
private def prunePhysicalColumns(output: Seq[AttributeReference],
projects: Seq[NamedExpression],
filters: Seq[Expression],
dataSchema: StructType,
outputRelationBuilder: StructType => LogicalRelation): Option[LogicalPlan] = {
val (normalizedProjects, normalizedFilters) =
normalizeAttributeRefNames(output, projects, filters)
val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)
// If requestedRootFields includes a nested field, continue. Otherwise,
// return op
if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) {
val prunedDataSchema = pruneSchema(dataSchema, requestedRootFields)
// If the data schema is different from the pruned data schema, continue. Otherwise,
// return op. We effect this comparison by counting the number of "leaf" fields in
// each schemata, assuming the fields in prunedDataSchema are a subset of the fields
// in dataSchema.
if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
val prunedRelation = outputRelationBuilder(prunedDataSchema)
val projectionOverSchema = ProjectionOverSchema(prunedDataSchema,AttributeSet(output))
Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
prunedRelation, projectionOverSchema))
} else {
None
}
} else {
None
}
}
/**
* Normalizes the names of the attribute references in the given projects and filters to reflect
* the names in the given logical relation. This makes it possible to compare attributes and
* fields by name. Returns a tuple with the normalized projects and filters, respectively.
*/
private def normalizeAttributeRefNames(output: Seq[AttributeReference],
projects: Seq[NamedExpression],
filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = {
val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap
val normalizedProjects = projects.map(_.transform {
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
}).map { case expr: NamedExpression => expr }
val normalizedFilters = filters.map(_.transform {
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
})
(normalizedProjects, normalizedFilters)
}
/**
* Builds the new output [[Project]] Spark SQL operator that has the `leafNode`.
*/
private def buildNewProjection(projects: Seq[NamedExpression],
normalizedProjects: Seq[NamedExpression],
filters: Seq[Expression],
prunedRelation: LogicalRelation,
projectionOverSchema: ProjectionOverSchema): Project = {
// Construct a new target for our projection by rewriting and
// including the original filters where available
val projectionChild =
if (filters.nonEmpty) {
val projectedFilters = filters.map(_.transformDown {
case projectionOverSchema(expr) => expr
})
val newFilterCondition = projectedFilters.reduce(And)
Filter(newFilterCondition, prunedRelation)
} else {
prunedRelation
}
// Construct the new projections of our Project by
// rewriting the original projections
val newProjects = normalizedProjects.map(_.transformDown {
case projectionOverSchema(expr) => expr
}).map { case expr: NamedExpression => expr }
if (log.isDebugEnabled) {
logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}")
}
Project(restoreOriginalOutputNames(newProjects, projects.map(_.name)), projectionChild)
}
/**
* Builds a pruned logical relation from the output of the output relation and the schema of the
* pruned base relation.
*/
private def buildPrunedRelation(outputRelation: LogicalRelation,
prunedBaseRelation: BaseRelation): LogicalRelation = {
val prunedOutput = getPrunedOutput(outputRelation.output, prunedBaseRelation.schema)
outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
}
// Prune the given output to make it consistent with `requiredSchema`.
private def getPrunedOutput(output: Seq[AttributeReference],
requiredSchema: StructType): Seq[AttributeReference] = {
// We need to replace the expression ids of the pruned relation output attributes
// with the expression ids of the original relation output attributes so that
// references to the original relation's output are not broken
val outputIdMap = output.map(att => (att.name, att.exprId)).toMap
requiredSchema
.toAttributes
.map {
case att if outputIdMap.contains(att.name) =>
att.withExprId(outputIdMap(att.name))
case att => att
}
}
/**
* Counts the "leaf" fields of the given dataType. Informally, this is the
* number of fields of non-complex data type in the tree representation of
* [[DataType]].
*/
private def countLeaves(dataType: DataType): Int = {
dataType match {
case array: ArrayType => countLeaves(array.elementType)
case map: MapType => countLeaves(map.keyType) + countLeaves(map.valueType)
case struct: StructType =>
struct.map(field => countLeaves(field.dataType)).sum
case _ => 1
}
}
}

View File

@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.parquet
import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.util.Utils
object Spark33DataSourceUtils {
/**
* NOTE: This method was copied from Spark 3.2.0, and is required to maintain runtime
* compatibility against Spark 3.2.0
*/
// scalastyle:off
def int96RebaseMode(lookupFileMeta: String => String,
modeByConfig: String): LegacyBehaviorPolicy.Value = {
if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") {
return LegacyBehaviorPolicy.CORRECTED
}
// If there is no version, we return the mode specified by the config.
Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version =>
// Files written by Spark 3.0 and earlier follow the legacy hybrid calendar and we need to
// rebase the INT96 timestamp values.
// Files written by Spark 3.1 and latter may also need the rebase if they were written with
// the "LEGACY" rebase mode.
if (version < "3.1.0" || lookupFileMeta("org.apache.spark.legacyINT96") != null) {
LegacyBehaviorPolicy.LEGACY
} else {
LegacyBehaviorPolicy.CORRECTED
}
}.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig))
}
// scalastyle:on
/**
* NOTE: This method was copied from Spark 3.2.0, and is required to maintain runtime
* compatibility against Spark 3.2.0
*/
// scalastyle:off
def datetimeRebaseMode(lookupFileMeta: String => String,
modeByConfig: String): LegacyBehaviorPolicy.Value = {
if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") {
return LegacyBehaviorPolicy.CORRECTED
}
// If there is no version, we return the mode specified by the config.
Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version =>
// Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to
// rebase the datetime values.
// Files written by Spark 3.0 and latter may also need the rebase if they were written with
// the "LEGACY" rebase mode.
if (version < "3.0.0" || lookupFileMeta("org.apache.spark.legacyDateTime") != null) {
LegacyBehaviorPolicy.LEGACY
} else {
LegacyBehaviorPolicy.CORRECTED
}
}.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig))
}
// scalastyle:on
}

View File

@@ -0,0 +1,505 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.parquet
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred.FileSplit
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType}
import org.apache.hudi.HoodieSparkUtils
import org.apache.hudi.client.utils.SparkInternalSchemaConverter
import org.apache.hudi.common.fs.FSUtils
import org.apache.hudi.common.util.InternalSchemaCache
import org.apache.hudi.common.util.StringUtils.isNullOrEmpty
import org.apache.hudi.common.util.collection.Pair
import org.apache.hudi.internal.schema.InternalSchema
import org.apache.hudi.internal.schema.action.InternalSchemaMerger
import org.apache.hudi.internal.schema.utils.{InternalSchemaUtils, SerDeHelper}
import org.apache.parquet.filter2.compat.FilterCompat
import org.apache.parquet.filter2.predicate.FilterApi
import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS
import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader}
import org.apache.spark.TaskContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.{Cast, JoinedRow}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.parquet.Spark33HoodieParquetFileFormat._
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{AtomicType, DataType, StructField, StructType}
import org.apache.spark.util.SerializableConfiguration
import java.net.URI
/**
* This class is an extension of [[ParquetFileFormat]] overriding Spark-specific behavior
* that's not possible to customize in any other way
*
* NOTE: This is a version of [[AvroDeserializer]] impl from Spark 3.2.1 w/ w/ the following changes applied to it:
* <ol>
* <li>Avoiding appending partition values to the rows read from the data file</li>
* <li>Schema on-read</li>
* </ol>
*/
class Spark33HoodieParquetFileFormat(private val shouldAppendPartitionValues: Boolean) extends ParquetFileFormat {
override def buildReaderWithPartitionValues(sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName)
hadoopConf.set(
ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
requiredSchema.json)
hadoopConf.set(
ParquetWriteSupport.SPARK_ROW_SCHEMA,
requiredSchema.json)
hadoopConf.set(
SQLConf.SESSION_LOCAL_TIMEZONE.key,
sparkSession.sessionState.conf.sessionLocalTimeZone)
hadoopConf.setBoolean(
SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key,
sparkSession.sessionState.conf.nestedSchemaPruningEnabled)
hadoopConf.setBoolean(
SQLConf.CASE_SENSITIVE.key,
sparkSession.sessionState.conf.caseSensitiveAnalysis)
ParquetWriteSupport.setSchema(requiredSchema, hadoopConf)
// Sets flags for `ParquetToSparkSchemaConverter`
hadoopConf.setBoolean(
SQLConf.PARQUET_BINARY_AS_STRING.key,
sparkSession.sessionState.conf.isParquetBinaryAsString)
hadoopConf.setBoolean(
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
sparkSession.sessionState.conf.isParquetINT96AsTimestamp)
val internalSchemaStr = hadoopConf.get(SparkInternalSchemaConverter.HOODIE_QUERY_SCHEMA)
// For Spark DataSource v1, there's no Physical Plan projection/schema pruning w/in Spark itself,
// therefore it's safe to do schema projection here
if (!isNullOrEmpty(internalSchemaStr)) {
val prunedInternalSchemaStr =
pruneInternalSchema(internalSchemaStr, requiredSchema)
hadoopConf.set(SparkInternalSchemaConverter.HOODIE_QUERY_SCHEMA, prunedInternalSchemaStr)
}
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
// TODO: if you move this into the closure it reverts to the default values.
// If true, enable using the custom RecordReader for parquet. This only works for
// a subset of the types (no complex types).
val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields)
val sqlConf = sparkSession.sessionState.conf
val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled
val enableVectorizedReader: Boolean =
sqlConf.parquetVectorizedReaderEnabled &&
resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled
val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion
val capacity = sqlConf.parquetVectorizedReaderBatchSize
val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown
// Whole stage codegen (PhysicalRDD) is able to deal with batches directly
val returningBatch = supportBatch(sparkSession, resultSchema)
val pushDownDate = sqlConf.parquetFilterPushDownDate
val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp
val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal
val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith
val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
val isCaseSensitive = sqlConf.caseSensitiveAnalysis
val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf)
val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead
val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead
(file: PartitionedFile) => {
assert(!shouldAppendPartitionValues || file.partitionValues.numFields == partitionSchema.size)
val filePath = new Path(new URI(file.filePath))
val split = new FileSplit(filePath, file.start, file.length, Array.empty[String])
val sharedConf = broadcastedHadoopConf.value.value
// Fetch internal schema
val internalSchemaStr = sharedConf.get(SparkInternalSchemaConverter.HOODIE_QUERY_SCHEMA)
// Internal schema has to be pruned at this point
val querySchemaOption = SerDeHelper.fromJson(internalSchemaStr)
val shouldUseInternalSchema = !isNullOrEmpty(internalSchemaStr) && querySchemaOption.isPresent
val tablePath = sharedConf.get(SparkInternalSchemaConverter.HOODIE_TABLE_PATH)
val fileSchema = if (shouldUseInternalSchema) {
val commitInstantTime = FSUtils.getCommitTime(filePath.getName).toLong;
val validCommits = sharedConf.get(SparkInternalSchemaConverter.HOODIE_VALID_COMMITS_LIST)
InternalSchemaCache.getInternalSchemaByVersionId(commitInstantTime, tablePath, sharedConf, if (validCommits == null) "" else validCommits)
} else {
null
}
lazy val footerFileMetaData =
ParquetFooterReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData
// Try to push down filters when filter push-down is enabled.
val pushed = if (enableParquetFilterPushDown) {
val parquetSchema = footerFileMetaData.getSchema
val parquetFilters = if (HoodieSparkUtils.gteqSpark3_2_1) {
// NOTE: Below code could only be compiled against >= Spark 3.2.1,
// and unfortunately won't compile against Spark 3.2.0
// However this code is runtime-compatible w/ both Spark 3.2.0 and >= Spark 3.2.1
val datetimeRebaseSpec =
DataSourceUtils.datetimeRebaseSpec(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead)
new ParquetFilters(
parquetSchema,
pushDownDate,
pushDownTimestamp,
pushDownDecimal,
pushDownStringStartWith,
pushDownInFilterThreshold,
isCaseSensitive,
datetimeRebaseSpec)
} else {
// Spark 3.2.0
val datetimeRebaseMode =
Spark33DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead)
createParquetFilters(
parquetSchema,
pushDownDate,
pushDownTimestamp,
pushDownDecimal,
pushDownStringStartWith,
pushDownInFilterThreshold,
isCaseSensitive,
datetimeRebaseMode)
}
filters.map(rebuildFilterFromParquet(_, fileSchema, querySchemaOption.orElse(null)))
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
// converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
// is used here.
.flatMap(parquetFilters.createFilter)
.reduceOption(FilterApi.and)
} else {
None
}
// PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps'
// *only* if the file was created by something other than "parquet-mr", so check the actual
// writer here for this file. We have to do this per-file, as each file in the table may
// have different writers.
// Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads.
def isCreatedByParquetMr: Boolean =
footerFileMetaData.getCreatedBy().startsWith("parquet-mr")
val convertTz =
if (timestampConversion && !isCreatedByParquetMr) {
Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key)))
} else {
None
}
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
// Clone new conf
val hadoopAttemptConf = new Configuration(broadcastedHadoopConf.value.value)
val typeChangeInfos: java.util.Map[Integer, Pair[DataType, DataType]] = if (shouldUseInternalSchema) {
val mergedInternalSchema = new InternalSchemaMerger(fileSchema, querySchemaOption.get(), true, true).mergeSchema()
val mergedSchema = SparkInternalSchemaConverter.constructSparkSchemaFromInternalSchema(mergedInternalSchema)
hadoopAttemptConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, mergedSchema.json)
SparkInternalSchemaConverter.collectTypeChangedCols(querySchemaOption.get(), mergedInternalSchema)
} else {
new java.util.HashMap()
}
val hadoopAttemptContext =
new TaskAttemptContextImpl(hadoopAttemptConf, attemptId)
// Try to push down filters when filter push-down is enabled.
// Notice: This push-down is RowGroups level, not individual records.
if (pushed.isDefined) {
ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get)
}
val taskContext = Option(TaskContext.get())
if (enableVectorizedReader) {
val vectorizedReader =
if (shouldUseInternalSchema) {
val int96RebaseSpec =
DataSourceUtils.int96RebaseSpec(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead)
val datetimeRebaseSpec =
DataSourceUtils.datetimeRebaseSpec(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead)
new Spark33HoodieVectorizedParquetRecordReader(
convertTz.orNull,
datetimeRebaseSpec.mode.toString,
datetimeRebaseSpec.timeZone,
int96RebaseSpec.mode.toString,
int96RebaseSpec.timeZone,
enableOffHeapColumnVector && taskContext.isDefined,
capacity,
typeChangeInfos)
} else if (HoodieSparkUtils.gteqSpark3_2_1) {
// NOTE: Below code could only be compiled against >= Spark 3.2.1,
// and unfortunately won't compile against Spark 3.2.0
// However this code is runtime-compatible w/ both Spark 3.2.0 and >= Spark 3.2.1
val int96RebaseSpec =
DataSourceUtils.int96RebaseSpec(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead)
val datetimeRebaseSpec =
DataSourceUtils.datetimeRebaseSpec(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead)
new VectorizedParquetRecordReader(
convertTz.orNull,
datetimeRebaseSpec.mode.toString,
datetimeRebaseSpec.timeZone,
int96RebaseSpec.mode.toString,
int96RebaseSpec.timeZone,
enableOffHeapColumnVector && taskContext.isDefined,
capacity)
} else {
// Spark 3.2.0
val datetimeRebaseMode =
Spark33DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead)
val int96RebaseMode =
Spark33DataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead)
createVectorizedParquetRecordReader(
convertTz.orNull,
datetimeRebaseMode.toString,
int96RebaseMode.toString,
enableOffHeapColumnVector && taskContext.isDefined,
capacity)
}
// SPARK-37089: We cannot register a task completion listener to close this iterator here
// because downstream exec nodes have already registered their listeners. Since listeners
// are executed in reverse order of registration, a listener registered here would close the
// iterator while downstream exec nodes are still running. When off-heap column vectors are
// enabled, this can cause a use-after-free bug leading to a segfault.
//
// Instead, we use FileScanRDD's task completion listener to close this iterator.
val iter = new RecordReaderIterator(vectorizedReader)
try {
vectorizedReader.initialize(split, hadoopAttemptContext)
// NOTE: We're making appending of the partitioned values to the rows read from the
// data file configurable
if (shouldAppendPartitionValues) {
logDebug(s"Appending $partitionSchema ${file.partitionValues}")
vectorizedReader.initBatch(partitionSchema, file.partitionValues)
} else {
vectorizedReader.initBatch(StructType(Nil), InternalRow.empty)
}
if (returningBatch) {
vectorizedReader.enableReturningBatches()
}
// UnsafeRowParquetRecordReader appends the columns internally to avoid another copy.
iter.asInstanceOf[Iterator[InternalRow]]
} catch {
case e: Throwable =>
// SPARK-23457: In case there is an exception in initialization, close the iterator to
// avoid leaking resources.
iter.close()
throw e
}
} else {
logDebug(s"Falling back to parquet-mr")
val readSupport = if (HoodieSparkUtils.gteqSpark3_2_1) {
// ParquetRecordReader returns InternalRow
// NOTE: Below code could only be compiled against >= Spark 3.2.1,
// and unfortunately won't compile against Spark 3.2.0
// However this code is runtime-compatible w/ both Spark 3.2.0 and >= Spark 3.2.1
val int96RebaseSpec =
DataSourceUtils.int96RebaseSpec(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead)
val datetimeRebaseSpec =
DataSourceUtils.datetimeRebaseSpec(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead)
new ParquetReadSupport(
convertTz,
enableVectorizedReader = false,
datetimeRebaseSpec,
int96RebaseSpec)
} else {
val datetimeRebaseMode =
Spark33DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead)
val int96RebaseMode =
Spark33DataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead)
createParquetReadSupport(
convertTz,
/* enableVectorizedReader = */ false,
datetimeRebaseMode,
int96RebaseMode)
}
val reader = if (pushed.isDefined && enableRecordFilter) {
val parquetFilter = FilterCompat.get(pushed.get, null)
new ParquetRecordReader[InternalRow](readSupport, parquetFilter)
} else {
new ParquetRecordReader[InternalRow](readSupport)
}
val iter = new RecordReaderIterator[InternalRow](reader)
try {
reader.initialize(split, hadoopAttemptContext)
val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
val unsafeProjection = if (typeChangeInfos.isEmpty) {
GenerateUnsafeProjection.generate(fullSchema, fullSchema)
} else {
// find type changed.
val newFullSchema = new StructType(requiredSchema.fields.zipWithIndex.map { case (f, i) =>
if (typeChangeInfos.containsKey(i)) {
StructField(f.name, typeChangeInfos.get(i).getRight, f.nullable, f.metadata)
} else f
}).toAttributes ++ partitionSchema.toAttributes
val castSchema = newFullSchema.zipWithIndex.map { case (attr, i) =>
if (typeChangeInfos.containsKey(i)) {
Cast(attr, typeChangeInfos.get(i).getLeft)
} else attr
}
GenerateUnsafeProjection.generate(castSchema, newFullSchema)
}
// NOTE: We're making appending of the partitioned values to the rows read from the
// data file configurable
if (!shouldAppendPartitionValues || partitionSchema.length == 0) {
// There is no partition columns
iter.map(unsafeProjection)
} else {
val joinedRow = new JoinedRow()
iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues)))
}
} catch {
case e: Throwable =>
// SPARK-23457: In case there is an exception in initialization, close the iterator to
// avoid leaking resources.
iter.close()
throw e
}
}
}
}
}
object Spark33HoodieParquetFileFormat {
/**
* NOTE: This method is specific to Spark 3.2.0
*/
private def createParquetFilters(args: Any*): ParquetFilters = {
// NOTE: ParquetFilters ctor args contain Scala enum, therefore we can't look it
// up by arg types, and have to instead rely on the number of args based on individual class;
// the ctor order is not guaranteed
val ctor = classOf[ParquetFilters].getConstructors.maxBy(_.getParameterCount)
ctor.newInstance(args.map(_.asInstanceOf[AnyRef]): _*)
.asInstanceOf[ParquetFilters]
}
/**
* NOTE: This method is specific to Spark 3.2.0
*/
private def createParquetReadSupport(args: Any*): ParquetReadSupport = {
// NOTE: ParquetReadSupport ctor args contain Scala enum, therefore we can't look it
// up by arg types, and have to instead rely on the number of args based on individual class;
// the ctor order is not guaranteed
val ctor = classOf[ParquetReadSupport].getConstructors.maxBy(_.getParameterCount)
ctor.newInstance(args.map(_.asInstanceOf[AnyRef]): _*)
.asInstanceOf[ParquetReadSupport]
}
/**
* NOTE: This method is specific to Spark 3.2.0
*/
private def createVectorizedParquetRecordReader(args: Any*): VectorizedParquetRecordReader = {
// NOTE: ParquetReadSupport ctor args contain Scala enum, therefore we can't look it
// up by arg types, and have to instead rely on the number of args based on individual class;
// the ctor order is not guaranteed
val ctor = classOf[VectorizedParquetRecordReader].getConstructors.maxBy(_.getParameterCount)
ctor.newInstance(args.map(_.asInstanceOf[AnyRef]): _*)
.asInstanceOf[VectorizedParquetRecordReader]
}
def pruneInternalSchema(internalSchemaStr: String, requiredSchema: StructType): String = {
val querySchemaOption = SerDeHelper.fromJson(internalSchemaStr)
if (querySchemaOption.isPresent && requiredSchema.nonEmpty) {
val prunedSchema = SparkInternalSchemaConverter.convertAndPruneStructTypeToInternalSchema(requiredSchema, querySchemaOption.get())
SerDeHelper.toJson(prunedSchema)
} else {
internalSchemaStr
}
}
private def rebuildFilterFromParquet(oldFilter: Filter, fileSchema: InternalSchema, querySchema: InternalSchema): Filter = {
if (fileSchema == null || querySchema == null) {
oldFilter
} else {
oldFilter match {
case eq: EqualTo =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(eq.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else eq.copy(attribute = newAttribute)
case eqs: EqualNullSafe =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(eqs.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else eqs.copy(attribute = newAttribute)
case gt: GreaterThan =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(gt.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else gt.copy(attribute = newAttribute)
case gtr: GreaterThanOrEqual =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(gtr.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else gtr.copy(attribute = newAttribute)
case lt: LessThan =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(lt.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else lt.copy(attribute = newAttribute)
case lte: LessThanOrEqual =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(lte.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else lte.copy(attribute = newAttribute)
case i: In =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(i.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else i.copy(attribute = newAttribute)
case isn: IsNull =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(isn.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else isn.copy(attribute = newAttribute)
case isnn: IsNotNull =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(isnn.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else isnn.copy(attribute = newAttribute)
case And(left, right) =>
And(rebuildFilterFromParquet(left, fileSchema, querySchema), rebuildFilterFromParquet(right, fileSchema, querySchema))
case Or(left, right) =>
Or(rebuildFilterFromParquet(left, fileSchema, querySchema), rebuildFilterFromParquet(right, fileSchema, querySchema))
case Not(child) =>
Not(rebuildFilterFromParquet(child, fileSchema, querySchema))
case ssw: StringStartsWith =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(ssw.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else ssw.copy(attribute = newAttribute)
case ses: StringEndsWith =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(ses.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else ses.copy(attribute = newAttribute)
case sc: StringContains =>
val newAttribute = InternalSchemaUtils.reBuildFilterName(sc.attribute, fileSchema, querySchema)
if (newAttribute.isEmpty) AlwaysTrue else sc.copy(attribute = newAttribute)
case AlwaysTrue =>
AlwaysTrue
case AlwaysFalse =>
AlwaysFalse
case _ =>
AlwaysTrue
}
}
}
}

View File

@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import org.apache.hudi.common.config.HoodieCommonConfig
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.internal.schema.action.TableChange.ColumnChangeID
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.ResolvedTable
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.hudi.catalog.HoodieInternalV2Table
import org.apache.spark.sql.hudi.command.{AlterTableCommand => HudiAlterTableCommand}
/**
* Rule to mostly resolve, normalize and rewrite column names based on case sensitivity.
* for alter table column commands.
*/
class Spark33ResolveHudiAlterTableCommand(sparkSession: SparkSession) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
if (schemaEvolutionEnabled) {
plan.resolveOperatorsUp {
case set@SetTableProperties(ResolvedHoodieV2TablePlan(t), _) if set.resolved =>
HudiAlterTableCommand(t.v1Table, set.changes, ColumnChangeID.PROPERTY_CHANGE)
case unSet@UnsetTableProperties(ResolvedHoodieV2TablePlan(t), _, _) if unSet.resolved =>
HudiAlterTableCommand(t.v1Table, unSet.changes, ColumnChangeID.PROPERTY_CHANGE)
case drop@DropColumns(ResolvedHoodieV2TablePlan(t), _, _) if drop.resolved =>
HudiAlterTableCommand(t.v1Table, drop.changes, ColumnChangeID.DELETE)
case add@AddColumns(ResolvedHoodieV2TablePlan(t), _) if add.resolved =>
HudiAlterTableCommand(t.v1Table, add.changes, ColumnChangeID.ADD)
case renameColumn@RenameColumn(ResolvedHoodieV2TablePlan(t), _, _) if renameColumn.resolved =>
HudiAlterTableCommand(t.v1Table, renameColumn.changes, ColumnChangeID.UPDATE)
case alter@AlterColumn(ResolvedHoodieV2TablePlan(t), _, _, _, _, _) if alter.resolved =>
HudiAlterTableCommand(t.v1Table, alter.changes, ColumnChangeID.UPDATE)
case replace@ReplaceColumns(ResolvedHoodieV2TablePlan(t), _) if replace.resolved =>
HudiAlterTableCommand(t.v1Table, replace.changes, ColumnChangeID.REPLACE)
}
} else {
plan
}
}
private def schemaEvolutionEnabled: Boolean =
sparkSession.sessionState.conf.getConfString(HoodieCommonConfig.SCHEMA_EVOLUTION_ENABLE.key,
HoodieCommonConfig.SCHEMA_EVOLUTION_ENABLE.defaultValue.toString).toBoolean
object ResolvedHoodieV2TablePlan {
def unapply(plan: LogicalPlan): Option[HoodieInternalV2Table] = {
plan match {
case ResolvedTable(_, _, v2Table: HoodieInternalV2Table, _) => Some(v2Table)
case _ => None
}
}
}
}

View File

@@ -0,0 +1,222 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.analysis
import org.apache.hudi.{DefaultSource, SparkAdapterSupport}
import org.apache.hudi.common.table.HoodieTableMetaClient
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{ResolvedTable, UnresolvedPartitionSpec}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HoodieCatalogTable}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
import org.apache.spark.sql.connector.catalog.{Table, V1Table}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.PreWriteCheck.failAnalysis
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, V2SessionCatalog}
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{castIfNeeded, getTableLocation, removeMetaFields, tableExistsInPath}
import org.apache.spark.sql.hudi.catalog.{HoodieCatalog, HoodieInternalV2Table}
import org.apache.spark.sql.hudi.command.{AlterHoodieTableDropPartitionCommand, ShowHoodieTablePartitionsCommand, TruncateHoodieTableCommand}
import org.apache.spark.sql.hudi.{HoodieSqlCommonUtils, ProvidesHoodieConfig}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{AnalysisException, SQLContext, SparkSession}
import scala.collection.JavaConverters.mapAsJavaMapConverter
/**
* NOTE: PLEASE READ CAREFULLY
*
* Since Hudi relations don't currently implement DS V2 Read API, we have to fallback to V1 here.
* Such fallback will have considerable performance impact, therefore it's only performed in cases
* where V2 API have to be used. Currently only such use-case is using of Schema Evolution feature
*
* Check out HUDI-4178 for more details
*/
class HoodieDataSourceV2ToV1Fallback(sparkSession: SparkSession) extends Rule[LogicalPlan]
with ProvidesHoodieConfig {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown {
case v2r @ DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _) =>
val output = v2r.output
val catalogTable = v2Table.catalogTable.map(_ => v2Table.v1Table)
val relation = new DefaultSource().createRelation(new SQLContext(sparkSession),
buildHoodieConfig(v2Table.hoodieCatalogTable), v2Table.hoodieCatalogTable.tableSchema)
LogicalRelation(relation, output, catalogTable, isStreaming = false)
}
}
class HoodieSpark3Analysis(sparkSession: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown {
case s @ InsertIntoStatement(r @ DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _), partitionSpec, _, _, _, _)
if s.query.resolved && needsSchemaAdjustment(s.query, v2Table.hoodieCatalogTable.table, partitionSpec, r.schema) =>
val projection = resolveQueryColumnsByOrdinal(s.query, r.output)
if (projection != s.query) {
s.copy(query = projection)
} else {
s
}
}
/**
* Need to adjust schema based on the query and relation schema, for example,
* if using insert into xx select 1, 2 here need to map to column names
*/
private def needsSchemaAdjustment(query: LogicalPlan,
table: CatalogTable,
partitionSpec: Map[String, Option[String]],
schema: StructType): Boolean = {
val output = query.output
val queryOutputWithoutMetaFields = removeMetaFields(output)
val hoodieCatalogTable = HoodieCatalogTable(sparkSession, table)
val partitionFields = hoodieCatalogTable.partitionFields
val partitionSchema = hoodieCatalogTable.partitionSchema
val staticPartitionValues = partitionSpec.filter(p => p._2.isDefined).mapValues(_.get)
assert(staticPartitionValues.isEmpty ||
staticPartitionValues.size == partitionSchema.size,
s"Required partition columns is: ${partitionSchema.json}, Current static partitions " +
s"is: ${staticPartitionValues.mkString("," + "")}")
assert(staticPartitionValues.size + queryOutputWithoutMetaFields.size
== hoodieCatalogTable.tableSchemaWithoutMetaFields.size,
s"Required select columns count: ${hoodieCatalogTable.tableSchemaWithoutMetaFields.size}, " +
s"Current select columns(including static partition column) count: " +
s"${staticPartitionValues.size + queryOutputWithoutMetaFields.size}columns: " +
s"(${(queryOutputWithoutMetaFields.map(_.name) ++ staticPartitionValues.keys).mkString(",")})")
// static partition insert.
if (staticPartitionValues.nonEmpty) {
// drop partition fields in origin schema to align fields.
schema.dropWhile(p => partitionFields.contains(p.name))
}
val existingSchemaOutput = output.take(schema.length)
existingSchemaOutput.map(_.name) != schema.map(_.name) ||
existingSchemaOutput.map(_.dataType) != schema.map(_.dataType)
}
private def resolveQueryColumnsByOrdinal(query: LogicalPlan,
targetAttrs: Seq[Attribute]): LogicalPlan = {
// always add a Cast. it will be removed in the optimizer if it is unnecessary.
val project = query.output.zipWithIndex.map { case (attr, i) =>
if (i < targetAttrs.length) {
val targetAttr = targetAttrs(i)
val castAttr = castIfNeeded(attr.withNullability(targetAttr.nullable), targetAttr.dataType, conf)
Alias(castAttr, targetAttr.name)()
} else {
attr
}
}
Project(project, query)
}
}
/**
* Rule for resolve hoodie's extended syntax or rewrite some logical plan.
*/
case class HoodieSpark3ResolveReferences(sparkSession: SparkSession) extends Rule[LogicalPlan]
with SparkAdapterSupport with ProvidesHoodieConfig {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
// Fill schema for Create Table without specify schema info
// CreateTable / CreateTableAsSelect was migrated to v2 in Spark 3.3.0
// https://issues.apache.org/jira/browse/SPARK-36850
case c @ CreateTable(tableCatalog, schema, partitioning, tableSpec, _)
if sparkAdapter.isHoodieTable(tableSpec.properties.asJava) =>
if (schema.isEmpty && partitioning.nonEmpty) {
failAnalysis("It is not allowed to specify partition columns when the table schema is " +
"not defined. When the table schema is not provided, schema and partition columns " +
"will be inferred.")
}
val hoodieCatalog = tableCatalog match {
case catalog: HoodieCatalog => catalog
case _ => tableCatalog.asInstanceOf[V2SessionCatalog]
}
val tablePath = getTableLocation(tableSpec.properties,
TableIdentifier(c.tableName.name(), c.tableName.namespace().lastOption)
, sparkSession)
val tableExistInCatalog = hoodieCatalog.tableExists(c.tableName)
// Only when the table has not exist in catalog, we need to fill the schema info for creating table.
if (!tableExistInCatalog && tableExistsInPath(tablePath, sparkSession.sessionState.newHadoopConf())) {
val metaClient = HoodieTableMetaClient.builder()
.setBasePath(tablePath)
.setConf(sparkSession.sessionState.newHadoopConf())
.build()
val tableSchema = HoodieSqlCommonUtils.getTableSqlSchema(metaClient)
if (tableSchema.isDefined && schema.isEmpty) {
// Fill the schema with the schema from the table
c.copy(tableSchema = tableSchema.get)
} else if (tableSchema.isDefined && schema != tableSchema.get) {
throw new AnalysisException(s"Specified schema in create table statement is not equal to the table schema." +
s"You should not specify the schema for an existing table: ${c.tableName.name()} ")
} else {
c
}
} else {
c
}
case p => p
}
}
/**
* Rule replacing resolved Spark's commands (not working for Hudi tables out-of-the-box) with
* corresponding Hudi implementations
*/
case class HoodieSpark3PostAnalysisRule(sparkSession: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan match {
case ShowPartitions(ResolvedTable(_, id, HoodieV1OrV2Table(_), _), specOpt, _) =>
ShowHoodieTablePartitionsCommand(
id.asTableIdentifier, specOpt.map(s => s.asInstanceOf[UnresolvedPartitionSpec].spec))
// Rewrite TruncateTableCommand to TruncateHoodieTableCommand
case TruncateTable(ResolvedTable(_, id, HoodieV1OrV2Table(_), _)) =>
TruncateHoodieTableCommand(id.asTableIdentifier, None)
case TruncatePartition(ResolvedTable(_, id, HoodieV1OrV2Table(_), _), partitionSpec: UnresolvedPartitionSpec) =>
TruncateHoodieTableCommand(id.asTableIdentifier, Some(partitionSpec.spec))
case DropPartitions(ResolvedTable(_, id, HoodieV1OrV2Table(_), _), specs, ifExists, purge) =>
AlterHoodieTableDropPartitionCommand(
id.asTableIdentifier,
specs.seq.map(f => f.asInstanceOf[UnresolvedPartitionSpec]).map(s => s.spec),
ifExists,
purge,
retainData = true
)
case _ => plan
}
}
}
private[sql] object HoodieV1OrV2Table extends SparkAdapterSupport {
def unapply(table: Table): Option[CatalogTable] = table match {
case V1Table(catalogTable) if sparkAdapter.isHoodieTable(catalogTable) => Some(catalogTable)
case v2: HoodieInternalV2Table => v2.catalogTable
case _ => None
}
}

View File

@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.catalog
import org.apache.hudi.exception.HoodieException
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.types.StructType
import java.util
/**
* Basic implementation that represents a table which is staged for being committed.
* @param ident table ident
* @param table table
* @param catalog table catalog
*/
case class BasicStagedTable(ident: Identifier,
table: Table,
catalog: TableCatalog) extends SupportsWrite with StagedTable {
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
info match {
case supportsWrite: SupportsWrite => supportsWrite.newWriteBuilder(info)
case _ => throw new HoodieException(s"Table `${ident.name}` does not support writes.")
}
}
override def abortStagedChanges(): Unit = catalog.dropTable(ident)
override def commitStagedChanges(): Unit = {}
override def name(): String = ident.name()
override def schema(): StructType = table.schema()
override def partitioning(): Array[Transform] = table.partitioning()
override def capabilities(): util.Set[TableCapability] = table.capabilities()
override def properties(): util.Map[String, String] = table.properties()
}

View File

@@ -355,7 +355,7 @@ object HoodieCatalog {
identityCols += col
case BucketTransform(numBuckets, FieldReference(Seq(col))) =>
case BucketTransform(numBuckets, Seq(FieldReference(Seq(col))), _) =>
bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil))
case _ =>

View File

@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.catalog
import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HoodieCatalogTable}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, V1Table, V2TableWithV1Fallback}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.hudi.ProvidesHoodieConfig
import org.apache.spark.sql.sources.{Filter, InsertableRelation}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import java.util
import scala.collection.JavaConverters.{mapAsJavaMapConverter, setAsJavaSetConverter}
case class HoodieInternalV2Table(spark: SparkSession,
path: String,
catalogTable: Option[CatalogTable] = None,
tableIdentifier: Option[String] = None,
options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty())
extends Table with SupportsWrite with V2TableWithV1Fallback {
lazy val hoodieCatalogTable: HoodieCatalogTable = if (catalogTable.isDefined) {
HoodieCatalogTable(spark, catalogTable.get)
} else {
val metaClient: HoodieTableMetaClient = HoodieTableMetaClient.builder()
.setBasePath(path)
.setConf(SparkSession.active.sessionState.newHadoopConf)
.build()
val tableConfig: HoodieTableConfig = metaClient.getTableConfig
val tableName: String = tableConfig.getTableName
HoodieCatalogTable(spark, TableIdentifier(tableName))
}
private lazy val tableSchema: StructType = hoodieCatalogTable.tableSchema
override def name(): String = hoodieCatalogTable.table.identifier.unquotedString
override def schema(): StructType = tableSchema
override def capabilities(): util.Set[TableCapability] = Set(
BATCH_READ, V1_BATCH_WRITE, OVERWRITE_BY_FILTER, TRUNCATE, ACCEPT_ANY_SCHEMA
).asJava
override def properties(): util.Map[String, String] = {
hoodieCatalogTable.catalogProperties.asJava
}
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new HoodieV1WriteBuilder(info.options, hoodieCatalogTable, spark)
}
override def v1Table: CatalogTable = hoodieCatalogTable.table
def v1TableWrapper: V1Table = V1Table(v1Table)
override def partitioning(): Array[Transform] = {
hoodieCatalogTable.partitionFields.map { col =>
new IdentityTransform(new FieldReference(Seq(col)))
}.toArray
}
}
private class HoodieV1WriteBuilder(writeOptions: CaseInsensitiveStringMap,
hoodieCatalogTable: HoodieCatalogTable,
spark: SparkSession)
extends SupportsTruncate with SupportsOverwrite with ProvidesHoodieConfig {
private var forceOverwrite = false
override def truncate(): HoodieV1WriteBuilder = {
forceOverwrite = true
this
}
override def overwrite(filters: Array[Filter]): WriteBuilder = {
forceOverwrite = true
this
}
override def build(): V1Write = new V1Write {
override def toInsertableRelation: InsertableRelation = {
new InsertableRelation {
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
val mode = if (forceOverwrite && hoodieCatalogTable.partitionFields.isEmpty) {
// insert overwrite non-partition table
SaveMode.Overwrite
} else {
// for insert into or insert overwrite partition we use append mode.
SaveMode.Append
}
alignOutputColumns(data).write.format("org.apache.hudi")
.mode(mode)
.options(buildHoodieConfig(hoodieCatalogTable) ++
buildHoodieInsertConfig(hoodieCatalogTable, spark, forceOverwrite, Map.empty, Map.empty))
.save()
}
}
}
}
private def alignOutputColumns(data: DataFrame): DataFrame = {
val schema = hoodieCatalogTable.tableSchema
spark.createDataFrame(data.toJavaRDD, schema)
}
}

View File

@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.catalog
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hudi.DataSourceWriteOptions.RECORDKEY_FIELD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, SupportsWrite, TableCapability}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, V1Write, WriteBuilder}
import org.apache.spark.sql.types.StructType
import java.net.URI
import java.util
import scala.collection.JavaConverters.{mapAsScalaMapConverter, setAsJavaSetConverter}
case class HoodieStagedTable(ident: Identifier,
locUriAndTableType: (URI, CatalogTableType),
catalog: HoodieCatalog,
override val schema: StructType,
partitions: Array[Transform],
override val properties: util.Map[String, String],
mode: TableCreationMode) extends StagedTable with SupportsWrite {
private var sourceQuery: Option[DataFrame] = None
private var writeOptions: Map[String, String] = Map.empty
override def commitStagedChanges(): Unit = {
val props = new util.HashMap[String, String]()
val optionsThroughProperties = properties.asScala.collect {
case (k, _) if k.startsWith("option.") => k.stripPrefix("option.")
}.toSet
val sqlWriteOptions = new util.HashMap[String, String]()
properties.asScala.foreach { case (k, v) =>
if (!k.startsWith("option.") && !optionsThroughProperties.contains(k)) {
props.put(k, v)
} else if (optionsThroughProperties.contains(k)) {
sqlWriteOptions.put(k, v)
}
}
if (writeOptions.isEmpty && !sqlWriteOptions.isEmpty) {
writeOptions = sqlWriteOptions.asScala.toMap
}
props.putAll(properties)
props.put("hoodie.table.name", ident.name())
props.put(RECORDKEY_FIELD.key, properties.get("primaryKey"))
catalog.createHoodieTable(
ident, schema, locUriAndTableType, partitions, props, writeOptions, sourceQuery, mode)
}
override def name(): String = ident.name()
override def abortStagedChanges(): Unit = {
clearTablePath(locUriAndTableType._1.getPath, catalog.spark.sparkContext.hadoopConfiguration)
}
private def clearTablePath(tablePath: String, conf: Configuration): Unit = {
val path = new Path(tablePath)
val fs = path.getFileSystem(conf)
fs.delete(path, true)
}
override def capabilities(): util.Set[TableCapability] = Set(TableCapability.V1_BATCH_WRITE).asJava
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
writeOptions = info.options.asCaseSensitiveMap().asScala.toMap
new HoodieV1WriteBuilder
}
/*
* WriteBuilder for creating a Hoodie table.
*/
private class HoodieV1WriteBuilder extends WriteBuilder {
override def build(): V1Write = () => {
(data: DataFrame, overwrite: Boolean) => {
sourceQuery = Option(data)
}
}
}
}

View File

@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.catalog;
public enum TableCreationMode {
CREATE, CREATE_OR_REPLACE, STAGE_CREATE, STAGE_REPLACE
}

View File

@@ -0,0 +1,347 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command
import java.net.URI
import java.nio.charset.StandardCharsets
import java.util
import java.util.concurrent.atomic.AtomicInteger
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.client.utils.SparkInternalSchemaConverter
import org.apache.hudi.common.model.{HoodieCommitMetadata, WriteOperationType}
import org.apache.hudi.{DataSourceOptionsHelper, DataSourceUtils}
import org.apache.hudi.common.table.timeline.{HoodieActiveTimeline, HoodieInstant}
import org.apache.hudi.common.table.timeline.HoodieInstant.State
import org.apache.hudi.common.table.{HoodieTableMetaClient, TableSchemaResolver}
import org.apache.hudi.common.util.{CommitUtils, Option}
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.internal.schema.InternalSchema
import org.apache.hudi.internal.schema.action.TableChange.ColumnChangeID
import org.apache.hudi.internal.schema.action.TableChanges
import org.apache.hudi.internal.schema.convert.AvroInternalSchemaConverter
import org.apache.hudi.internal.schema.utils.{SchemaChangeUtils, SerDeHelper}
import org.apache.hudi.internal.schema.io.FileBasedInternalSchemaStorageManager
import org.apache.hudi.table.HoodieSparkTable
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
import org.apache.spark.sql.connector.catalog.{TableCatalog, TableChange}
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, DeleteColumn, RemoveProperty, SetProperty}
import org.apache.spark.sql.types.StructType
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
case class AlterTableCommand(table: CatalogTable, changes: Seq[TableChange], changeType: ColumnChangeID) extends HoodieLeafRunnableCommand with Logging {
override def run(sparkSession: SparkSession): Seq[Row] = {
changeType match {
case ColumnChangeID.ADD => applyAddAction(sparkSession)
case ColumnChangeID.DELETE => applyDeleteAction(sparkSession)
case ColumnChangeID.UPDATE => applyUpdateAction(sparkSession)
case ColumnChangeID.PROPERTY_CHANGE if (changes.filter(_.isInstanceOf[SetProperty]).size == changes.size) =>
applyPropertySet(sparkSession)
case ColumnChangeID.PROPERTY_CHANGE if (changes.filter(_.isInstanceOf[RemoveProperty]).size == changes.size) =>
applyPropertyUnset(sparkSession)
case ColumnChangeID.REPLACE => applyReplaceAction(sparkSession)
case other => throw new RuntimeException(s"find unsupported alter command type: ${other}")
}
Seq.empty[Row]
}
def applyReplaceAction(sparkSession: SparkSession): Unit = {
// convert to delete first then add again
val deleteChanges = changes.filter(p => p.isInstanceOf[DeleteColumn]).map(_.asInstanceOf[DeleteColumn])
val addChanges = changes.filter(p => p.isInstanceOf[AddColumn]).map(_.asInstanceOf[AddColumn])
val (oldSchema, historySchema) = getInternalSchemaAndHistorySchemaStr(sparkSession)
val newSchema = applyAddAction2Schema(sparkSession, applyDeleteAction2Schema(sparkSession, oldSchema, deleteChanges), addChanges)
val verifiedHistorySchema = if (historySchema == null || historySchema.isEmpty) {
SerDeHelper.inheritSchemas(oldSchema, "")
} else {
historySchema
}
AlterTableCommand.commitWithSchema(newSchema, verifiedHistorySchema, table, sparkSession)
logInfo("column replace finished")
}
def applyAddAction2Schema(sparkSession: SparkSession, oldSchema: InternalSchema, addChanges: Seq[AddColumn]): InternalSchema = {
val addChange = TableChanges.ColumnAddChange.get(oldSchema)
addChanges.foreach { addColumn =>
val names = addColumn.fieldNames()
val parentName = AlterTableCommand.getParentName(names)
// add col change
val colType = SparkInternalSchemaConverter.buildTypeFromStructType(addColumn.dataType(), true, new AtomicInteger(0))
addChange.addColumns(parentName, names.last, colType, addColumn.comment())
// add position change
addColumn.position() match {
case after: TableChange.After =>
addChange.addPositionChange(names.mkString("."),
if (parentName.isEmpty) after.column() else parentName + "." + after.column(), "after")
case _: TableChange.First =>
addChange.addPositionChange(names.mkString("."), "", "first")
case _ =>
}
}
SchemaChangeUtils.applyTableChanges2Schema(oldSchema, addChange)
}
def applyDeleteAction2Schema(sparkSession: SparkSession, oldSchema: InternalSchema, deleteChanges: Seq[DeleteColumn]): InternalSchema = {
val deleteChange = TableChanges.ColumnDeleteChange.get(oldSchema)
deleteChanges.foreach { c =>
val originalColName = c.fieldNames().mkString(".")
checkSchemaChange(Seq(originalColName), table)
deleteChange.deleteColumn(originalColName)
}
SchemaChangeUtils.applyTableChanges2Schema(oldSchema, deleteChange).setSchemaId(oldSchema.getMaxColumnId)
}
def applyAddAction(sparkSession: SparkSession): Unit = {
val (oldSchema, historySchema) = getInternalSchemaAndHistorySchemaStr(sparkSession)
val newSchema = applyAddAction2Schema(sparkSession, oldSchema, changes.map(_.asInstanceOf[AddColumn]))
val verifiedHistorySchema = if (historySchema == null || historySchema.isEmpty) {
SerDeHelper.inheritSchemas(oldSchema, "")
} else {
historySchema
}
AlterTableCommand.commitWithSchema(newSchema, verifiedHistorySchema, table, sparkSession)
logInfo("column add finished")
}
def applyDeleteAction(sparkSession: SparkSession): Unit = {
val (oldSchema, historySchema) = getInternalSchemaAndHistorySchemaStr(sparkSession)
val newSchema = applyDeleteAction2Schema(sparkSession, oldSchema, changes.map(_.asInstanceOf[DeleteColumn]))
// delete action should not change the getMaxColumnId field.
newSchema.setMaxColumnId(oldSchema.getMaxColumnId)
val verifiedHistorySchema = if (historySchema == null || historySchema.isEmpty) {
SerDeHelper.inheritSchemas(oldSchema, "")
} else {
historySchema
}
AlterTableCommand.commitWithSchema(newSchema, verifiedHistorySchema, table, sparkSession)
logInfo("column delete finished")
}
def applyUpdateAction(sparkSession: SparkSession): Unit = {
val (oldSchema, historySchema) = getInternalSchemaAndHistorySchemaStr(sparkSession)
val updateChange = TableChanges.ColumnUpdateChange.get(oldSchema)
changes.foreach { change =>
change match {
case updateType: TableChange.UpdateColumnType =>
val newType = SparkInternalSchemaConverter.buildTypeFromStructType(updateType.newDataType(), true, new AtomicInteger(0))
updateChange.updateColumnType(updateType.fieldNames().mkString("."), newType)
case updateComment: TableChange.UpdateColumnComment =>
updateChange.updateColumnComment(updateComment.fieldNames().mkString("."), updateComment.newComment())
case updateName: TableChange.RenameColumn =>
val originalColName = updateName.fieldNames().mkString(".")
checkSchemaChange(Seq(originalColName), table)
updateChange.renameColumn(originalColName, updateName.newName())
case updateNullAbility: TableChange.UpdateColumnNullability =>
updateChange.updateColumnNullability(updateNullAbility.fieldNames().mkString("."), updateNullAbility.nullable())
case updatePosition: TableChange.UpdateColumnPosition =>
val names = updatePosition.fieldNames()
val parentName = AlterTableCommand.getParentName(names)
updatePosition.position() match {
case after: TableChange.After =>
updateChange.addPositionChange(names.mkString("."),
if (parentName.isEmpty) after.column() else parentName + "." + after.column(), "after")
case _: TableChange.First =>
updateChange.addPositionChange(names.mkString("."), "", "first")
case _ =>
}
}
}
val newSchema = SchemaChangeUtils.applyTableChanges2Schema(oldSchema, updateChange)
val verifiedHistorySchema = if (historySchema == null || historySchema.isEmpty) {
SerDeHelper.inheritSchemas(oldSchema, "")
} else {
historySchema
}
AlterTableCommand.commitWithSchema(newSchema, verifiedHistorySchema, table, sparkSession)
logInfo("column update finished")
}
// to do support unset default value to columns, and apply them to internalSchema
def applyPropertyUnset(sparkSession: SparkSession): Unit = {
val catalog = sparkSession.sessionState.catalog
val propKeys = changes.map(_.asInstanceOf[RemoveProperty]).map(_.property())
// ignore NonExist unset
propKeys.foreach { k =>
if (!table.properties.contains(k) && k != TableCatalog.PROP_COMMENT) {
logWarning(s"find non exist unset property: ${k} , ignore it")
}
}
val tableComment = if (propKeys.contains(TableCatalog.PROP_COMMENT)) None else table.comment
val newProperties = table.properties.filter { case (k, _) => !propKeys.contains(k) }
val newTable = table.copy(properties = newProperties, comment = tableComment)
catalog.alterTable(newTable)
logInfo("table properties change finished")
}
// to do support set default value to columns, and apply them to internalSchema
def applyPropertySet(sparkSession: SparkSession): Unit = {
val catalog = sparkSession.sessionState.catalog
val properties = changes.map(_.asInstanceOf[SetProperty]).map(f => f.property -> f.value).toMap
// This overrides old properties and update the comment parameter of CatalogTable
// with the newly added/modified comment since CatalogTable also holds comment as its
// direct property.
val newTable = table.copy(
properties = table.properties ++ properties,
comment = properties.get(TableCatalog.PROP_COMMENT).orElse(table.comment))
catalog.alterTable(newTable)
logInfo("table properties change finished")
}
def getInternalSchemaAndHistorySchemaStr(sparkSession: SparkSession): (InternalSchema, String) = {
val path = AlterTableCommand.getTableLocation(table, sparkSession)
val hadoopConf = sparkSession.sessionState.newHadoopConf()
val metaClient = HoodieTableMetaClient.builder().setBasePath(path)
.setConf(hadoopConf).build()
val schemaUtil = new TableSchemaResolver(metaClient)
val schema = schemaUtil.getTableInternalSchemaFromCommitMetadata().orElse {
AvroInternalSchemaConverter.convert(schemaUtil.getTableAvroSchema)
}
val historySchemaStr = schemaUtil.getTableHistorySchemaStrFromCommitMetadata.orElse("")
(schema, historySchemaStr)
}
def checkSchemaChange(colNames: Seq[String], catalogTable: CatalogTable): Unit = {
val primaryKeys = catalogTable.storage.properties.getOrElse("primaryKey", catalogTable.properties.getOrElse("primaryKey", "keyid")).split(",").map(_.trim)
val preCombineKey = Seq(catalogTable.storage.properties.getOrElse("preCombineField", catalogTable.properties.getOrElse("preCombineField", "ts"))).map(_.trim)
val partitionKey = catalogTable.partitionColumnNames.map(_.trim)
val checkNames = primaryKeys ++ preCombineKey ++ partitionKey
colNames.foreach { col =>
if (checkNames.contains(col)) {
throw new UnsupportedOperationException("cannot support apply changes for primaryKey/CombineKey/partitionKey")
}
}
}
}
object AlterTableCommand extends Logging {
/**
* Generate an commit with new schema to change the table's schema.
*
* @param internalSchema new schema after change
* @param historySchemaStr history schemas
* @param table The hoodie table.
* @param sparkSession The spark session.
*/
def commitWithSchema(internalSchema: InternalSchema, historySchemaStr: String, table: CatalogTable, sparkSession: SparkSession): Unit = {
val schema = AvroInternalSchemaConverter.convert(internalSchema, table.identifier.table)
val path = getTableLocation(table, sparkSession)
val jsc = new JavaSparkContext(sparkSession.sparkContext)
val client = DataSourceUtils.createHoodieClient(jsc, schema.toString,
path, table.identifier.table, parametersWithWriteDefaults(table.storage.properties).asJava)
val hadoopConf = sparkSession.sessionState.newHadoopConf()
val metaClient = HoodieTableMetaClient.builder().setBasePath(path).setConf(hadoopConf).build()
val commitActionType = CommitUtils.getCommitActionType(WriteOperationType.ALTER_SCHEMA, metaClient.getTableType)
val instantTime = HoodieActiveTimeline.createNewInstantTime
client.startCommitWithTime(instantTime, commitActionType)
val hoodieTable = HoodieSparkTable.create(client.getConfig, client.getEngineContext)
val timeLine = hoodieTable.getActiveTimeline
val requested = new HoodieInstant(State.REQUESTED, commitActionType, instantTime)
val metadata = new HoodieCommitMetadata
metadata.setOperationType(WriteOperationType.ALTER_SCHEMA)
timeLine.transitionRequestedToInflight(requested, Option.of(metadata.toJsonString.getBytes(StandardCharsets.UTF_8)))
val extraMeta = new util.HashMap[String, String]()
extraMeta.put(SerDeHelper.LATEST_SCHEMA, SerDeHelper.toJson(internalSchema.setSchemaId(instantTime.toLong)))
val schemaManager = new FileBasedInternalSchemaStorageManager(metaClient)
schemaManager.persistHistorySchemaStr(instantTime, SerDeHelper.inheritSchemas(internalSchema, historySchemaStr))
client.commit(instantTime, jsc.emptyRDD, Option.of(extraMeta))
val existRoTable = sparkSession.catalog.tableExists(table.identifier.unquotedString + "_ro")
val existRtTable = sparkSession.catalog.tableExists(table.identifier.unquotedString + "_rt")
try {
sparkSession.catalog.refreshTable(table.identifier.unquotedString)
// try to refresh ro/rt table
if (existRoTable) sparkSession.catalog.refreshTable(table.identifier.unquotedString + "_ro")
if (existRoTable) sparkSession.catalog.refreshTable(table.identifier.unquotedString + "_rt")
} catch {
case NonFatal(e) =>
log.error(s"Exception when attempting to refresh table ${table.identifier.quotedString}", e)
}
// try to sync to hive
// drop partition field before call alter table
val fullSparkSchema = SparkInternalSchemaConverter.constructSparkSchemaFromInternalSchema(internalSchema)
val dataSparkSchema = new StructType(fullSparkSchema.fields.filter(p => !table.partitionColumnNames.exists(f => sparkSession.sessionState.conf.resolver(f, p.name))))
alterTableDataSchema(sparkSession, table.identifier.database.getOrElse("default"), table.identifier.table, dataSparkSchema)
if (existRoTable) alterTableDataSchema(sparkSession, table.identifier.database.getOrElse("default"), table.identifier.table + "_ro", dataSparkSchema)
if (existRtTable) alterTableDataSchema(sparkSession, table.identifier.database.getOrElse("default"), table.identifier.table + "_rt", dataSparkSchema)
}
def alterTableDataSchema(sparkSession: SparkSession, db: String, tableName: String, dataSparkSchema: StructType): Unit = {
sparkSession.sessionState.catalog
.externalCatalog
.alterTableDataSchema(db, tableName, dataSparkSchema)
}
def getTableLocation(table: CatalogTable, sparkSession: SparkSession): String = {
val uri = if (table.tableType == CatalogTableType.MANAGED) {
Some(sparkSession.sessionState.catalog.defaultTablePath(table.identifier))
} else {
table.storage.locationUri
}
val conf = sparkSession.sessionState.newHadoopConf()
uri.map(makePathQualified(_, conf))
.map(removePlaceHolder)
.getOrElse(throw new IllegalArgumentException(s"Missing location for ${table.identifier}"))
}
private def removePlaceHolder(path: String): String = {
if (path == null || path.length == 0) {
path
} else if (path.endsWith("-PLACEHOLDER")) {
path.substring(0, path.length() - 16)
} else {
path
}
}
def makePathQualified(path: URI, hadoopConf: Configuration): String = {
val hadoopPath = new Path(path)
val fs = hadoopPath.getFileSystem(hadoopConf)
fs.makeQualified(hadoopPath).toUri.toString
}
def getParentName(names: Array[String]): String = {
if (names.size > 1) {
names.dropRight(1).mkString(".")
} else ""
}
def parametersWithWriteDefaults(parameters: Map[String, String]): Map[String, String] = {
Map(OPERATION.key -> OPERATION.defaultValue,
TABLE_TYPE.key -> TABLE_TYPE.defaultValue,
PRECOMBINE_FIELD.key -> PRECOMBINE_FIELD.defaultValue,
HoodieWriteConfig.WRITE_PAYLOAD_CLASS_NAME.key -> HoodieWriteConfig.DEFAULT_WRITE_PAYLOAD_CLASS,
INSERT_DROP_DUPS.key -> INSERT_DROP_DUPS.defaultValue,
ASYNC_COMPACT_ENABLE.key -> ASYNC_COMPACT_ENABLE.defaultValue,
INLINE_CLUSTERING_ENABLE.key -> INLINE_CLUSTERING_ENABLE.defaultValue,
ASYNC_CLUSTERING_ENABLE.key -> ASYNC_CLUSTERING_ENABLE.defaultValue
) ++ DataSourceOptionsHelper.translateConfigurations(parameters)
}
}

View File

@@ -0,0 +1,199 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.parser
import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
import org.antlr.v4.runtime.tree.TerminalNodeImpl
import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser.{NonReservedContext, QuotedIdentifierContext}
import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseListener, HoodieSqlBaseLexer, HoodieSqlBaseParser}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.internal.VariableSubstitution
import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, SparkSession}
import java.util.Locale
class HoodieSpark3_3ExtendedSqlParser(session: SparkSession, delegate: ParserInterface)
extends ParserInterface with Logging {
private lazy val conf = session.sqlContext.conf
private lazy val builder = new HoodieSpark3_3ExtendedSqlAstBuilder(conf, delegate)
private val substitutor = new VariableSubstitution
override def parsePlan(sqlText: String): LogicalPlan = {
val substitutionSql = substitutor.substitute(sqlText)
if (isHoodieCommand(substitutionSql)) {
parse(substitutionSql) { parser =>
builder.visit(parser.singleStatement()) match {
case plan: LogicalPlan => plan
case _ => delegate.parsePlan(sqlText)
}
}
} else {
delegate.parsePlan(substitutionSql)
}
}
// SPARK-37266 Added parseQuery to ParserInterface in Spark 3.3.0
// Don't mark this as override for backward compatibility
def parseQuery(sqlText: String): LogicalPlan = delegate.parseQuery(sqlText)
override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText)
override def parseTableIdentifier(sqlText: String): TableIdentifier =
delegate.parseTableIdentifier(sqlText)
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
delegate.parseFunctionIdentifier(sqlText)
override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText)
override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText)
protected def parse[T](command: String)(toResult: HoodieSqlBaseParser => T): T = {
logDebug(s"Parsing command: $command")
val lexer = new HoodieSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)
val tokenStream = new CommonTokenStream(lexer)
val parser = new HoodieSqlBaseParser(tokenStream)
parser.addParseListener(PostProcessor)
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
// parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced
parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled
parser.SQL_standard_keyword_behavior = conf.ansiEnabled
try {
try {
// first, try parsing with potentially faster SLL mode
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
toResult(parser)
}
catch {
case e: ParseCancellationException =>
// if we fail, parse with LL mode
tokenStream.seek(0) // rewind input stream
parser.reset()
// Try Again.
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
toResult(parser)
}
}
catch {
case e: ParseException if e.command.isDefined =>
throw e
case e: ParseException =>
throw e.withCommand(command)
case e: AnalysisException =>
val position = Origin(e.line, e.startPosition)
throw new ParseException(Option(command), e.message, position, position)
}
}
override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
delegate.parseMultipartIdentifier(sqlText)
}
private def isHoodieCommand(sqlText: String): Boolean = {
val normalized = sqlText.toLowerCase(Locale.ROOT).trim().replaceAll("\\s+", " ")
normalized.contains("system_time as of") ||
normalized.contains("timestamp as of") ||
normalized.contains("system_version as of") ||
normalized.contains("version as of")
}
}
/**
* Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`.
*/
class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
override def consume(): Unit = wrapped.consume
override def getSourceName(): String = wrapped.getSourceName
override def index(): Int = wrapped.index
override def mark(): Int = wrapped.mark
override def release(marker: Int): Unit = wrapped.release(marker)
override def seek(where: Int): Unit = wrapped.seek(where)
override def size(): Int = wrapped.size
override def getText(interval: Interval): String = {
// ANTLR 4.7's CodePointCharStream implementations have bugs when
// getText() is called with an empty stream, or intervals where
// the start > end. See
// https://github.com/antlr/antlr4/commit/ac9f7530 for one fix
// that is not yet in a released ANTLR artifact.
if (size() > 0 && (interval.b - interval.a >= 0)) {
wrapped.getText(interval)
} else {
""
}
}
// scalastyle:off
override def LA(i: Int): Int = {
// scalastyle:on
val la = wrapped.LA(i)
if (la == 0 || la == IntStream.EOF) la
else Character.toUpperCase(la)
}
}
/**
* Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`.
*/
case object PostProcessor extends HoodieSqlBaseBaseListener {
/** Remove the back ticks from an Identifier. */
override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
replaceTokenByIdentifier(ctx, 1) { token =>
// Remove the double back ticks in the string.
token.setText(token.getText.replace("``", "`"))
token
}
}
/** Treat non-reserved keywords as Identifiers. */
override def exitNonReserved(ctx: NonReservedContext): Unit = {
replaceTokenByIdentifier(ctx, 0)(identity)
}
private def replaceTokenByIdentifier(
ctx: ParserRuleContext,
stripMargins: Int)(
f: CommonToken => CommonToken = identity): Unit = {
val parent = ctx.getParent
parent.removeLastChild()
val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
val newToken = new CommonToken(
new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
HoodieSqlBaseParser.IDENTIFIER,
token.getChannel,
token.getStartIndex + stripMargins,
token.getStopIndex - stripMargins)
parent.addChild(new TerminalNodeImpl(f(newToken)))
}
}

View File

@@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi.spark3.internal;
import org.apache.hudi.common.testutils.HoodieTestDataGenerator;
import org.apache.hudi.common.util.Option;
import org.apache.hudi.config.HoodieWriteConfig;
import org.apache.hudi.internal.HoodieBulkInsertInternalWriterTestBase;
import org.apache.hudi.table.HoodieSparkTable;
import org.apache.hudi.table.HoodieTable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import static org.apache.hudi.testutils.SparkDatasetTestUtils.ENCODER;
import static org.apache.hudi.testutils.SparkDatasetTestUtils.STRUCT_TYPE;
import static org.apache.hudi.testutils.SparkDatasetTestUtils.getInternalRowWithError;
import static org.apache.hudi.testutils.SparkDatasetTestUtils.getRandomRows;
import static org.apache.hudi.testutils.SparkDatasetTestUtils.toInternalRows;
import static org.junit.jupiter.api.Assertions.fail;
/**
* Unit tests {@link HoodieBulkInsertDataInternalWriter}.
*/
public class TestHoodieBulkInsertDataInternalWriter extends
HoodieBulkInsertInternalWriterTestBase {
private static Stream<Arguments> configParams() {
Object[][] data = new Object[][] {
{true, true},
{true, false},
{false, true},
{false, false}
};
return Stream.of(data).map(Arguments::of);
}
private static Stream<Arguments> bulkInsertTypeParams() {
Object[][] data = new Object[][] {
{true},
{false}
};
return Stream.of(data).map(Arguments::of);
}
@ParameterizedTest
@MethodSource("configParams")
public void testDataInternalWriter(boolean sorted, boolean populateMetaFields) throws Exception {
// init config and table
HoodieWriteConfig cfg = getWriteConfig(populateMetaFields);
HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient);
// execute N rounds
for (int i = 0; i < 2; i++) {
String instantTime = "00" + i;
// init writer
HoodieBulkInsertDataInternalWriter writer = new HoodieBulkInsertDataInternalWriter(table, cfg, instantTime, RANDOM.nextInt(100000),
RANDOM.nextLong(), STRUCT_TYPE, populateMetaFields, sorted);
int size = 10 + RANDOM.nextInt(1000);
// write N rows to partition1, N rows to partition2 and N rows to partition3 ... Each batch should create a new RowCreateHandle and a new file
int batches = 3;
Dataset<Row> totalInputRows = null;
for (int j = 0; j < batches; j++) {
String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3];
Dataset<Row> inputRows = getRandomRows(sqlContext, size, partitionPath, false);
writeRows(inputRows, writer);
if (totalInputRows == null) {
totalInputRows = inputRows;
} else {
totalInputRows = totalInputRows.union(inputRows);
}
}
HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit();
Option<List<String>> fileAbsPaths = Option.of(new ArrayList<>());
Option<List<String>> fileNames = Option.of(new ArrayList<>());
// verify write statuses
assertWriteStatuses(commitMetadata.getWriteStatuses(), batches, size, sorted, fileAbsPaths, fileNames, false);
// verify rows
Dataset<Row> result = sqlContext.read().parquet(fileAbsPaths.get().toArray(new String[0]));
assertOutput(totalInputRows, result, instantTime, fileNames, populateMetaFields);
}
}
/**
* Issue some corrupted or wrong schematized InternalRow after few valid InternalRows so that global error is thrown. write batch 1 of valid records write batch2 of invalid records which is expected
* to throw Global Error. Verify global error is set appropriately and only first batch of records are written to disk.
*/
@Test
public void testGlobalFailure() throws Exception {
// init config and table
HoodieWriteConfig cfg = getWriteConfig(true);
HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient);
String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[0];
String instantTime = "001";
HoodieBulkInsertDataInternalWriter writer = new HoodieBulkInsertDataInternalWriter(table, cfg, instantTime, RANDOM.nextInt(100000),
RANDOM.nextLong(), STRUCT_TYPE, true, false);
int size = 10 + RANDOM.nextInt(100);
int totalFailures = 5;
// Generate first batch of valid rows
Dataset<Row> inputRows = getRandomRows(sqlContext, size / 2, partitionPath, false);
List<InternalRow> internalRows = toInternalRows(inputRows, ENCODER);
// generate some failures rows
for (int i = 0; i < totalFailures; i++) {
internalRows.add(getInternalRowWithError(partitionPath));
}
// generate 2nd batch of valid rows
Dataset<Row> inputRows2 = getRandomRows(sqlContext, size / 2, partitionPath, false);
internalRows.addAll(toInternalRows(inputRows2, ENCODER));
// issue writes
try {
for (InternalRow internalRow : internalRows) {
writer.write(internalRow);
}
fail("Should have failed");
} catch (Throwable e) {
// expected
}
HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit();
Option<List<String>> fileAbsPaths = Option.of(new ArrayList<>());
Option<List<String>> fileNames = Option.of(new ArrayList<>());
// verify write statuses
assertWriteStatuses(commitMetadata.getWriteStatuses(), 1, size / 2, fileAbsPaths, fileNames);
// verify rows
Dataset<Row> result = sqlContext.read().parquet(fileAbsPaths.get().toArray(new String[0]));
assertOutput(inputRows, result, instantTime, fileNames, true);
}
private void writeRows(Dataset<Row> inputRows, HoodieBulkInsertDataInternalWriter writer)
throws Exception {
List<InternalRow> internalRows = toInternalRows(inputRows, ENCODER);
// issue writes
for (InternalRow internalRow : internalRows) {
writer.write(internalRow);
}
}
}

View File

@@ -0,0 +1,330 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi.spark3.internal;
import org.apache.hudi.DataSourceWriteOptions;
import org.apache.hudi.common.model.HoodieCommitMetadata;
import org.apache.hudi.common.testutils.HoodieTestDataGenerator;
import org.apache.hudi.common.util.Option;
import org.apache.hudi.config.HoodieWriteConfig;
import org.apache.hudi.internal.HoodieBulkInsertInternalWriterTestBase;
import org.apache.hudi.table.HoodieSparkTable;
import org.apache.hudi.table.HoodieTable;
import org.apache.hudi.testutils.HoodieClientTestUtils;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.write.DataWriter;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import static org.apache.hudi.testutils.SparkDatasetTestUtils.ENCODER;
import static org.apache.hudi.testutils.SparkDatasetTestUtils.STRUCT_TYPE;
import static org.apache.hudi.testutils.SparkDatasetTestUtils.getRandomRows;
import static org.apache.hudi.testutils.SparkDatasetTestUtils.toInternalRows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* Unit tests {@link HoodieDataSourceInternalBatchWrite}.
*/
public class TestHoodieDataSourceInternalBatchWrite extends
HoodieBulkInsertInternalWriterTestBase {
private static Stream<Arguments> bulkInsertTypeParams() {
Object[][] data = new Object[][] {
{true},
{false}
};
return Stream.of(data).map(Arguments::of);
}
@ParameterizedTest
@MethodSource("bulkInsertTypeParams")
public void testDataSourceWriter(boolean populateMetaFields) throws Exception {
testDataSourceWriterInternal(Collections.EMPTY_MAP, Collections.EMPTY_MAP, populateMetaFields);
}
private void testDataSourceWriterInternal(Map<String, String> extraMetadata, Map<String, String> expectedExtraMetadata, boolean populateMetaFields) throws Exception {
// init config and table
HoodieWriteConfig cfg = getWriteConfig(populateMetaFields);
HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient);
String instantTime = "001";
// init writer
HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite =
new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, extraMetadata, populateMetaFields, false);
DataWriter<InternalRow> writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(0, RANDOM.nextLong());
String[] partitionPaths = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS;
List<String> partitionPathsAbs = new ArrayList<>();
for (String partitionPath : partitionPaths) {
partitionPathsAbs.add(basePath + "/" + partitionPath + "/*");
}
int size = 10 + RANDOM.nextInt(1000);
int batches = 5;
Dataset<Row> totalInputRows = null;
for (int j = 0; j < batches; j++) {
String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3];
Dataset<Row> inputRows = getRandomRows(sqlContext, size, partitionPath, false);
writeRows(inputRows, writer);
if (totalInputRows == null) {
totalInputRows = inputRows;
} else {
totalInputRows = totalInputRows.union(inputRows);
}
}
HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit();
List<HoodieWriterCommitMessage> commitMessages = new ArrayList<>();
commitMessages.add(commitMetadata);
dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0]));
metaClient.reloadActiveTimeline();
Dataset<Row> result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0]));
// verify output
assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields);
assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty());
// verify extra metadata
Option<HoodieCommitMetadata> commitMetadataOption = HoodieClientTestUtils.getCommitMetadataForLatestInstant(metaClient);
assertTrue(commitMetadataOption.isPresent());
Map<String, String> actualExtraMetadata = new HashMap<>();
commitMetadataOption.get().getExtraMetadata().entrySet().stream().filter(entry ->
!entry.getKey().equals(HoodieCommitMetadata.SCHEMA_KEY)).forEach(entry -> actualExtraMetadata.put(entry.getKey(), entry.getValue()));
assertEquals(actualExtraMetadata, expectedExtraMetadata);
}
@Test
public void testDataSourceWriterExtraCommitMetadata() throws Exception {
String commitExtraMetaPrefix = "commit_extra_meta_";
Map<String, String> extraMeta = new HashMap<>();
extraMeta.put(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key(), commitExtraMetaPrefix);
extraMeta.put(commitExtraMetaPrefix + "a", "valA");
extraMeta.put(commitExtraMetaPrefix + "b", "valB");
extraMeta.put("commit_extra_c", "valC"); // should not be part of commit extra metadata
Map<String, String> expectedMetadata = new HashMap<>();
expectedMetadata.putAll(extraMeta);
expectedMetadata.remove(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key());
expectedMetadata.remove("commit_extra_c");
testDataSourceWriterInternal(extraMeta, expectedMetadata, true);
}
@Test
public void testDataSourceWriterEmptyExtraCommitMetadata() throws Exception {
String commitExtraMetaPrefix = "commit_extra_meta_";
Map<String, String> extraMeta = new HashMap<>();
extraMeta.put(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key(), commitExtraMetaPrefix);
extraMeta.put("keyA", "valA");
extraMeta.put("keyB", "valB");
extraMeta.put("commit_extra_c", "valC");
// none of the keys has commit metadata key prefix.
testDataSourceWriterInternal(extraMeta, Collections.EMPTY_MAP, true);
}
@ParameterizedTest
@MethodSource("bulkInsertTypeParams")
public void testMultipleDataSourceWrites(boolean populateMetaFields) throws Exception {
// init config and table
HoodieWriteConfig cfg = getWriteConfig(populateMetaFields);
HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient);
int partitionCounter = 0;
// execute N rounds
for (int i = 0; i < 2; i++) {
String instantTime = "00" + i;
// init writer
HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite =
new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.EMPTY_MAP, populateMetaFields, false);
List<HoodieWriterCommitMessage> commitMessages = new ArrayList<>();
Dataset<Row> totalInputRows = null;
DataWriter<InternalRow> writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(partitionCounter++, RANDOM.nextLong());
int size = 10 + RANDOM.nextInt(1000);
int batches = 3; // one batch per partition
for (int j = 0; j < batches; j++) {
String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3];
Dataset<Row> inputRows = getRandomRows(sqlContext, size, partitionPath, false);
writeRows(inputRows, writer);
if (totalInputRows == null) {
totalInputRows = inputRows;
} else {
totalInputRows = totalInputRows.union(inputRows);
}
}
HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit();
commitMessages.add(commitMetadata);
dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0]));
metaClient.reloadActiveTimeline();
Dataset<Row> result = HoodieClientTestUtils.readCommit(basePath, sqlContext, metaClient.getCommitTimeline(), instantTime, populateMetaFields);
// verify output
assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields);
assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty());
}
}
// Large writes are not required to be executed w/ regular CI jobs. Takes lot of running time.
@Disabled
@ParameterizedTest
@MethodSource("bulkInsertTypeParams")
public void testLargeWrites(boolean populateMetaFields) throws Exception {
// init config and table
HoodieWriteConfig cfg = getWriteConfig(populateMetaFields);
HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient);
int partitionCounter = 0;
// execute N rounds
for (int i = 0; i < 3; i++) {
String instantTime = "00" + i;
// init writer
HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite =
new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.EMPTY_MAP, populateMetaFields, false);
List<HoodieWriterCommitMessage> commitMessages = new ArrayList<>();
Dataset<Row> totalInputRows = null;
DataWriter<InternalRow> writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(partitionCounter++, RANDOM.nextLong());
int size = 10000 + RANDOM.nextInt(10000);
int batches = 3; // one batch per partition
for (int j = 0; j < batches; j++) {
String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3];
Dataset<Row> inputRows = getRandomRows(sqlContext, size, partitionPath, false);
writeRows(inputRows, writer);
if (totalInputRows == null) {
totalInputRows = inputRows;
} else {
totalInputRows = totalInputRows.union(inputRows);
}
}
HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit();
commitMessages.add(commitMetadata);
dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0]));
metaClient.reloadActiveTimeline();
Dataset<Row> result = HoodieClientTestUtils.readCommit(basePath, sqlContext, metaClient.getCommitTimeline(), instantTime,
populateMetaFields);
// verify output
assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields);
assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty());
}
}
/**
* Tests that DataSourceWriter.abort() will abort the written records of interest write and commit batch1 write and abort batch2 Read of entire dataset should show only records from batch1.
* commit batch1
* abort batch2
* verify only records from batch1 is available to read
*/
@ParameterizedTest
@MethodSource("bulkInsertTypeParams")
public void testAbort(boolean populateMetaFields) throws Exception {
// init config and table
HoodieWriteConfig cfg = getWriteConfig(populateMetaFields);
HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient);
String instantTime0 = "00" + 0;
// init writer
HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite =
new HoodieDataSourceInternalBatchWrite(instantTime0, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.EMPTY_MAP, populateMetaFields, false);
DataWriter<InternalRow> writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(0, RANDOM.nextLong());
List<String> partitionPaths = Arrays.asList(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS);
List<String> partitionPathsAbs = new ArrayList<>();
for (String partitionPath : partitionPaths) {
partitionPathsAbs.add(basePath + "/" + partitionPath + "/*");
}
int size = 10 + RANDOM.nextInt(100);
int batches = 1;
Dataset<Row> totalInputRows = null;
for (int j = 0; j < batches; j++) {
String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3];
Dataset<Row> inputRows = getRandomRows(sqlContext, size, partitionPath, false);
writeRows(inputRows, writer);
if (totalInputRows == null) {
totalInputRows = inputRows;
} else {
totalInputRows = totalInputRows.union(inputRows);
}
}
HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit();
List<HoodieWriterCommitMessage> commitMessages = new ArrayList<>();
commitMessages.add(commitMetadata);
// commit 1st batch
dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0]));
metaClient.reloadActiveTimeline();
Dataset<Row> result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0]));
// verify rows
assertOutput(totalInputRows, result, instantTime0, Option.empty(), populateMetaFields);
assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty());
// 2nd batch. abort in the end
String instantTime1 = "00" + 1;
dataSourceInternalBatchWrite =
new HoodieDataSourceInternalBatchWrite(instantTime1, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.EMPTY_MAP, populateMetaFields, false);
writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(1, RANDOM.nextLong());
for (int j = 0; j < batches; j++) {
String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3];
Dataset<Row> inputRows = getRandomRows(sqlContext, size, partitionPath, false);
writeRows(inputRows, writer);
}
commitMetadata = (HoodieWriterCommitMessage) writer.commit();
commitMessages = new ArrayList<>();
commitMessages.add(commitMetadata);
// commit 1st batch
dataSourceInternalBatchWrite.abort(commitMessages.toArray(new HoodieWriterCommitMessage[0]));
metaClient.reloadActiveTimeline();
result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0]));
// verify rows
// only rows from first batch should be present
assertOutput(totalInputRows, result, instantTime0, Option.empty(), populateMetaFields);
}
private void writeRows(Dataset<Row> inputRows, DataWriter<InternalRow> writer) throws Exception {
List<InternalRow> internalRows = toInternalRows(inputRows, ENCODER);
// issue writes
for (InternalRow internalRow : internalRows) {
writer.write(internalRow);
}
}
}

View File

@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi.spark3.internal;
import org.apache.hudi.testutils.HoodieClientTestBase;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoStatement;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
/**
* Unit tests {@link ReflectUtil}.
*/
public class TestReflectUtil extends HoodieClientTestBase {
@Test
public void testDataSourceWriterExtraCommitMetadata() throws Exception {
SparkSession spark = sqlContext.sparkSession();
String insertIntoSql = "insert into test_reflect_util values (1, 'z3', 1, '2021')";
InsertIntoStatement statement = (InsertIntoStatement) spark.sessionState().sqlParser().parsePlan(insertIntoSql);
InsertIntoStatement newStatment = ReflectUtil.createInsertInto(
statement.table(),
statement.partitionSpec(),
scala.collection.immutable.List.empty(),
statement.query(),
statement.overwrite(),
statement.ifPartitionNotExists());
Assertions.assertTrue(
((UnresolvedRelation)newStatment.table()).multipartIdentifier().contains("test_reflect_util"));
}
}

View File

@@ -0,0 +1,30 @@
###
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###
log4j.rootLogger=WARN, CONSOLE
log4j.logger.org.apache.hudi=DEBUG
log4j.logger.org.apache.hadoop.hbase=ERROR
# CONSOLE is set to be a ConsoleAppender.
log4j.appender.CONSOLE=org.apache.log4j.ConsoleAppender
# CONSOLE uses PatternLayout.
log4j.appender.CONSOLE.layout=org.apache.log4j.PatternLayout
log4j.appender.CONSOLE.layout.ConversionPattern=[%-5p] %d %c %x - %m%n
log4j.appender.CONSOLE.filter.a=org.apache.log4j.varia.LevelRangeFilter
log4j.appender.CONSOLE.filter.a.AcceptOnMatch=true
log4j.appender.CONSOLE.filter.a.LevelMin=WARN
log4j.appender.CONSOLE.filter.a.LevelMax=FATAL

View File

@@ -0,0 +1,31 @@
###
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###
log4j.rootLogger=WARN, CONSOLE
log4j.logger.org.apache=INFO
log4j.logger.org.apache.hudi=DEBUG
log4j.logger.org.apache.hadoop.hbase=ERROR
# CONSOLE is set to be a ConsoleAppender.
log4j.appender.CONSOLE=org.apache.log4j.ConsoleAppender
# CONSOLE uses PatternLayout.
log4j.appender.CONSOLE.layout=org.apache.log4j.PatternLayout
log4j.appender.CONSOLE.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n
log4j.appender.CONSOLE.filter.a=org.apache.log4j.varia.LevelRangeFilter
log4j.appender.CONSOLE.filter.a.AcceptOnMatch=true
log4j.appender.CONSOLE.filter.a.LevelMin=WARN
log4j.appender.CONSOLE.filter.a.LevelMax=FATAL