1
0

[HUDI-3172] Refactor hudi existing modules to make more code reuse in V2 Implementation (#4514)

* Introduce hudi-spark3-common and hudi-spark2-common modules to place classes that would be reused in different spark versions, also introduce hudi-spark3.1.x to support spark 3.1.x.
* Introduce hudi format under hudi-spark2, hudi-spark3, hudi-spark3.1.x modules and change the hudi format in original hudi-spark module to hudi_v1 format.
* Manually tested on Spark 3.1.2 and Spark 3.2.0 SQL.
* Added a README.md file under hudi-spark-datasource module.
This commit is contained in:
leesf
2022-01-14 13:42:35 +08:00
committed by GitHub
parent 195dac90fa
commit 5ce45c440b
90 changed files with 1249 additions and 430 deletions

View File

@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
class Spark3RowSerDe(val encoder: ExpressionEncoder[Row]) extends SparkRowSerDe {
private val deserializer: ExpressionEncoder.Deserializer[Row] = encoder.createDeserializer()
private val serializer: ExpressionEncoder.Serializer[Row] = encoder.createSerializer()
def deserializeRow(internalRow: InternalRow): Row = {
deserializer.apply(internalRow)
}
override def serializeRow(row: Row): InternalRow = {
serializer.apply(row)
}
}

View File

@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.adapter
import org.apache.hudi.Spark3RowSerDe
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.hudi.spark3.internal.ReflectUtil
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, Like}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, JoinHint, LogicalPlan}
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.{LogicalRelation, Spark3ParsePartitionUtil, SparkParsePartitionUtil}
import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.internal.SQLConf
/**
* The adapter for spark3.
*/
class Spark3Adapter extends SparkAdapter {
override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
new Spark3RowSerDe(encoder)
}
override def toTableIdentifier(aliasId: AliasIdentifier): TableIdentifier = {
aliasId match {
case AliasIdentifier(name, Seq(database)) =>
TableIdentifier(name, Some(database))
case AliasIdentifier(name, Seq(_, database)) =>
TableIdentifier(name, Some(database))
case AliasIdentifier(name, Seq()) =>
TableIdentifier(name, None)
case _=> throw new IllegalArgumentException(s"Cannot cast $aliasId to TableIdentifier")
}
}
override def toTableIdentifier(relation: UnresolvedRelation): TableIdentifier = {
relation.multipartIdentifier.asTableIdentifier
}
override def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join = {
Join(left, right, joinType, None, JoinHint.NONE)
}
override def isInsertInto(plan: LogicalPlan): Boolean = {
plan.isInstanceOf[InsertIntoStatement]
}
override def getInsertIntoChildren(plan: LogicalPlan):
Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = {
plan match {
case insert: InsertIntoStatement =>
Some((insert.table, insert.partitionSpec, insert.query, insert.overwrite, insert.ifPartitionNotExists))
case _ =>
None
}
}
override def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]],
query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan = {
ReflectUtil.createInsertInto(table, partition, Seq.empty[String], query, overwrite, ifPartitionNotExists)
}
override def createSparkParsePartitionUtil(conf: SQLConf): SparkParsePartitionUtil = {
new Spark3ParsePartitionUtil(conf)
}
override def createLike(left: Expression, right: Expression): Expression = {
new Like(left, right)
}
override def parseMultipartIdentifier(parser: ParserInterface, sqlText: String): Seq[String] = {
parser.parseMultipartIdentifier(sqlText)
}
}

View File

@@ -0,0 +1,274 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources
import java.lang.{Double => JDouble, Long => JLong}
import java.math.{BigDecimal => JBigDecimal}
import java.time.ZoneId
import java.util.{Locale, TimeZone}
import org.apache.hadoop.fs.Path
import org.apache.hudi.common.util.PartitionPathEncodeUtils.DEFAULT_PARTITION_PATH
import org.apache.hudi.spark3.internal.ReflectUtil
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter}
import org.apache.spark.sql.execution.datasources.PartitioningUtils.timestampPartitionPattern
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import scala.collection.mutable.ArrayBuffer
import scala.util.Try
import scala.util.control.NonFatal
class Spark3ParsePartitionUtil(conf: SQLConf) extends SparkParsePartitionUtil {
/**
* The definition of PartitionValues has been changed by SPARK-34314 in Spark3.2.
* To solve the compatibility between 3.1 and 3.2, copy some codes from PartitioningUtils in Spark3.2 here.
* And this method will generate and return `InternalRow` directly instead of `PartitionValues`.
*/
override def parsePartition(
path: Path,
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedDataTypes: Map[String, DataType],
timeZone: TimeZone): InternalRow = {
val dateFormatter = ReflectUtil.getDateFormatter(timeZone.toZoneId)
val timestampFormatter = TimestampFormatter(timestampPartitionPattern,
timeZone.toZoneId, isParsing = true)
val (partitionValues, _) = parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes,
conf.validatePartitionColumns, timeZone.toZoneId, dateFormatter, timestampFormatter)
partitionValues.map {
case PartitionValues(columnNames: Seq[String], typedValues: Seq[TypedPartValue]) =>
val rowValues = columnNames.zip(typedValues).map { case (columnName, typedValue) =>
try {
castPartValueToDesiredType(typedValue.dataType, typedValue.value, timeZone.toZoneId)
} catch {
case NonFatal(_) =>
if (conf.validatePartitionColumns) {
throw new RuntimeException(s"Failed to cast value `${typedValue.value}` to " +
s"`${typedValue.dataType}` for partition column `$columnName`")
} else null
}
}
InternalRow.fromSeq(rowValues)
}.getOrElse(InternalRow.empty)
}
case class TypedPartValue(value: String, dataType: DataType)
case class PartitionValues(columnNames: Seq[String], typedValues: Seq[TypedPartValue])
{
require(columnNames.size == typedValues.size)
}
private def parsePartition(
path: Path,
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedDataTypes: Map[String, DataType],
validatePartitionColumns: Boolean,
zoneId: ZoneId,
dateFormatter: DateFormatter,
timestampFormatter: TimestampFormatter): (Option[PartitionValues], Option[Path]) = {
val columns = ArrayBuffer.empty[(String, TypedPartValue)]
// Old Hadoop versions don't have `Path.isRoot`
var finished = path.getParent == null
// currentPath is the current path that we will use to parse partition column value.
var currentPath: Path = path
while (!finished) {
// Sometimes (e.g., when speculative task is enabled), temporary directories may be left
// uncleaned. Here we simply ignore them.
if (currentPath.getName.toLowerCase(Locale.ROOT) == "_temporary") {
// scalastyle:off return
return (None, None)
// scalastyle:on return
}
if (basePaths.contains(currentPath)) {
// If the currentPath is one of base paths. We should stop.
finished = true
} else {
// Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1.
// Once we get the string, we try to parse it and find the partition column and value.
val maybeColumn =
parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes,
validatePartitionColumns, zoneId, dateFormatter, timestampFormatter)
maybeColumn.foreach(columns += _)
// Now, we determine if we should stop.
// When we hit any of the following cases, we will stop:
// - In this iteration, we could not parse the value of partition column and value,
// i.e. maybeColumn is None, and columns is not empty. At here we check if columns is
// empty to handle cases like /table/a=1/_temporary/something (we need to find a=1 in
// this case).
// - After we get the new currentPath, this new currentPath represent the top level dir
// i.e. currentPath.getParent == null. For the example of "/table/a=1/",
// the top level dir is "/table".
finished =
(maybeColumn.isEmpty && !columns.isEmpty) || currentPath.getParent == null
if (!finished) {
// For the above example, currentPath will be "/table/".
currentPath = currentPath.getParent
}
}
}
if (columns.isEmpty) {
(None, Some(path))
} else {
val (columnNames, values) = columns.reverse.unzip
(Some(PartitionValues(columnNames.toSeq, values.toSeq)), Some(currentPath))
}
}
private def parsePartitionColumn(
columnSpec: String,
typeInference: Boolean,
userSpecifiedDataTypes: Map[String, DataType],
validatePartitionColumns: Boolean,
zoneId: ZoneId,
dateFormatter: DateFormatter,
timestampFormatter: TimestampFormatter): Option[(String, TypedPartValue)] = {
val equalSignIndex = columnSpec.indexOf('=')
if (equalSignIndex == -1) {
None
} else {
val columnName = unescapePathName(columnSpec.take(equalSignIndex))
assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'")
val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")
val dataType = if (userSpecifiedDataTypes.contains(columnName)) {
// SPARK-26188: if user provides corresponding column schema, get the column value without
// inference, and then cast it as user specified data type.
userSpecifiedDataTypes(columnName)
} else {
inferPartitionColumnValue(
rawColumnValue,
typeInference,
zoneId,
dateFormatter,
timestampFormatter)
}
Some(columnName -> TypedPartValue(rawColumnValue, dataType))
}
}
private def inferPartitionColumnValue(
raw: String,
typeInference: Boolean,
zoneId: ZoneId,
dateFormatter: DateFormatter,
timestampFormatter: TimestampFormatter): DataType = {
val decimalTry = Try {
// `BigDecimal` conversion can fail when the `field` is not a form of number.
val bigDecimal = new JBigDecimal(raw)
// It reduces the cases for decimals by disallowing values having scale (e.g. `1.1`).
require(bigDecimal.scale <= 0)
// `DecimalType` conversion can fail when
// 1. The precision is bigger than 38.
// 2. scale is bigger than precision.
fromDecimal(Decimal(bigDecimal))
}
val dateTry = Try {
// try and parse the date, if no exception occurs this is a candidate to be resolved as
// DateType
dateFormatter.parse(raw)
// SPARK-23436: Casting the string to date may still return null if a bad Date is provided.
// This can happen since DateFormat.parse may not use the entire text of the given string:
// so if there are extra-characters after the date, it returns correctly.
// We need to check that we can cast the raw string since we later can use Cast to get
// the partition values with the right DataType (see
// org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning)
val dateValue = Cast(Literal(raw), DateType, Some(zoneId.getId)).eval()
// Disallow DateType if the cast returned null
require(dateValue != null)
DateType
}
val timestampTry = Try {
val unescapedRaw = unescapePathName(raw)
// the inferred data type is consistent with the default timestamp type
val timestampType = TimestampType
// try and parse the date, if no exception occurs this is a candidate to be resolved as TimestampType
timestampFormatter.parse(unescapedRaw)
// SPARK-23436: see comment for date
val timestampValue = Cast(Literal(unescapedRaw), timestampType, Some(zoneId.getId)).eval()
// Disallow TimestampType if the cast returned null
require(timestampValue != null)
timestampType
}
if (typeInference) {
// First tries integral types
Try({ Integer.parseInt(raw); IntegerType })
.orElse(Try { JLong.parseLong(raw); LongType })
.orElse(decimalTry)
// Then falls back to fractional types
.orElse(Try { JDouble.parseDouble(raw); DoubleType })
// Then falls back to date/timestamp types
.orElse(timestampTry)
.orElse(dateTry)
// Then falls back to string
.getOrElse {
if (raw == DEFAULT_PARTITION_PATH) NullType else StringType
}
} else {
if (raw == DEFAULT_PARTITION_PATH) NullType else StringType
}
}
def castPartValueToDesiredType(
desiredType: DataType,
value: String,
zoneId: ZoneId): Any = desiredType match {
case _ if value == DEFAULT_PARTITION_PATH => null
case NullType => null
case StringType => UTF8String.fromString(unescapePathName(value))
case IntegerType => Integer.parseInt(value)
case LongType => JLong.parseLong(value)
case DoubleType => JDouble.parseDouble(value)
case _: DecimalType => Literal(new JBigDecimal(value)).value
case DateType =>
Cast(Literal(value), DateType, Some(zoneId.getId)).eval()
// Timestamp types
case dt: TimestampType =>
Try {
Cast(Literal(unescapePathName(value)), dt, Some(zoneId.getId)).eval()
}.getOrElse {
Cast(Cast(Literal(value), DateType, Some(zoneId.getId)), dt).eval()
}
case dt => throw new IllegalArgumentException(s"Unexpected type $dt")
}
private def fromDecimal(d: Decimal): DecimalType = DecimalType(d.precision, d.scale)
}