diff --git a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/columnstats/ColumnStatsIndexHelper.java b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/columnstats/ColumnStatsIndexHelper.java
index b98893344..49e33f456 100644
--- a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/columnstats/ColumnStatsIndexHelper.java
+++ b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/columnstats/ColumnStatsIndexHelper.java
@@ -72,6 +72,7 @@ import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
+import static org.apache.hudi.common.util.ValidationUtils.checkState;
import static org.apache.hudi.util.DataTypeUtils.areCompatible;
public class ColumnStatsIndexHelper {
@@ -111,17 +112,17 @@ public class ColumnStatsIndexHelper {
* | another_base_file.parquet | -10 | 0 | 5 |
* +---------------------------+------------+------------+-------------+
*
- *
+ *
* NOTE: Currently {@link TimestampType} is not supported, since Parquet writer
* does not support statistics for it.
- *
+ *
* TODO leverage metadata table after RFC-27 lands
- * @VisibleForTesting
*
- * @param sparkSession encompassing Spark session
- * @param baseFilesPaths list of base-files paths to be sourced for column-stats index
+ * @param sparkSession encompassing Spark session
+ * @param baseFilesPaths list of base-files paths to be sourced for column-stats index
* @param orderedColumnSchemas target ordered columns
* @return Spark's {@link Dataset} holding an index table
+ * @VisibleForTesting
*/
@Nonnull
public static Dataset buildColumnStatsTableFor(
@@ -223,13 +224,13 @@ public class ColumnStatsIndexHelper {
* Cleans up any residual index tables, that weren't cleaned up before
*
*
- * @param sparkSession encompassing Spark session
+ * @param sparkSession encompassing Spark session
* @param sourceTableSchema instance of {@link StructType} bearing source table's writer's schema
- * @param sourceBaseFiles list of base-files to be indexed
- * @param orderedCols target ordered columns
- * @param indexFolderPath col-stats index folder path
- * @param commitTime current operation commit instant
- * @param completedCommits all previously completed commit instants
+ * @param sourceBaseFiles list of base-files to be indexed
+ * @param orderedCols target ordered columns
+ * @param indexFolderPath col-stats index folder path
+ * @param commitTime current operation commit instant
+ * @param completedCommits all previously completed commit instants
*/
public static void updateColumnStatsIndexFor(
@Nonnull SparkSession sparkSession,
@@ -424,57 +425,64 @@ public class ColumnStatsIndexHelper {
return String.format("%s_%s", col, statName);
}
- private static Pair
- fetchMinMaxValues(
- @Nonnull DataType colType,
- @Nonnull HoodieColumnRangeMetadata colMetadata) {
+ private static Pair fetchMinMaxValues(@Nonnull DataType colType,
+ @Nonnull HoodieColumnRangeMetadata colMetadata) {
+ Comparable> minValue = colMetadata.getMinValue();
+ Comparable> maxValue = colMetadata.getMaxValue();
+
+ checkState((minValue == null) == (maxValue == null), "Either both min/max values should be null or neither");
+
+ if (minValue == null || maxValue == null) {
+ return Pair.of(null, null);
+ }
+
if (colType instanceof IntegerType) {
return Pair.of(
- new Integer(colMetadata.getMinValue().toString()),
- new Integer(colMetadata.getMaxValue().toString())
+ new Integer(minValue.toString()),
+ new Integer(maxValue.toString())
);
} else if (colType instanceof DoubleType) {
return Pair.of(
- new Double(colMetadata.getMinValue().toString()),
- new Double(colMetadata.getMaxValue().toString())
+ new Double(minValue.toString()),
+ new Double(maxValue.toString())
);
} else if (colType instanceof StringType) {
return Pair.of(
- colMetadata.getMinValue().toString(),
- colMetadata.getMaxValue().toString());
+ minValue.toString(),
+ maxValue.toString());
} else if (colType instanceof DecimalType) {
return Pair.of(
- new BigDecimal(colMetadata.getMinValue().toString()),
- new BigDecimal(colMetadata.getMaxValue().toString()));
+ new BigDecimal(minValue.toString()),
+ new BigDecimal(maxValue.toString()));
} else if (colType instanceof DateType) {
return Pair.of(
- java.sql.Date.valueOf(colMetadata.getMinValue().toString()),
- java.sql.Date.valueOf(colMetadata.getMaxValue().toString()));
+ java.sql.Date.valueOf(minValue.toString()),
+ java.sql.Date.valueOf(maxValue.toString()));
} else if (colType instanceof LongType) {
return Pair.of(
- new Long(colMetadata.getMinValue().toString()),
- new Long(colMetadata.getMaxValue().toString()));
+ new Long(minValue.toString()),
+ new Long(maxValue.toString()));
} else if (colType instanceof ShortType) {
return Pair.of(
- new Short(colMetadata.getMinValue().toString()),
- new Short(colMetadata.getMaxValue().toString()));
+ new Short(minValue.toString()),
+ new Short(maxValue.toString()));
} else if (colType instanceof FloatType) {
return Pair.of(
- new Float(colMetadata.getMinValue().toString()),
- new Float(colMetadata.getMaxValue().toString()));
+ new Float(minValue.toString()),
+ new Float(maxValue.toString()));
} else if (colType instanceof BinaryType) {
return Pair.of(
- ((ByteBuffer) colMetadata.getMinValue()).array(),
- ((ByteBuffer) colMetadata.getMaxValue()).array());
+ ((ByteBuffer) minValue).array(),
+ ((ByteBuffer) maxValue).array());
} else if (colType instanceof BooleanType) {
return Pair.of(
- Boolean.valueOf(colMetadata.getMinValue().toString()),
- Boolean.valueOf(colMetadata.getMaxValue().toString()));
+ Boolean.valueOf(minValue.toString()),
+ Boolean.valueOf(maxValue.toString()));
} else if (colType instanceof ByteType) {
return Pair.of(
- Byte.valueOf(colMetadata.getMinValue().toString()),
- Byte.valueOf(colMetadata.getMaxValue().toString()));
- } else {
+ Byte.valueOf(minValue.toString()),
+ Byte.valueOf(maxValue.toString()));
+ } else {
throw new HoodieException(String.format("Not support type: %s", colType));
}
}
diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala
index fb6a5813a..16d9253ad 100644
--- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala
+++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala
@@ -27,10 +27,10 @@ import org.apache.spark.sql.hudi.SparkAdapter
trait SparkAdapterSupport {
lazy val sparkAdapter: SparkAdapter = {
- val adapterClass = if (HoodieSparkUtils.gteqSpark3_2) {
+ val adapterClass = if (HoodieSparkUtils.isSpark3_2) {
"org.apache.spark.sql.adapter.Spark3_2Adapter"
} else if (HoodieSparkUtils.isSpark3_0 || HoodieSparkUtils.isSpark3_1) {
- "org.apache.spark.sql.adapter.Spark3Adapter"
+ "org.apache.spark.sql.adapter.Spark3_1Adapter"
} else {
"org.apache.spark.sql.adapter.Spark2Adapter"
}
diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala
new file mode 100644
index 000000000..fe30f61b9
--- /dev/null
+++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, SubqueryExpression}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan}
+import org.apache.spark.sql.types.StructType
+
+trait HoodieCatalystExpressionUtils {
+
+ /**
+ * Parses and resolves expression against the attributes of the given table schema.
+ *
+ * For example:
+ *
+ * ts > 1000 and ts <= 1500
+ *
+ * will be resolved as
+ *
+ * And(GreaterThan(ts#590L > 1000), LessThanOrEqual(ts#590L <= 1500))
+ *
+ *
+ * Where ts is a column of the provided [[tableSchema]]
+ *
+ * @param spark spark session
+ * @param exprString string representation of the expression to parse and resolve
+ * @param tableSchema table schema encompassing attributes to resolve against
+ * @return Resolved filter expression
+ */
+ def resolveExpr(spark: SparkSession, exprString: String, tableSchema: StructType): Expression = {
+ val expr = spark.sessionState.sqlParser.parseExpression(exprString)
+ resolveExpr(spark, expr, tableSchema)
+ }
+
+ /**
+ * Resolves provided expression (unless already resolved) against the attributes of the given table schema.
+ *
+ * For example:
+ *
+ * ts > 1000 and ts <= 1500
+ *
+ * will be resolved as
+ *
+ * And(GreaterThan(ts#590L > 1000), LessThanOrEqual(ts#590L <= 1500))
+ *
+ *
+ * Where ts is a column of the provided [[tableSchema]]
+ *
+ * @param spark spark session
+ * @param expr Catalyst expression to be resolved (if not yet)
+ * @param tableSchema table schema encompassing attributes to resolve against
+ * @return Resolved filter expression
+ */
+ def resolveExpr(spark: SparkSession, expr: Expression, tableSchema: StructType): Expression = {
+ val analyzer = spark.sessionState.analyzer
+ val schemaFields = tableSchema.fields
+
+ val resolvedExpr = {
+ val plan: LogicalPlan = Filter(expr, LocalRelation(schemaFields.head, schemaFields.drop(1): _*))
+ analyzer.execute(plan).asInstanceOf[Filter].condition
+ }
+
+ if (!hasUnresolvedRefs(resolvedExpr)) {
+ resolvedExpr
+ } else {
+ throw new IllegalStateException("unresolved attribute")
+ }
+ }
+
+ /**
+ * Split the given predicates into two sequence predicates:
+ * - predicates that references partition columns only(and involves no sub-query);
+ * - other predicates.
+ *
+ * @param sparkSession The spark session
+ * @param predicates The predicates to be split
+ * @param partitionColumns The partition columns
+ * @return (partitionFilters, dataFilters)
+ */
+ def splitPartitionAndDataPredicates(sparkSession: SparkSession,
+ predicates: Array[Expression],
+ partitionColumns: Array[String]): (Array[Expression], Array[Expression]) = {
+ // Validates that the provided names both resolve to the same entity
+ val resolvedNameEquals = sparkSession.sessionState.analyzer.resolver
+
+ predicates.partition(expr => {
+ // Checks whether given expression only references partition columns(and involves no sub-query)
+ expr.references.forall(r => partitionColumns.exists(resolvedNameEquals(r.name, _))) &&
+ !SubqueryExpression.hasSubquery(expr)
+ })
+ }
+
+ /**
+ * Matches an expression iff
+ *
+ *
+ * It references exactly one [[AttributeReference]]
+ * It contains only whitelisted transformations that preserve ordering of the source column [1]
+ *
+ *
+ * [1] Preserving ordering is defined as following: transformation T is defined as ordering preserving in case
+ * values of the source column A values being ordered as a1, a2, a3 ..., will map into column B = T(A) which
+ * will keep the same ordering b1, b2, b3, ... with b1 = T(a1), b2 = T(a2), ...
+ */
+ def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference]
+
+ private def hasUnresolvedRefs(resolvedExpr: Expression): Boolean =
+ resolvedExpr.collectFirst {
+ case _: UnresolvedAttribute | _: UnresolvedFunction => true
+ }.isDefined
+}
diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala
index e41a9c1c8..354ed0ef2 100644
--- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala
+++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala
@@ -32,16 +32,21 @@ import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.execution.datasources.{FilePartition, LogicalRelation, PartitionedFile, SparkParsePartitionUtil}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.{HoodieCatalystExpressionUtils, Row, SparkSession}
import java.util.Locale
/**
- * An interface to adapter the difference between spark2 and spark3
- * in some spark related class.
+ * Interface adapting discrepancies and incompatibilities between different Spark versions
*/
trait SparkAdapter extends Serializable {
+ /**
+ * Creates instance of [[HoodieCatalystExpressionUtils]] providing for common utils operating
+ * on Catalyst Expressions
+ */
+ def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils
+
/**
* Creates instance of [[HoodieAvroSerializer]] providing for ability to serialize
* Spark's [[InternalRow]] into Avro payloads
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala
index 11778da63..98fc21887 100644
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala
@@ -20,7 +20,7 @@ package org.apache.hudi
import org.apache.avro.Schema
import org.apache.avro.generic.GenericRecord
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, Path, PathFilter}
+import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.hbase.io.hfile.CacheConfig
import org.apache.hadoop.mapred.JobConf
import org.apache.hudi.HoodieBaseRelation.{getPartitionPath, isMetadataTable}
@@ -32,7 +32,6 @@ import org.apache.hudi.common.table.timeline.{HoodieInstant, HoodieTimeline}
import org.apache.hudi.common.table.view.HoodieTableFileSystemView
import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient, TableSchemaResolver}
import org.apache.hudi.common.util.StringUtils
-import org.apache.hudi.hadoop.HoodieROTablePathFilter
import org.apache.hudi.io.storage.HoodieHFileReader
import org.apache.hudi.metadata.{HoodieMetadataPayload, HoodieTableMetadata}
import org.apache.spark.execution.datasources.HoodieInMemoryFileIndex
@@ -41,7 +40,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.avro.SchemaConverters
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
-import org.apache.spark.sql.execution.datasources.{FileStatusCache, PartitionDirectory, PartitionedFile}
+import org.apache.spark.sql.execution.datasources.{FileStatusCache, PartitionedFile}
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils
import org.apache.spark.sql.sources.{BaseRelation, Filter, PrunedFilteredScan}
import org.apache.spark.sql.types.StructType
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
index de863203d..ea3b2f061 100644
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
@@ -32,11 +32,10 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{And, Expression, Literal}
import org.apache.spark.sql.execution.datasources.{FileIndex, FileStatusCache, NoopCache, PartitionDirectory}
import org.apache.spark.sql.functions.col
-import org.apache.spark.sql.hudi.DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr
-import org.apache.spark.sql.hudi.HoodieSqlCommonUtils
+import org.apache.spark.sql.hudi.{DataSkippingUtils, HoodieSqlCommonUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StringType, StructType}
-import org.apache.spark.sql.{AnalysisException, Column, SparkSession}
+import org.apache.spark.sql.{Column, SparkSession}
import org.apache.spark.unsafe.types.UTF8String
import java.text.SimpleDateFormat
@@ -244,21 +243,21 @@ case class HoodieFileIndex(spark: SparkSession,
// column references from the filtering expressions, and only transpose records corresponding to the
// columns referenced in those
val transposedColStatsDF =
- queryReferencedColumns.map(colName =>
- colStatsDF.filter(col(HoodieMetadataPayload.COLUMN_STATS_FIELD_COLUMN_NAME).equalTo(colName))
- .select(targetColStatsIndexColumns.map(col): _*)
- .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_NULL_COUNT, getNumNullsColumnNameFor(colName))
- .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_MIN_VALUE, getMinColumnNameFor(colName))
- .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_MAX_VALUE, getMaxColumnNameFor(colName))
- )
- .reduceLeft((left, right) =>
- left.join(right, usingColumn = HoodieMetadataPayload.COLUMN_STATS_FIELD_FILE_NAME))
+ queryReferencedColumns.map(colName =>
+ colStatsDF.filter(col(HoodieMetadataPayload.COLUMN_STATS_FIELD_COLUMN_NAME).equalTo(colName))
+ .select(targetColStatsIndexColumns.map(col): _*)
+ .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_NULL_COUNT, getNumNullsColumnNameFor(colName))
+ .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_MIN_VALUE, getMinColumnNameFor(colName))
+ .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_MAX_VALUE, getMaxColumnNameFor(colName))
+ )
+ .reduceLeft((left, right) =>
+ left.join(right, usingColumn = HoodieMetadataPayload.COLUMN_STATS_FIELD_FILE_NAME))
// Persist DF to avoid re-computing column statistics unraveling
withPersistence(transposedColStatsDF) {
val indexSchema = transposedColStatsDF.schema
val indexFilter =
- queryFilters.map(translateIntoColumnStatsIndexFilterExpr(_, indexSchema))
+ queryFilters.map(DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(_, indexSchema))
.reduce(And)
val allIndexedFileNames =
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/HoodieSparkTypeUtils.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/HoodieSparkTypeUtils.scala
new file mode 100644
index 000000000..d5d95872a
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/HoodieSparkTypeUtils.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.sql.types.{DataType, NumericType, StringType}
+
+// TODO unify w/ DataTypeUtils
+object HoodieSparkTypeUtils {
+
+ /**
+ * Checks whether casting expression of [[from]] [[DataType]] to [[to]] [[DataType]] will
+ * preserve ordering of the elements
+ */
+ def isCastPreservingOrdering(from: DataType, to: DataType): Boolean =
+ (from, to) match {
+ // NOTE: In the casting rules defined by Spark, only casting from String to Numeric
+ // (and vice versa) are the only casts that might break the ordering of the elements after casting
+ case (StringType, _: NumericType) => false
+ case (_: NumericType, StringType) => false
+
+ case _ => true
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala
deleted file mode 100644
index d640c0226..000000000
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation}
-import org.apache.spark.sql.types.StructType
-
-object HoodieCatalystExpressionUtils {
-
- /**
- * Resolve filter expression from string expr with given table schema, for example:
- *
- * ts > 1000 and ts <= 1500
- *
- * will be resolved as
- *
- * And(GreaterThan(ts#590L > 1000), LessThanOrEqual(ts#590L <= 1500))
- *
- *
- * @param spark The spark session
- * @param exprString String to be resolved
- * @param tableSchema The table schema
- * @return Resolved filter expression
- */
- def resolveFilterExpr(spark: SparkSession, exprString: String, tableSchema: StructType): Expression = {
- val expr = spark.sessionState.sqlParser.parseExpression(exprString)
- resolveFilterExpr(spark, expr, tableSchema)
- }
-
- def resolveFilterExpr(spark: SparkSession, expr: Expression, tableSchema: StructType): Expression = {
- val schemaFields = tableSchema.fields
- val resolvedExpr = spark.sessionState.analyzer.ResolveReferences(
- Filter(expr,
- LocalRelation(schemaFields.head, schemaFields.drop(1): _*))
- )
- .asInstanceOf[Filter].condition
-
- checkForUnresolvedRefs(resolvedExpr)
- }
-
- private def checkForUnresolvedRefs(resolvedExpr: Expression): Expression =
- resolvedExpr match {
- case UnresolvedAttribute(_) => throw new IllegalStateException("unresolved attribute")
- case _ => resolvedExpr.mapChildren(e => checkForUnresolvedRefs(e))
- }
-
- /**
- * Split the given predicates into two sequence predicates:
- * - predicates that references partition columns only(and involves no sub-query);
- * - other predicates.
- *
- * @param sparkSession The spark session
- * @param predicates The predicates to be split
- * @param partitionColumns The partition columns
- * @return (partitionFilters, dataFilters)
- */
- def splitPartitionAndDataPredicates(sparkSession: SparkSession,
- predicates: Array[Expression],
- partitionColumns: Array[String]): (Array[Expression], Array[Expression]) = {
- // Validates that the provided names both resolve to the same entity
- val resolvedNameEquals = sparkSession.sessionState.analyzer.resolver
-
- predicates.partition(expr => {
- // Checks whether given expression only references partition columns(and involves no sub-query)
- expr.references.forall(r => partitionColumns.exists(resolvedNameEquals(r.name, _))) &&
- !SubqueryExpression.hasSubquery(expr)
- })
- }
-}
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala
index 06b92e204..b7ddd2828 100644
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala
@@ -17,14 +17,17 @@
package org.apache.spark.sql.hudi
+import org.apache.hudi.SparkAdapterSupport
+import org.apache.hudi.common.util.ValidationUtils.checkState
import org.apache.hudi.index.columnstats.ColumnStatsIndexHelper.{getMaxColumnNameFor, getMinColumnNameFor, getNumNullsColumnNameFor}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith, SubqueryExpression}
import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.hudi.ColumnStatsExpressionUtils._
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{AnalysisException, HoodieCatalystExpressionUtils}
import org.apache.spark.unsafe.types.UTF8String
object DataSkippingUtils extends Logging {
@@ -59,147 +62,205 @@ object DataSkippingUtils extends Logging {
}
private def tryComposeIndexFilterExpr(sourceExpr: Expression, indexSchema: StructType): Option[Expression] = {
- def minValue(colName: String) = col(getMinColumnNameFor(colName)).expr
- def maxValue(colName: String) = col(getMaxColumnNameFor(colName)).expr
- def numNulls(colName: String) = col(getNumNullsColumnNameFor(colName)).expr
-
- def colContainsValuesEqualToLiteral(colName: String, value: Literal): Expression =
- // Only case when column C contains value V is when min(C) <= V <= max(c)
- And(LessThanOrEqual(minValue(colName), value), GreaterThanOrEqual(maxValue(colName), value))
-
- def colContainsOnlyValuesEqualToLiteral(colName: String, value: Literal) =
- // Only case when column C contains _only_ value V is when min(C) = V AND max(c) = V
- And(EqualTo(minValue(colName), value), EqualTo(maxValue(colName), value))
-
+ //
+ // For translation of the Filter Expression for the Data Table into Filter Expression for Column Stats Index, we're
+ // assuming that
+ // - The column A is queried in the Data Table (hereafter referred to as "colA")
+ // - Filter Expression is a relational expression (ie "=", "<", "<=", ...) of the following form
+ //
+ // ```transform_expr(colA) = value_expr```
+ //
+ // Where
+ // - "transform_expr" is an expression of the _transformation_ which preserve ordering of the "colA"
+ // - "value_expr" is an "value"-expression (ie one NOT referring to other attributes/columns or containing sub-queries)
+ //
+ // We translate original Filter Expr into the one querying Column Stats Index like following: let's consider
+ // equality Filter Expr referred to above:
+ //
+ // ```transform_expr(colA) = value_expr```
+ //
+ // This expression will be translated into following Filter Expression for the Column Stats Index:
+ //
+ // ```(transform_expr(colA_minValue) <= value_expr) AND (value_expr <= transform_expr(colA_maxValue))```
+ //
+ // Which will enable us to match files with the range of values in column A containing the target ```value_expr```
+ //
+ // NOTE: That we can apply ```transform_expr``` transformation precisely b/c it preserves the ordering of the
+ // values of the source column, ie following holds true:
+ //
+ // colA_minValue = min(colA) => transform_expr(colA_minValue) = min(transform_expr(colA))
+ // colA_maxValue = max(colA) => transform_expr(colA_maxValue) = max(transform_expr(colA))
+ //
sourceExpr match {
- // Filter "colA = b"
- // Translates to "colA_minValue <= b AND colA_maxValue >= b" condition for index lookup
- case EqualTo(attribute: AttributeReference, value: Literal) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => colContainsValuesEqualToLiteral(colName, value))
+ // If Expression is not resolved, we can't perform the analysis accurately, bailing
+ case expr if !expr.resolved => None
- // Filter "b = colA"
- // Translates to "colA_minValue <= b AND colA_maxValue >= b" condition for index lookup
- case EqualTo(value: Literal, attribute: AttributeReference) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => colContainsValuesEqualToLiteral(colName, value))
+ // Filter "expr(colA) = B" and "B = expr(colA)"
+ // Translates to "(expr(colA_minValue) <= B) AND (B <= expr(colA_maxValue))" condition for index lookup
+ case EqualTo(sourceExpr @ AllowedTransformationExpression(attrRef), valueExpr: Expression) if isValueExpression(valueExpr) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ // NOTE: Since we're supporting (almost) arbitrary expressions of the form `f(colA) = B`, we have to
+ // appropriately translate such original expression targeted at Data Table, to corresponding
+ // expression targeted at Column Stats Index Table. For that, we take original expression holding
+ // [[AttributeReference]] referring to the Data Table, and swap it w/ expression referring to
+ // corresponding column in the Column Stats Index
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ genColumnValuesEqualToExpression(colName, valueExpr, targetExprBuilder)
+ }
- // Filter "colA != b"
- // Translates to "NOT(colA_minValue = b AND colA_maxValue = b)"
- // NOTE: This is NOT an inversion of `colA = b`
- case Not(EqualTo(attribute: AttributeReference, value: Literal)) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => Not(colContainsOnlyValuesEqualToLiteral(colName, value)))
+ case EqualTo(valueExpr: Expression, sourceExpr @ AllowedTransformationExpression(attrRef)) if isValueExpression(valueExpr) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ genColumnValuesEqualToExpression(colName, valueExpr, targetExprBuilder)
+ }
- // Filter "b != colA"
- // Translates to "NOT(colA_minValue = b AND colA_maxValue = b)"
- // NOTE: This is NOT an inversion of `colA = b`
- case Not(EqualTo(value: Literal, attribute: AttributeReference)) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => Not(colContainsOnlyValuesEqualToLiteral(colName, value)))
+ // Filter "expr(colA) != B" and "B != expr(colA)"
+ // Translates to "NOT(expr(colA_minValue) = B AND expr(colA_maxValue) = B)"
+ // NOTE: This is NOT an inversion of `colA = b`, instead this filter ONLY excludes files for which `colA = B`
+ // holds true
+ case Not(EqualTo(sourceExpr @ AllowedTransformationExpression(attrRef), value: Expression)) if isValueExpression(value) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ Not(genColumnOnlyValuesEqualToExpression(colName, value, targetExprBuilder))
+ }
+
+ case Not(EqualTo(value: Expression, sourceExpr @ AllowedTransformationExpression(attrRef))) if isValueExpression(value) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ Not(genColumnOnlyValuesEqualToExpression(colName, value, targetExprBuilder))
+ }
// Filter "colA = null"
// Translates to "colA_num_nulls = null" for index lookup
- case equalNullSafe @ EqualNullSafe(_: AttributeReference, _ @ Literal(null, _)) =>
- getTargetIndexedColName(equalNullSafe.left, indexSchema)
- .map(colName => EqualTo(numNulls(colName), equalNullSafe.right))
+ case EqualNullSafe(attrRef: AttributeReference, litNull @ Literal(null, _)) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map(colName => EqualTo(genColNumNullsExpr(colName), litNull))
- // Filter "colA < b"
- // Translates to "colA_minValue < b" for index lookup
- case LessThan(attribute: AttributeReference, value: Literal) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => LessThan(minValue(colName), value))
+ // Filter "expr(colA) < B" and "B > expr(colA)"
+ // Translates to "expr(colA_minValue) < B" for index lookup
+ case LessThan(sourceExpr @ AllowedTransformationExpression(attrRef), value: Expression) if isValueExpression(value) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ LessThan(targetExprBuilder.apply(genColMinValueExpr(colName)), value)
+ }
- // Filter "b > colA"
- // Translates to "b > colA_minValue" for index lookup
- case GreaterThan(value: Literal, attribute: AttributeReference) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => LessThan(minValue(colName), value))
+ case GreaterThan(value: Expression, sourceExpr @ AllowedTransformationExpression(attrRef)) if isValueExpression(value) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ LessThan(targetExprBuilder.apply(genColMinValueExpr(colName)), value)
+ }
- // Filter "b < colA"
- // Translates to "b < colA_maxValue" for index lookup
- case LessThan(value: Literal, attribute: AttributeReference) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => GreaterThan(maxValue(colName), value))
+ // Filter "B < expr(colA)" and "expr(colA) > B"
+ // Translates to "B < colA_maxValue" for index lookup
+ case LessThan(value: Expression, sourceExpr @ AllowedTransformationExpression(attrRef)) if isValueExpression(value) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ GreaterThan(targetExprBuilder.apply(genColMaxValueExpr(colName)), value)
+ }
- // Filter "colA > b"
- // Translates to "colA_maxValue > b" for index lookup
- case GreaterThan(attribute: AttributeReference, value: Literal) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => GreaterThan(maxValue(colName), value))
+ case GreaterThan(sourceExpr @ AllowedTransformationExpression(attrRef), value: Expression) if isValueExpression(value) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ GreaterThan(targetExprBuilder.apply(genColMaxValueExpr(colName)), value)
+ }
- // Filter "colA <= b"
- // Translates to "colA_minValue <= b" for index lookup
- case LessThanOrEqual(attribute: AttributeReference, value: Literal) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => LessThanOrEqual(minValue(colName), value))
+ // Filter "expr(colA) <= B" and "B >= expr(colA)"
+ // Translates to "colA_minValue <= B" for index lookup
+ case LessThanOrEqual(sourceExpr @ AllowedTransformationExpression(attrRef), value: Expression) if isValueExpression(value) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ LessThanOrEqual(targetExprBuilder.apply(genColMinValueExpr(colName)), value)
+ }
- // Filter "b >= colA"
- // Translates to "b >= colA_minValue" for index lookup
- case GreaterThanOrEqual(value: Literal, attribute: AttributeReference) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => LessThanOrEqual(minValue(colName), value))
+ case GreaterThanOrEqual(value: Expression, sourceExpr @ AllowedTransformationExpression(attrRef)) if isValueExpression(value) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ LessThanOrEqual(targetExprBuilder.apply(genColMinValueExpr(colName)), value)
+ }
- // Filter "b <= colA"
- // Translates to "b <= colA_maxValue" for index lookup
- case LessThanOrEqual(value: Literal, attribute: AttributeReference) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => GreaterThanOrEqual(maxValue(colName), value))
+ // Filter "B <= expr(colA)" and "expr(colA) >= B"
+ // Translates to "B <= colA_maxValue" for index lookup
+ case LessThanOrEqual(value: Expression, sourceExpr @ AllowedTransformationExpression(attrRef)) if isValueExpression(value) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ GreaterThanOrEqual(targetExprBuilder.apply(genColMaxValueExpr(colName)), value)
+ }
- // Filter "colA >= b"
- // Translates to "colA_maxValue >= b" for index lookup
- case GreaterThanOrEqual(attribute: AttributeReference, right: Literal) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => GreaterThanOrEqual(maxValue(colName), right))
+ case GreaterThanOrEqual(sourceExpr @ AllowedTransformationExpression(attrRef), value: Expression) if isValueExpression(value) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ GreaterThanOrEqual(targetExprBuilder.apply(genColMaxValueExpr(colName)), value)
+ }
// Filter "colA is null"
// Translates to "colA_num_nulls > 0" for index lookup
case IsNull(attribute: AttributeReference) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => GreaterThan(numNulls(colName), Literal(0)))
+ getTargetIndexedColumnName(attribute, indexSchema)
+ .map(colName => GreaterThan(genColNumNullsExpr(colName), Literal(0)))
// Filter "colA is not null"
// Translates to "colA_num_nulls = 0" for index lookup
case IsNotNull(attribute: AttributeReference) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => EqualTo(numNulls(colName), Literal(0)))
+ getTargetIndexedColumnName(attribute, indexSchema)
+ .map(colName => EqualTo(genColNumNullsExpr(colName), Literal(0)))
- // Filter "colA in (a, b, ...)"
- // Translates to "(colA_minValue <= a AND colA_maxValue >= a) OR (colA_minValue <= b AND colA_maxValue >= b)" for index lookup
- // NOTE: This is equivalent to "colA = a OR colA = b OR ..."
- case In(attribute: AttributeReference, list: Seq[Literal]) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName =>
- list.map { lit => colContainsValuesEqualToLiteral(colName, lit) }.reduce(Or)
- )
+ // Filter "expr(colA) in (B1, B2, ...)"
+ // Translates to "(colA_minValue <= B1 AND colA_maxValue >= B1) OR (colA_minValue <= B2 AND colA_maxValue >= B2) ... "
+ // for index lookup
+ // NOTE: This is equivalent to "colA = B1 OR colA = B2 OR ..."
+ case In(sourceExpr @ AllowedTransformationExpression(attrRef), list: Seq[Expression]) if list.forall(isValueExpression) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ list.map(lit => genColumnValuesEqualToExpression(colName, lit, targetExprBuilder)).reduce(Or)
+ }
- // Filter "colA not in (a, b, ...)"
- // Translates to "NOT((colA_minValue = a AND colA_maxValue = a) OR (colA_minValue = b AND colA_maxValue = b))" for index lookup
- // NOTE: This is NOT an inversion of `in (a, b, ...)` expr, this is equivalent to "colA != a AND colA != b AND ..."
- case Not(In(attribute: AttributeReference, list: Seq[Literal])) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName =>
- Not(
- list.map { lit => colContainsOnlyValuesEqualToLiteral(colName, lit) }.reduce(Or)
- )
- )
+ // Filter "expr(colA) not in (B1, B2, ...)"
+ // Translates to "NOT((colA_minValue = B1 AND colA_maxValue = B1) OR (colA_minValue = B2 AND colA_maxValue = B2))" for index lookup
+ // NOTE: This is NOT an inversion of `in (B1, B2, ...)` expr, this is equivalent to "colA != B1 AND colA != B2 AND ..."
+ case Not(In(sourceExpr @ AllowedTransformationExpression(attrRef), list: Seq[Expression])) if list.forall(_.foldable) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ Not(list.map(lit => genColumnOnlyValuesEqualToExpression(colName, lit, targetExprBuilder)).reduce(Or))
+ }
// Filter "colA like 'xxx%'"
- // Translates to "colA_minValue <= xxx AND colA_maxValue >= xxx" for index lookup
- // NOTE: That this operator only matches string prefixes, and this is
- // essentially equivalent to "colA = b" expression
- case StartsWith(attribute, v @ Literal(_: UTF8String, _)) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName => colContainsValuesEqualToLiteral(colName, v))
+ // Translates to "colA_minValue <= xxx AND xxx <= colA_maxValue" for index lookup
+ //
+ // NOTE: Since a) this operator matches strings by prefix and b) given that this column is going to be ordered
+ // lexicographically, we essentially need to check that provided literal falls w/in min/max bounds of the
+ // given column
+ case StartsWith(sourceExpr @ AllowedTransformationExpression(attrRef), v @ Literal(_: UTF8String, _)) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ genColumnValuesEqualToExpression(colName, v, targetExprBuilder)
+ }
- // Filter "colA not like 'xxx%'"
- // Translates to "NOT(colA_minValue like 'xxx%' AND colA_maxValue like 'xxx%')" for index lookup
+ // Filter "expr(colA) not like 'xxx%'"
+ // Translates to "NOT(expr(colA_minValue) like 'xxx%' AND expr(colA_maxValue) like 'xxx%')" for index lookup
// NOTE: This is NOT an inversion of "colA like xxx"
- case Not(StartsWith(attribute, value @ Literal(_: UTF8String, _))) =>
- getTargetIndexedColName(attribute, indexSchema)
- .map(colName =>
- Not(And(StartsWith(minValue(colName), value), StartsWith(maxValue(colName), value)))
- )
+ case Not(StartsWith(sourceExpr @ AllowedTransformationExpression(attrRef), value @ Literal(_: UTF8String, _))) =>
+ getTargetIndexedColumnName(attrRef, indexSchema)
+ .map { colName =>
+ val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
+ val minValueExpr = targetExprBuilder.apply(genColMinValueExpr(colName))
+ val maxValueExpr = targetExprBuilder.apply(genColMaxValueExpr(colName))
+ Not(And(StartsWith(minValueExpr, value), StartsWith(maxValueExpr, value)))
+ }
case or: Or =>
val resLeft = createColumnStatsIndexFilterExprInternal(or.left, indexSchema)
@@ -238,7 +299,7 @@ object DataSkippingUtils extends Logging {
.forall(stat => indexSchema.exists(_.name == stat))
}
- private def getTargetIndexedColName(resolvedExpr: Expression, indexSchema: StructType): Option[String] = {
+ private def getTargetIndexedColumnName(resolvedExpr: AttributeReference, indexSchema: StructType): Option[String] = {
val colName = UnresolvedAttribute(getTargetColNameParts(resolvedExpr)).name
// Verify that the column is indexed
@@ -261,3 +322,91 @@ object DataSkippingUtils extends Logging {
}
}
}
+
+private object ColumnStatsExpressionUtils {
+
+ def genColMinValueExpr(colName: String): Expression =
+ col(getMinColumnNameFor(colName)).expr
+ def genColMaxValueExpr(colName: String): Expression =
+ col(getMaxColumnNameFor(colName)).expr
+ def genColNumNullsExpr(colName: String): Expression =
+ col(getNumNullsColumnNameFor(colName)).expr
+
+ def genColumnValuesEqualToExpression(colName: String,
+ value: Expression,
+ targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = {
+ // TODO clean up
+ checkState(isValueExpression(value))
+
+ val minValueExpr = targetExprBuilder.apply(genColMinValueExpr(colName))
+ val maxValueExpr = targetExprBuilder.apply(genColMaxValueExpr(colName))
+ // Only case when column C contains value V is when min(C) <= V <= max(c)
+ And(LessThanOrEqual(minValueExpr, value), GreaterThanOrEqual(maxValueExpr, value))
+ }
+
+ def genColumnOnlyValuesEqualToExpression(colName: String,
+ value: Expression,
+ targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = {
+ // TODO clean up
+ checkState(isValueExpression(value))
+
+ val minValueExpr = targetExprBuilder.apply(genColMinValueExpr(colName))
+ val maxValueExpr = targetExprBuilder.apply(genColMaxValueExpr(colName))
+ // Only case when column C contains _only_ value V is when min(C) = V AND max(c) = V
+ And(EqualTo(minValueExpr, value), EqualTo(maxValueExpr, value))
+ }
+
+ def swapAttributeRefInExpr(sourceExpr: Expression, from: AttributeReference, to: Expression): Expression = {
+ checkState(sourceExpr.references.size == 1)
+ sourceExpr.transformDown {
+ case attrRef: AttributeReference if attrRef.sameRef(from) => to
+ }
+ }
+
+ /**
+ * This check is used to validate that the expression that target column is compared against
+ *
+ * a) Has no references to other attributes (for ex, columns)
+ * b) Does not contain sub-queries
+ *
+ *
+ * This in turn allows us to be certain that Spark will be able to evaluate such expression
+ * against Column Stats Index as well
+ */
+ def isValueExpression(expr: Expression): Boolean =
+ expr.references.isEmpty && !SubqueryExpression.hasSubquery(expr)
+
+ /**
+ * This utility pattern-matches an expression iff
+ *
+ *
+ * It references *exactly* 1 attribute (column)
+ * It does NOT contain sub-queries
+ * It contains only whitelisted transformations that preserve ordering of the source column [1]
+ *
+ *
+ * [1] This is required to make sure that we can correspondingly map Column Stats Index values as well. Applying
+ * transformations that do not preserve the ordering might lead to incorrect results being returned by Data
+ * Skipping flow.
+ *
+ * Returns only [[AttributeReference]] contained as a sub-expression
+ */
+ object AllowedTransformationExpression extends SparkAdapterSupport {
+ val exprUtils: HoodieCatalystExpressionUtils = sparkAdapter.createCatalystExpressionUtils()
+
+ def unapply(expr: Expression): Option[AttributeReference] = {
+ // First step, we check that expression
+ // - Does NOT contain sub-queries
+ // - Does contain exactly 1 attribute
+ if (SubqueryExpression.hasSubquery(expr) || expr.references.size != 1) {
+ None
+ } else {
+ // Second step, we validate that holding expression is an actually permitted
+ // transformation
+ // NOTE: That transformation composition is permitted
+ exprUtils.tryMatchAttributeOrderingPreservingTransformation(expr)
+ }
+ }
+ }
+}
+
diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RunClusteringProcedure.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RunClusteringProcedure.scala
index 442ee0441..231d0939c 100644
--- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RunClusteringProcedure.scala
+++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RunClusteringProcedure.scala
@@ -24,9 +24,9 @@ import org.apache.hudi.common.util.ValidationUtils.checkArgument
import org.apache.hudi.common.util.{ClusteringUtils, Option => HOption}
import org.apache.hudi.config.HoodieClusteringConfig
import org.apache.hudi.exception.HoodieClusteringException
-import org.apache.hudi.{AvroConversionUtils, HoodieCLIUtils, HoodieFileIndex}
+import org.apache.hudi.{AvroConversionUtils, HoodieCLIUtils, HoodieFileIndex, SparkAdapterSupport}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{HoodieCatalystExpressionUtils, Row}
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.execution.datasources.FileStatusCache
import org.apache.spark.sql.types._
@@ -34,7 +34,14 @@ import org.apache.spark.sql.types._
import java.util.function.Supplier
import scala.collection.JavaConverters._
-class RunClusteringProcedure extends BaseProcedure with ProcedureBuilder with PredicateHelper with Logging {
+class RunClusteringProcedure extends BaseProcedure
+ with ProcedureBuilder
+ with PredicateHelper
+ with Logging
+ with SparkAdapterSupport {
+
+ private val exprUtils = sparkAdapter.createCatalystExpressionUtils()
+
/**
* OPTIMIZE table_name|table_path [WHERE predicate]
* [ORDER BY (col_name1 [, ...] ) ]
@@ -120,9 +127,9 @@ class RunClusteringProcedure extends BaseProcedure with ProcedureBuilder with Pr
// Resolve partition predicates
val schemaResolver = new TableSchemaResolver(metaClient)
val tableSchema = AvroConversionUtils.convertAvroSchemaToStructType(schemaResolver.getTableAvroSchema)
- val condition = HoodieCatalystExpressionUtils.resolveFilterExpr(sparkSession, predicate, tableSchema)
+ val condition = exprUtils.resolveExpr(sparkSession, predicate, tableSchema)
val partitionColumns = metaClient.getTableConfig.getPartitionFields.orElse(Array[String]())
- val (partitionPredicates, dataPredicates) = HoodieCatalystExpressionUtils.splitPartitionAndDataPredicates(
+ val (partitionPredicates, dataPredicates) = exprUtils.splitPartitionAndDataPredicates(
sparkSession, splitConjunctivePredicates(condition).toArray, partitionColumns)
checkArgument(dataPredicates.isEmpty, "Only partition predicates are allowed")
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala
index 6b96472d4..43f070e6c 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala
@@ -20,30 +20,44 @@ package org.apache.hudi
import org.apache.hudi.index.columnstats.ColumnStatsIndexHelper
import org.apache.hudi.testutils.HoodieClientTestBase
import org.apache.spark.sql.catalyst.expressions.{Expression, Not}
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions.{col, lower}
import org.apache.spark.sql.hudi.DataSkippingUtils
-import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType, VarcharType}
-import org.apache.spark.sql.{Column, HoodieCatalystExpressionUtils, SparkSession}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{Column, HoodieCatalystExpressionUtils, Row, SparkSession}
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments.arguments
import org.junit.jupiter.params.provider.{Arguments, MethodSource}
+import java.sql.Timestamp
import scala.collection.JavaConverters._
// NOTE: Only A, B columns are indexed
case class IndexRow(
file: String,
- A_minValue: Long,
- A_maxValue: Long,
- A_num_nulls: Long,
+
+ // Corresponding A column is LongType
+ A_minValue: Long = -1,
+ A_maxValue: Long = -1,
+ A_num_nulls: Long = -1,
+
+ // Corresponding B column is StringType
B_minValue: String = null,
B_maxValue: String = null,
- B_num_nulls: Long = -1
-)
+ B_num_nulls: Long = -1,
-class TestDataSkippingUtils extends HoodieClientTestBase {
+ // Corresponding B column is TimestampType
+ C_minValue: Timestamp = null,
+ C_maxValue: Timestamp = null,
+ C_num_nulls: Long = -1
+) {
+ def toRow: Row = Row(productIterator.toSeq: _*)
+}
+
+class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSupport {
+
+ val exprUtils: HoodieCatalystExpressionUtils = sparkAdapter.createCatalystExpressionUtils()
var spark: SparkSession = _
@@ -53,17 +67,18 @@ class TestDataSkippingUtils extends HoodieClientTestBase {
spark = sqlContext.sparkSession
}
- val indexedCols = Seq("A", "B")
- val sourceTableSchema =
+ val indexedCols: Seq[String] = Seq("A", "B", "C")
+ val sourceTableSchema: StructType =
StructType(
Seq(
StructField("A", LongType),
StructField("B", StringType),
- StructField("C", VarcharType(32))
+ StructField("C", TimestampType),
+ StructField("D", VarcharType(32))
)
)
- val indexSchema =
+ val indexSchema: StructType =
ColumnStatsIndexHelper.composeIndexSchema(
sourceTableSchema.fields.toSeq
.filter(f => indexedCols.contains(f.name))
@@ -71,15 +86,17 @@ class TestDataSkippingUtils extends HoodieClientTestBase {
)
@ParameterizedTest
- @MethodSource(Array("testBaseLookupFilterExpressionsSource", "testAdvancedLookupFilterExpressionsSource"))
+ @MethodSource(
+ Array(
+ "testBasicLookupFilterExpressionsSource",
+ "testAdvancedLookupFilterExpressionsSource",
+ "testCompositeFilterExpressionsSource"
+ ))
def testLookupFilterExpressions(sourceExpr: String, input: Seq[IndexRow], output: Seq[String]): Unit = {
- val resolvedExpr: Expression = HoodieCatalystExpressionUtils.resolveFilterExpr(spark, sourceExpr, sourceTableSchema)
+ val resolvedExpr: Expression = exprUtils.resolveExpr(spark, sourceExpr, sourceTableSchema)
val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema)
- val spark2 = spark
- import spark2.implicits._
-
- val indexDf = spark.createDataset(input)
+ val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema)
val rows = indexDf.where(new Column(lookupFilter))
.select("file")
@@ -93,7 +110,7 @@ class TestDataSkippingUtils extends HoodieClientTestBase {
@ParameterizedTest
@MethodSource(Array("testStringsLookupFilterExpressionsSource"))
def testStringsLookupFilterExpressions(sourceExpr: Expression, input: Seq[IndexRow], output: Seq[String]): Unit = {
- val resolvedExpr = HoodieCatalystExpressionUtils.resolveFilterExpr(spark, sourceExpr, sourceTableSchema)
+ val resolvedExpr = exprUtils.resolveExpr(spark, sourceExpr, sourceTableSchema)
val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema)
val spark2 = spark
@@ -130,11 +147,21 @@ object TestDataSkippingUtils {
IndexRow("file_3", 0, 0, 0, "aaa", "aba", 0),
IndexRow("file_4", 0, 0, 0, "abc123", "abc345", 0) // all strings start w/ "abc"
),
+ Seq("file_1", "file_2", "file_3")),
+ arguments(
+ // Composite expression
+ Not(lower(col("B")).startsWith("abc").expr),
+ Seq(
+ IndexRow("file_1", 0, 0, 0, "ABA", "ADF", 1), // may contain strings starting w/ "ABC" (after upper)
+ IndexRow("file_2", 0, 0, 0, "ADF", "AZY", 0),
+ IndexRow("file_3", 0, 0, 0, "AAA", "ABA", 0),
+ IndexRow("file_4", 0, 0, 0, "ABC123", "ABC345", 0) // all strings start w/ "ABC" (after upper)
+ ),
Seq("file_1", "file_2", "file_3"))
)
}
- def testBaseLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = {
+ def testBasicLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = {
java.util.stream.Stream.of(
// TODO cases
// A = null
@@ -263,6 +290,23 @@ object TestDataSkippingUtils {
IndexRow("file_4", 0, 0, 0), // only contains 0
IndexRow("file_5", 1, 1, 0) // only contains 1
),
+ Seq("file_1", "file_2", "file_3")),
+ arguments(
+ // Value expression containing expression, which isn't a literal
+ "A = int('0')",
+ Seq(
+ IndexRow("file_1", 1, 2, 0),
+ IndexRow("file_2", -1, 1, 0)
+ ),
+ Seq("file_2")),
+ arguments(
+ // Value expression containing reference to the other attribute (column), fallback
+ "A = D",
+ Seq(
+ IndexRow("file_1", 1, 2, 0),
+ IndexRow("file_2", -1, 1, 0),
+ IndexRow("file_3", -2, -1, 0)
+ ),
Seq("file_1", "file_2", "file_3"))
)
}
@@ -316,8 +360,8 @@ object TestDataSkippingUtils {
Seq("file_1", "file_2", "file_3")),
arguments(
- // Queries contains expression involving non-indexed column C
- "A = 0 AND B = 'abc' AND C = '...'",
+ // Queries contains expression involving non-indexed column D
+ "A = 0 AND B = 'abc' AND D IS NULL",
Seq(
IndexRow("file_1", 1, 2, 0),
IndexRow("file_2", -1, 1, 0),
@@ -327,8 +371,8 @@ object TestDataSkippingUtils {
Seq("file_4")),
arguments(
- // Queries contains expression involving non-indexed column C
- "A = 0 OR B = 'abc' OR C = '...'",
+ // Queries contains expression involving non-indexed column D
+ "A = 0 OR B = 'abc' OR D IS NULL",
Seq(
IndexRow("file_1", 1, 2, 0),
IndexRow("file_2", -1, 1, 0),
@@ -338,4 +382,206 @@ object TestDataSkippingUtils {
Seq("file_1", "file_2", "file_3", "file_4"))
)
}
+
+ def testCompositeFilterExpressionsSource(): java.util.stream.Stream[Arguments] = {
+ java.util.stream.Stream.of(
+ arguments(
+ "date_format(C, 'MM/dd/yyyy') = '03/06/2022'",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_2")),
+ arguments(
+ "'03/06/2022' = date_format(C, 'MM/dd/yyyy')",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_2")),
+ arguments(
+ "'03/06/2022' != date_format(C, 'MM/dd/yyyy')",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_1")),
+ arguments(
+ "date_format(C, 'MM/dd/yyyy') != '03/06/2022'",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_1")),
+ arguments(
+ "date_format(C, 'MM/dd/yyyy') < '03/07/2022'",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_2")),
+ arguments(
+ "'03/07/2022' > date_format(C, 'MM/dd/yyyy')",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_2")),
+ arguments(
+ "'03/07/2022' < date_format(C, 'MM/dd/yyyy')",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_1")),
+ arguments(
+ "date_format(C, 'MM/dd/yyyy') > '03/07/2022'",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_1")),
+ arguments(
+ "date_format(C, 'MM/dd/yyyy') <= '03/06/2022'",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_2")),
+ arguments(
+ "'03/06/2022' >= date_format(C, 'MM/dd/yyyy')",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_2")),
+ arguments(
+ "'03/08/2022' <= date_format(C, 'MM/dd/yyyy')",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_1")),
+ arguments(
+ "date_format(C, 'MM/dd/yyyy') >= '03/08/2022'",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_1")),
+ arguments(
+ "date_format(C, 'MM/dd/yyyy') IN ('03/08/2022')",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_1")),
+ arguments(
+ "date_format(C, 'MM/dd/yyyy') NOT IN ('03/06/2022')",
+ Seq(
+ IndexRow("file_1",
+ C_minValue = new Timestamp(1646711448000L), // 03/07/2022
+ C_maxValue = new Timestamp(1646797848000L), // 03/08/2022
+ C_num_nulls = 0),
+ IndexRow("file_2",
+ C_minValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_maxValue = new Timestamp(1646625048000L), // 03/06/2022
+ C_num_nulls = 0)
+ ),
+ Seq("file_1")),
+ arguments(
+ // Should be identical to the one above
+ "date_format(to_timestamp(B, 'yyyy-MM-dd'), 'MM/dd/yyyy') NOT IN ('03/06/2022')",
+ Seq(
+ IndexRow("file_1",
+ B_minValue = "2022-03-07", // 03/07/2022
+ B_maxValue = "2022-03-08", // 03/08/2022
+ B_num_nulls = 0),
+ IndexRow("file_2",
+ B_minValue = "2022-03-06", // 03/06/2022
+ B_maxValue = "2022-03-06", // 03/06/2022
+ B_num_nulls = 0)
+ ),
+ Seq("file_1"))
+
+ )
+ }
}
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala
new file mode 100644
index 000000000..3e233352c
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.HoodieSparkTypeUtils.isCastPreservingOrdering
+import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
+
+object HoodieSpark2CatalystExpressionUtils extends HoodieCatalystExpressionUtils {
+
+ override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = {
+ expr match {
+ case OrderPreservingTransformation(attrRef) => Some(attrRef)
+ case _ => None
+ }
+ }
+
+ private object OrderPreservingTransformation {
+ def unapply(expr: Expression): Option[AttributeReference] = {
+ expr match {
+ // Date/Time Expressions
+ case DateFormatClass(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case DateAdd(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case DateSub(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case DateDiff(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case DateDiff(_, OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case FromUnixTime(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case FromUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case ParseToDate(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case ParseToTimestamp(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case ToUnixTimestamp(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case ToUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+
+ // String Expressions
+ case Lower(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Upper(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case org.apache.spark.sql.catalyst.expressions.Left(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+
+ // Math Expressions
+ // Binary
+ case Add(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case Add(_, OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Multiply(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case Multiply(_, OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Divide(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case BitwiseOr(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case BitwiseOr(_, OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ // Unary
+ case Exp(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Expm1(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log10(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log1p(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log2(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case ShiftLeft(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case ShiftRight(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+
+ // Other
+ case cast @ Cast(OrderPreservingTransformation(attrRef), _, _)
+ if isCastPreservingOrdering(cast.child.dataType, cast.dataType) => Some(attrRef)
+
+ // Identity transformation
+ case attrRef: AttributeReference => Some(attrRef)
+ // No match
+ case _ => None
+ }
+ }
+ }
+
+}
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala
index 54c8b912a..42ad66598 100644
--- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.hudi.parser.HoodieSpark2ExtendedSqlParser
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieSpark2CatalystExpressionUtils, Row, SparkSession}
import scala.collection.mutable.ArrayBuffer
@@ -42,10 +42,12 @@ import scala.collection.mutable.ArrayBuffer
*/
class Spark2Adapter extends SparkAdapter {
- def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer =
+ override def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils = HoodieSpark2CatalystExpressionUtils
+
+ override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer =
new HoodieSparkAvroSerializer(rootCatalystType, rootAvroType, nullable)
- def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
+ override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
new HoodieSpark2AvroDeserializer(rootAvroType, rootCatalystType)
override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala
similarity index 93%
rename from hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala
rename to hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala
index ad338323e..33aae23df 100644
--- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala
+++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala
@@ -39,14 +39,14 @@ import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SparkSession}
/**
- * The adapter for spark3.
+ * Base implementation of [[SparkAdapter]] for Spark 3.x branch
*/
-class Spark3Adapter extends SparkAdapter {
+abstract class BaseSpark3Adapter extends SparkAdapter {
- def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer =
+ override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer =
new HoodieSparkAvroSerializer(rootCatalystType, rootAvroType, nullable)
- def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
+ override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer =
new HoodieSpark3AvroDeserializer(rootAvroType, rootCatalystType)
override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark3_1CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark3_1CatalystExpressionUtils.scala
new file mode 100644
index 000000000..cb9c31f08
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark3_1CatalystExpressionUtils.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.HoodieSparkTypeUtils.isCastPreservingOrdering
+import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
+
+object HoodieSpark3_1CatalystExpressionUtils extends HoodieCatalystExpressionUtils {
+
+ override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = {
+ expr match {
+ case OrderPreservingTransformation(attrRef) => Some(attrRef)
+ case _ => None
+ }
+ }
+
+ private object OrderPreservingTransformation {
+ def unapply(expr: Expression): Option[AttributeReference] = {
+ expr match {
+ // Date/Time Expressions
+ case DateFormatClass(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case DateAdd(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case DateSub(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case DateDiff(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case DateDiff(_, OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case FromUnixTime(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case FromUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case ParseToDate(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case ParseToTimestamp(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case ToUnixTimestamp(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef)
+ case ToUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+
+ // String Expressions
+ case Lower(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Upper(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case org.apache.spark.sql.catalyst.expressions.Left(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+
+ // Math Expressions
+ // Binary
+ case Add(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case Add(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case Multiply(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case Multiply(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case Divide(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case BitwiseOr(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case BitwiseOr(_, OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ // Unary
+ case Exp(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Expm1(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log10(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log1p(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log2(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case ShiftLeft(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case ShiftRight(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+
+ // Other
+ case cast @ Cast(OrderPreservingTransformation(attrRef), _, _)
+ if isCastPreservingOrdering(cast.child.dataType, cast.dataType) => Some(attrRef)
+
+ // Identity transformation
+ case attrRef: AttributeReference => Some(attrRef)
+ // No match
+ case _ => None
+ }
+ }
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala
new file mode 100644
index 000000000..106939cbb
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.adapter
+
+import org.apache.spark.sql.hudi.SparkAdapter
+import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieSpark3_1CatalystExpressionUtils}
+
+/**
+ * Implementation of [[SparkAdapter]] for Spark 3.1.x
+ */
+class Spark3_1Adapter extends BaseSpark3Adapter {
+
+ override def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils = HoodieSpark3_1CatalystExpressionUtils
+
+}
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/HoodieSpark3_2CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/HoodieSpark3_2CatalystExpressionUtils.scala
new file mode 100644
index 000000000..8e056c033
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/HoodieSpark3_2CatalystExpressionUtils.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.HoodieSparkTypeUtils.isCastPreservingOrdering
+import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper}
+
+object HoodieSpark3_2CatalystExpressionUtils extends HoodieCatalystExpressionUtils {
+
+ override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = {
+ expr match {
+ case OrderPreservingTransformation(attrRef) => Some(attrRef)
+ case _ => None
+ }
+ }
+
+ private object OrderPreservingTransformation {
+ def unapply(expr: Expression): Option[AttributeReference] = {
+ expr match {
+ // Date/Time Expressions
+ case DateFormatClass(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case DateAdd(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case DateSub(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case DateDiff(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case DateDiff(_, OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case FromUnixTime(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case FromUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case ParseToDate(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case ParseToTimestamp(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef)
+ case ToUnixTimestamp(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef)
+ case ToUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+
+ // String Expressions
+ case Lower(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Upper(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case org.apache.spark.sql.catalyst.expressions.Left(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+
+ // Math Expressions
+ // Binary
+ case Add(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case Add(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case Multiply(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case Multiply(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case Divide(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef)
+ case BitwiseOr(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case BitwiseOr(_, OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ // Unary
+ case Exp(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Expm1(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log10(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log1p(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case Log2(OrderPreservingTransformation(attrRef)) => Some(attrRef)
+ case ShiftLeft(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+ case ShiftRight(OrderPreservingTransformation(attrRef), _) => Some(attrRef)
+
+ // Other
+ case cast @ Cast(OrderPreservingTransformation(attrRef), _, _, _)
+ if isCastPreservingOrdering(cast.child.dataType, cast.dataType) => Some(attrRef)
+
+ // Identity transformation
+ case attrRef: AttributeReference => Some(attrRef)
+ // No match
+ case _ => None
+ }
+ }
+ }
+}
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala
index 699623f8b..1256344c3 100644
--- a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala
@@ -17,16 +17,19 @@
package org.apache.spark.sql.adapter
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieSpark3_2CatalystExpressionUtils, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.parser.HoodieSpark3_2ExtendedSqlParser
/**
- * The adapter for spark3.2.
+ * Implementation of [[SparkAdapter]] for Spark 3.2.x branch
*/
-class Spark3_2Adapter extends Spark3Adapter {
+class Spark3_2Adapter extends BaseSpark3Adapter {
+
+ override def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils = HoodieSpark3_2CatalystExpressionUtils
+
/**
* if the logical plan is a TimeTravelRelation LogicalPlan.
*/