From 8b38ddedc2fd1ef048a756faa3531afc06b2c26b Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Thu, 24 Mar 2022 22:27:15 -0700 Subject: [PATCH] [HUDI-3594] Supporting Composite Expressions over Data Table Columns in Data Skipping flow (#4996) --- .../columnstats/ColumnStatsIndexHelper.java | 84 ++-- .../org/apache/hudi/SparkAdapterSupport.scala | 4 +- .../sql/HoodieCatalystExpressionUtils.scala | 127 ++++++ .../apache/spark/sql/hudi/SparkAdapter.scala | 11 +- .../org/apache/hudi/HoodieBaseRelation.scala | 5 +- .../org/apache/hudi/HoodieFileIndex.scala | 25 +- .../apache/spark/HoodieSparkTypeUtils.scala | 38 ++ .../sql/HoodieCatalystExpressionUtils.scala | 88 ---- .../spark/sql/hudi/DataSkippingUtils.scala | 379 ++++++++++++------ .../procedures/RunClusteringProcedure.scala | 17 +- .../apache/hudi/TestDataSkippingUtils.scala | 296 ++++++++++++-- .../HoodieSpark2CatalystExpressionUtils.scala | 84 ++++ .../spark/sql/adapter/Spark2Adapter.scala | 8 +- ...3Adapter.scala => BaseSpark3Adapter.scala} | 8 +- ...oodieSpark3_1CatalystExpressionUtils.scala | 84 ++++ .../spark/sql/adapter/Spark3_1Adapter.scala | 31 ++ ...oodieSpark3_2CatalystExpressionUtils.scala | 83 ++++ .../spark/sql/adapter/Spark3_2Adapter.scala | 9 +- 18 files changed, 1079 insertions(+), 302 deletions(-) create mode 100644 hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala create mode 100644 hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/HoodieSparkTypeUtils.scala delete mode 100644 hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala create mode 100644 hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala rename hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/{Spark3Adapter.scala => BaseSpark3Adapter.scala} (93%) create mode 100644 hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark3_1CatalystExpressionUtils.scala create mode 100644 hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala create mode 100644 hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/HoodieSpark3_2CatalystExpressionUtils.scala diff --git a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/columnstats/ColumnStatsIndexHelper.java b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/columnstats/ColumnStatsIndexHelper.java index b98893344..49e33f456 100644 --- a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/columnstats/ColumnStatsIndexHelper.java +++ b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/columnstats/ColumnStatsIndexHelper.java @@ -72,6 +72,7 @@ import java.util.UUID; import java.util.stream.Collectors; import java.util.stream.StreamSupport; +import static org.apache.hudi.common.util.ValidationUtils.checkState; import static org.apache.hudi.util.DataTypeUtils.areCompatible; public class ColumnStatsIndexHelper { @@ -111,17 +112,17 @@ public class ColumnStatsIndexHelper { * | another_base_file.parquet | -10 | 0 | 5 | * +---------------------------+------------+------------+-------------+ * - * + *

* NOTE: Currently {@link TimestampType} is not supported, since Parquet writer * does not support statistics for it. - * + *

* TODO leverage metadata table after RFC-27 lands - * @VisibleForTesting * - * @param sparkSession encompassing Spark session - * @param baseFilesPaths list of base-files paths to be sourced for column-stats index + * @param sparkSession encompassing Spark session + * @param baseFilesPaths list of base-files paths to be sourced for column-stats index * @param orderedColumnSchemas target ordered columns * @return Spark's {@link Dataset} holding an index table + * @VisibleForTesting */ @Nonnull public static Dataset buildColumnStatsTableFor( @@ -223,13 +224,13 @@ public class ColumnStatsIndexHelper { *

  • Cleans up any residual index tables, that weren't cleaned up before
  • * * - * @param sparkSession encompassing Spark session + * @param sparkSession encompassing Spark session * @param sourceTableSchema instance of {@link StructType} bearing source table's writer's schema - * @param sourceBaseFiles list of base-files to be indexed - * @param orderedCols target ordered columns - * @param indexFolderPath col-stats index folder path - * @param commitTime current operation commit instant - * @param completedCommits all previously completed commit instants + * @param sourceBaseFiles list of base-files to be indexed + * @param orderedCols target ordered columns + * @param indexFolderPath col-stats index folder path + * @param commitTime current operation commit instant + * @param completedCommits all previously completed commit instants */ public static void updateColumnStatsIndexFor( @Nonnull SparkSession sparkSession, @@ -424,57 +425,64 @@ public class ColumnStatsIndexHelper { return String.format("%s_%s", col, statName); } - private static Pair - fetchMinMaxValues( - @Nonnull DataType colType, - @Nonnull HoodieColumnRangeMetadata colMetadata) { + private static Pair fetchMinMaxValues(@Nonnull DataType colType, + @Nonnull HoodieColumnRangeMetadata colMetadata) { + Comparable minValue = colMetadata.getMinValue(); + Comparable maxValue = colMetadata.getMaxValue(); + + checkState((minValue == null) == (maxValue == null), "Either both min/max values should be null or neither"); + + if (minValue == null || maxValue == null) { + return Pair.of(null, null); + } + if (colType instanceof IntegerType) { return Pair.of( - new Integer(colMetadata.getMinValue().toString()), - new Integer(colMetadata.getMaxValue().toString()) + new Integer(minValue.toString()), + new Integer(maxValue.toString()) ); } else if (colType instanceof DoubleType) { return Pair.of( - new Double(colMetadata.getMinValue().toString()), - new Double(colMetadata.getMaxValue().toString()) + new Double(minValue.toString()), + new Double(maxValue.toString()) ); } else if (colType instanceof StringType) { return Pair.of( - colMetadata.getMinValue().toString(), - colMetadata.getMaxValue().toString()); + minValue.toString(), + maxValue.toString()); } else if (colType instanceof DecimalType) { return Pair.of( - new BigDecimal(colMetadata.getMinValue().toString()), - new BigDecimal(colMetadata.getMaxValue().toString())); + new BigDecimal(minValue.toString()), + new BigDecimal(maxValue.toString())); } else if (colType instanceof DateType) { return Pair.of( - java.sql.Date.valueOf(colMetadata.getMinValue().toString()), - java.sql.Date.valueOf(colMetadata.getMaxValue().toString())); + java.sql.Date.valueOf(minValue.toString()), + java.sql.Date.valueOf(maxValue.toString())); } else if (colType instanceof LongType) { return Pair.of( - new Long(colMetadata.getMinValue().toString()), - new Long(colMetadata.getMaxValue().toString())); + new Long(minValue.toString()), + new Long(maxValue.toString())); } else if (colType instanceof ShortType) { return Pair.of( - new Short(colMetadata.getMinValue().toString()), - new Short(colMetadata.getMaxValue().toString())); + new Short(minValue.toString()), + new Short(maxValue.toString())); } else if (colType instanceof FloatType) { return Pair.of( - new Float(colMetadata.getMinValue().toString()), - new Float(colMetadata.getMaxValue().toString())); + new Float(minValue.toString()), + new Float(maxValue.toString())); } else if (colType instanceof BinaryType) { return Pair.of( - ((ByteBuffer) colMetadata.getMinValue()).array(), - ((ByteBuffer) colMetadata.getMaxValue()).array()); + ((ByteBuffer) minValue).array(), + ((ByteBuffer) maxValue).array()); } else if (colType instanceof BooleanType) { return Pair.of( - Boolean.valueOf(colMetadata.getMinValue().toString()), - Boolean.valueOf(colMetadata.getMaxValue().toString())); + Boolean.valueOf(minValue.toString()), + Boolean.valueOf(maxValue.toString())); } else if (colType instanceof ByteType) { return Pair.of( - Byte.valueOf(colMetadata.getMinValue().toString()), - Byte.valueOf(colMetadata.getMaxValue().toString())); - } else { + Byte.valueOf(minValue.toString()), + Byte.valueOf(maxValue.toString())); + } else { throw new HoodieException(String.format("Not support type: %s", colType)); } } diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala index fb6a5813a..16d9253ad 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala @@ -27,10 +27,10 @@ import org.apache.spark.sql.hudi.SparkAdapter trait SparkAdapterSupport { lazy val sparkAdapter: SparkAdapter = { - val adapterClass = if (HoodieSparkUtils.gteqSpark3_2) { + val adapterClass = if (HoodieSparkUtils.isSpark3_2) { "org.apache.spark.sql.adapter.Spark3_2Adapter" } else if (HoodieSparkUtils.isSpark3_0 || HoodieSparkUtils.isSpark3_1) { - "org.apache.spark.sql.adapter.Spark3Adapter" + "org.apache.spark.sql.adapter.Spark3_1Adapter" } else { "org.apache.spark.sql.adapter.Spark2Adapter" } diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala new file mode 100644 index 000000000..fe30f61b9 --- /dev/null +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, SubqueryExpression} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} +import org.apache.spark.sql.types.StructType + +trait HoodieCatalystExpressionUtils { + + /** + * Parses and resolves expression against the attributes of the given table schema. + * + * For example: + *
    +   * ts > 1000 and ts <= 1500
    +   * 
    + * will be resolved as + *
    +   * And(GreaterThan(ts#590L > 1000), LessThanOrEqual(ts#590L <= 1500))
    +   * 
    + * + * Where
    ts
    is a column of the provided [[tableSchema]] + * + * @param spark spark session + * @param exprString string representation of the expression to parse and resolve + * @param tableSchema table schema encompassing attributes to resolve against + * @return Resolved filter expression + */ + def resolveExpr(spark: SparkSession, exprString: String, tableSchema: StructType): Expression = { + val expr = spark.sessionState.sqlParser.parseExpression(exprString) + resolveExpr(spark, expr, tableSchema) + } + + /** + * Resolves provided expression (unless already resolved) against the attributes of the given table schema. + * + * For example: + *
    +   * ts > 1000 and ts <= 1500
    +   * 
    + * will be resolved as + *
    +   * And(GreaterThan(ts#590L > 1000), LessThanOrEqual(ts#590L <= 1500))
    +   * 
    + * + * Where
    ts
    is a column of the provided [[tableSchema]] + * + * @param spark spark session + * @param expr Catalyst expression to be resolved (if not yet) + * @param tableSchema table schema encompassing attributes to resolve against + * @return Resolved filter expression + */ + def resolveExpr(spark: SparkSession, expr: Expression, tableSchema: StructType): Expression = { + val analyzer = spark.sessionState.analyzer + val schemaFields = tableSchema.fields + + val resolvedExpr = { + val plan: LogicalPlan = Filter(expr, LocalRelation(schemaFields.head, schemaFields.drop(1): _*)) + analyzer.execute(plan).asInstanceOf[Filter].condition + } + + if (!hasUnresolvedRefs(resolvedExpr)) { + resolvedExpr + } else { + throw new IllegalStateException("unresolved attribute") + } + } + + /** + * Split the given predicates into two sequence predicates: + * - predicates that references partition columns only(and involves no sub-query); + * - other predicates. + * + * @param sparkSession The spark session + * @param predicates The predicates to be split + * @param partitionColumns The partition columns + * @return (partitionFilters, dataFilters) + */ + def splitPartitionAndDataPredicates(sparkSession: SparkSession, + predicates: Array[Expression], + partitionColumns: Array[String]): (Array[Expression], Array[Expression]) = { + // Validates that the provided names both resolve to the same entity + val resolvedNameEquals = sparkSession.sessionState.analyzer.resolver + + predicates.partition(expr => { + // Checks whether given expression only references partition columns(and involves no sub-query) + expr.references.forall(r => partitionColumns.exists(resolvedNameEquals(r.name, _))) && + !SubqueryExpression.hasSubquery(expr) + }) + } + + /** + * Matches an expression iff + * + *
      + *
    1. It references exactly one [[AttributeReference]]
    2. + *
    3. It contains only whitelisted transformations that preserve ordering of the source column [1]
    4. + *
    + * + * [1] Preserving ordering is defined as following: transformation T is defined as ordering preserving in case + * values of the source column A values being ordered as a1, a2, a3 ..., will map into column B = T(A) which + * will keep the same ordering b1, b2, b3, ... with b1 = T(a1), b2 = T(a2), ... + */ + def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] + + private def hasUnresolvedRefs(resolvedExpr: Expression): Boolean = + resolvedExpr.collectFirst { + case _: UnresolvedAttribute | _: UnresolvedFunction => true + }.isDefined +} diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala index e41a9c1c8..354ed0ef2 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala @@ -32,16 +32,21 @@ import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.execution.datasources.{FilePartition, LogicalRelation, PartitionedFile, SparkParsePartitionUtil} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{HoodieCatalystExpressionUtils, Row, SparkSession} import java.util.Locale /** - * An interface to adapter the difference between spark2 and spark3 - * in some spark related class. + * Interface adapting discrepancies and incompatibilities between different Spark versions */ trait SparkAdapter extends Serializable { + /** + * Creates instance of [[HoodieCatalystExpressionUtils]] providing for common utils operating + * on Catalyst Expressions + */ + def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils + /** * Creates instance of [[HoodieAvroSerializer]] providing for ability to serialize * Spark's [[InternalRow]] into Avro payloads diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala index 11778da63..98fc21887 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala @@ -20,7 +20,7 @@ package org.apache.hudi import org.apache.avro.Schema import org.apache.avro.generic.GenericRecord import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path, PathFilter} +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hbase.io.hfile.CacheConfig import org.apache.hadoop.mapred.JobConf import org.apache.hudi.HoodieBaseRelation.{getPartitionPath, isMetadataTable} @@ -32,7 +32,6 @@ import org.apache.hudi.common.table.timeline.{HoodieInstant, HoodieTimeline} import org.apache.hudi.common.table.view.HoodieTableFileSystemView import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient, TableSchemaResolver} import org.apache.hudi.common.util.StringUtils -import org.apache.hudi.hadoop.HoodieROTablePathFilter import org.apache.hudi.io.storage.HoodieHFileReader import org.apache.hudi.metadata.{HoodieMetadataPayload, HoodieTableMetadata} import org.apache.spark.execution.datasources.HoodieInMemoryFileIndex @@ -41,7 +40,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.avro.SchemaConverters import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} -import org.apache.spark.sql.execution.datasources.{FileStatusCache, PartitionDirectory, PartitionedFile} +import org.apache.spark.sql.execution.datasources.{FileStatusCache, PartitionedFile} import org.apache.spark.sql.hudi.HoodieSqlCommonUtils import org.apache.spark.sql.sources.{BaseRelation, Filter, PrunedFilteredScan} import org.apache.spark.sql.types.StructType diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala index de863203d..ea3b2f061 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala @@ -32,11 +32,10 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{And, Expression, Literal} import org.apache.spark.sql.execution.datasources.{FileIndex, FileStatusCache, NoopCache, PartitionDirectory} import org.apache.spark.sql.functions.col -import org.apache.spark.sql.hudi.DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr -import org.apache.spark.sql.hudi.HoodieSqlCommonUtils +import org.apache.spark.sql.hudi.{DataSkippingUtils, HoodieSqlCommonUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql.{AnalysisException, Column, SparkSession} +import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.unsafe.types.UTF8String import java.text.SimpleDateFormat @@ -244,21 +243,21 @@ case class HoodieFileIndex(spark: SparkSession, // column references from the filtering expressions, and only transpose records corresponding to the // columns referenced in those val transposedColStatsDF = - queryReferencedColumns.map(colName => - colStatsDF.filter(col(HoodieMetadataPayload.COLUMN_STATS_FIELD_COLUMN_NAME).equalTo(colName)) - .select(targetColStatsIndexColumns.map(col): _*) - .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_NULL_COUNT, getNumNullsColumnNameFor(colName)) - .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_MIN_VALUE, getMinColumnNameFor(colName)) - .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_MAX_VALUE, getMaxColumnNameFor(colName)) - ) - .reduceLeft((left, right) => - left.join(right, usingColumn = HoodieMetadataPayload.COLUMN_STATS_FIELD_FILE_NAME)) + queryReferencedColumns.map(colName => + colStatsDF.filter(col(HoodieMetadataPayload.COLUMN_STATS_FIELD_COLUMN_NAME).equalTo(colName)) + .select(targetColStatsIndexColumns.map(col): _*) + .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_NULL_COUNT, getNumNullsColumnNameFor(colName)) + .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_MIN_VALUE, getMinColumnNameFor(colName)) + .withColumnRenamed(HoodieMetadataPayload.COLUMN_STATS_FIELD_MAX_VALUE, getMaxColumnNameFor(colName)) + ) + .reduceLeft((left, right) => + left.join(right, usingColumn = HoodieMetadataPayload.COLUMN_STATS_FIELD_FILE_NAME)) // Persist DF to avoid re-computing column statistics unraveling withPersistence(transposedColStatsDF) { val indexSchema = transposedColStatsDF.schema val indexFilter = - queryFilters.map(translateIntoColumnStatsIndexFilterExpr(_, indexSchema)) + queryFilters.map(DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(_, indexSchema)) .reduce(And) val allIndexedFileNames = diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/HoodieSparkTypeUtils.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/HoodieSparkTypeUtils.scala new file mode 100644 index 000000000..d5d95872a --- /dev/null +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/HoodieSparkTypeUtils.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.sql.types.{DataType, NumericType, StringType} + +// TODO unify w/ DataTypeUtils +object HoodieSparkTypeUtils { + + /** + * Checks whether casting expression of [[from]] [[DataType]] to [[to]] [[DataType]] will + * preserve ordering of the elements + */ + def isCastPreservingOrdering(from: DataType, to: DataType): Boolean = + (from, to) match { + // NOTE: In the casting rules defined by Spark, only casting from String to Numeric + // (and vice versa) are the only casts that might break the ordering of the elements after casting + case (StringType, _: NumericType) => false + case (_: NumericType, StringType) => false + + case _ => true + } +} diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala deleted file mode 100644 index d640c0226..000000000 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation} -import org.apache.spark.sql.types.StructType - -object HoodieCatalystExpressionUtils { - - /** - * Resolve filter expression from string expr with given table schema, for example: - *
    -   *   ts > 1000 and ts <= 1500
    -   * 
    - * will be resolved as - *
    -   *   And(GreaterThan(ts#590L > 1000), LessThanOrEqual(ts#590L <= 1500))
    -   * 
    - * - * @param spark The spark session - * @param exprString String to be resolved - * @param tableSchema The table schema - * @return Resolved filter expression - */ - def resolveFilterExpr(spark: SparkSession, exprString: String, tableSchema: StructType): Expression = { - val expr = spark.sessionState.sqlParser.parseExpression(exprString) - resolveFilterExpr(spark, expr, tableSchema) - } - - def resolveFilterExpr(spark: SparkSession, expr: Expression, tableSchema: StructType): Expression = { - val schemaFields = tableSchema.fields - val resolvedExpr = spark.sessionState.analyzer.ResolveReferences( - Filter(expr, - LocalRelation(schemaFields.head, schemaFields.drop(1): _*)) - ) - .asInstanceOf[Filter].condition - - checkForUnresolvedRefs(resolvedExpr) - } - - private def checkForUnresolvedRefs(resolvedExpr: Expression): Expression = - resolvedExpr match { - case UnresolvedAttribute(_) => throw new IllegalStateException("unresolved attribute") - case _ => resolvedExpr.mapChildren(e => checkForUnresolvedRefs(e)) - } - - /** - * Split the given predicates into two sequence predicates: - * - predicates that references partition columns only(and involves no sub-query); - * - other predicates. - * - * @param sparkSession The spark session - * @param predicates The predicates to be split - * @param partitionColumns The partition columns - * @return (partitionFilters, dataFilters) - */ - def splitPartitionAndDataPredicates(sparkSession: SparkSession, - predicates: Array[Expression], - partitionColumns: Array[String]): (Array[Expression], Array[Expression]) = { - // Validates that the provided names both resolve to the same entity - val resolvedNameEquals = sparkSession.sessionState.analyzer.resolver - - predicates.partition(expr => { - // Checks whether given expression only references partition columns(and involves no sub-query) - expr.references.forall(r => partitionColumns.exists(resolvedNameEquals(r.name, _))) && - !SubqueryExpression.hasSubquery(expr) - }) - } -} diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala index 06b92e204..b7ddd2828 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.hudi +import org.apache.hudi.SparkAdapterSupport +import org.apache.hudi.common.util.ValidationUtils.checkState import org.apache.hudi.index.columnstats.ColumnStatsIndexHelper.{getMaxColumnNameFor, getMinColumnNameFor, getNumNullsColumnNameFor} import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith, SubqueryExpression} import org.apache.spark.sql.functions.col +import org.apache.spark.sql.hudi.ColumnStatsExpressionUtils._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{AnalysisException, HoodieCatalystExpressionUtils} import org.apache.spark.unsafe.types.UTF8String object DataSkippingUtils extends Logging { @@ -59,147 +62,205 @@ object DataSkippingUtils extends Logging { } private def tryComposeIndexFilterExpr(sourceExpr: Expression, indexSchema: StructType): Option[Expression] = { - def minValue(colName: String) = col(getMinColumnNameFor(colName)).expr - def maxValue(colName: String) = col(getMaxColumnNameFor(colName)).expr - def numNulls(colName: String) = col(getNumNullsColumnNameFor(colName)).expr - - def colContainsValuesEqualToLiteral(colName: String, value: Literal): Expression = - // Only case when column C contains value V is when min(C) <= V <= max(c) - And(LessThanOrEqual(minValue(colName), value), GreaterThanOrEqual(maxValue(colName), value)) - - def colContainsOnlyValuesEqualToLiteral(colName: String, value: Literal) = - // Only case when column C contains _only_ value V is when min(C) = V AND max(c) = V - And(EqualTo(minValue(colName), value), EqualTo(maxValue(colName), value)) - + // + // For translation of the Filter Expression for the Data Table into Filter Expression for Column Stats Index, we're + // assuming that + // - The column A is queried in the Data Table (hereafter referred to as "colA") + // - Filter Expression is a relational expression (ie "=", "<", "<=", ...) of the following form + // + // ```transform_expr(colA) = value_expr``` + // + // Where + // - "transform_expr" is an expression of the _transformation_ which preserve ordering of the "colA" + // - "value_expr" is an "value"-expression (ie one NOT referring to other attributes/columns or containing sub-queries) + // + // We translate original Filter Expr into the one querying Column Stats Index like following: let's consider + // equality Filter Expr referred to above: + // + // ```transform_expr(colA) = value_expr``` + // + // This expression will be translated into following Filter Expression for the Column Stats Index: + // + // ```(transform_expr(colA_minValue) <= value_expr) AND (value_expr <= transform_expr(colA_maxValue))``` + // + // Which will enable us to match files with the range of values in column A containing the target ```value_expr``` + // + // NOTE: That we can apply ```transform_expr``` transformation precisely b/c it preserves the ordering of the + // values of the source column, ie following holds true: + // + // colA_minValue = min(colA) => transform_expr(colA_minValue) = min(transform_expr(colA)) + // colA_maxValue = max(colA) => transform_expr(colA_maxValue) = max(transform_expr(colA)) + // sourceExpr match { - // Filter "colA = b" - // Translates to "colA_minValue <= b AND colA_maxValue >= b" condition for index lookup - case EqualTo(attribute: AttributeReference, value: Literal) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => colContainsValuesEqualToLiteral(colName, value)) + // If Expression is not resolved, we can't perform the analysis accurately, bailing + case expr if !expr.resolved => None - // Filter "b = colA" - // Translates to "colA_minValue <= b AND colA_maxValue >= b" condition for index lookup - case EqualTo(value: Literal, attribute: AttributeReference) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => colContainsValuesEqualToLiteral(colName, value)) + // Filter "expr(colA) = B" and "B = expr(colA)" + // Translates to "(expr(colA_minValue) <= B) AND (B <= expr(colA_maxValue))" condition for index lookup + case EqualTo(sourceExpr @ AllowedTransformationExpression(attrRef), valueExpr: Expression) if isValueExpression(valueExpr) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + // NOTE: Since we're supporting (almost) arbitrary expressions of the form `f(colA) = B`, we have to + // appropriately translate such original expression targeted at Data Table, to corresponding + // expression targeted at Column Stats Index Table. For that, we take original expression holding + // [[AttributeReference]] referring to the Data Table, and swap it w/ expression referring to + // corresponding column in the Column Stats Index + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + genColumnValuesEqualToExpression(colName, valueExpr, targetExprBuilder) + } - // Filter "colA != b" - // Translates to "NOT(colA_minValue = b AND colA_maxValue = b)" - // NOTE: This is NOT an inversion of `colA = b` - case Not(EqualTo(attribute: AttributeReference, value: Literal)) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => Not(colContainsOnlyValuesEqualToLiteral(colName, value))) + case EqualTo(valueExpr: Expression, sourceExpr @ AllowedTransformationExpression(attrRef)) if isValueExpression(valueExpr) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + genColumnValuesEqualToExpression(colName, valueExpr, targetExprBuilder) + } - // Filter "b != colA" - // Translates to "NOT(colA_minValue = b AND colA_maxValue = b)" - // NOTE: This is NOT an inversion of `colA = b` - case Not(EqualTo(value: Literal, attribute: AttributeReference)) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => Not(colContainsOnlyValuesEqualToLiteral(colName, value))) + // Filter "expr(colA) != B" and "B != expr(colA)" + // Translates to "NOT(expr(colA_minValue) = B AND expr(colA_maxValue) = B)" + // NOTE: This is NOT an inversion of `colA = b`, instead this filter ONLY excludes files for which `colA = B` + // holds true + case Not(EqualTo(sourceExpr @ AllowedTransformationExpression(attrRef), value: Expression)) if isValueExpression(value) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + Not(genColumnOnlyValuesEqualToExpression(colName, value, targetExprBuilder)) + } + + case Not(EqualTo(value: Expression, sourceExpr @ AllowedTransformationExpression(attrRef))) if isValueExpression(value) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + Not(genColumnOnlyValuesEqualToExpression(colName, value, targetExprBuilder)) + } // Filter "colA = null" // Translates to "colA_num_nulls = null" for index lookup - case equalNullSafe @ EqualNullSafe(_: AttributeReference, _ @ Literal(null, _)) => - getTargetIndexedColName(equalNullSafe.left, indexSchema) - .map(colName => EqualTo(numNulls(colName), equalNullSafe.right)) + case EqualNullSafe(attrRef: AttributeReference, litNull @ Literal(null, _)) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map(colName => EqualTo(genColNumNullsExpr(colName), litNull)) - // Filter "colA < b" - // Translates to "colA_minValue < b" for index lookup - case LessThan(attribute: AttributeReference, value: Literal) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => LessThan(minValue(colName), value)) + // Filter "expr(colA) < B" and "B > expr(colA)" + // Translates to "expr(colA_minValue) < B" for index lookup + case LessThan(sourceExpr @ AllowedTransformationExpression(attrRef), value: Expression) if isValueExpression(value) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + LessThan(targetExprBuilder.apply(genColMinValueExpr(colName)), value) + } - // Filter "b > colA" - // Translates to "b > colA_minValue" for index lookup - case GreaterThan(value: Literal, attribute: AttributeReference) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => LessThan(minValue(colName), value)) + case GreaterThan(value: Expression, sourceExpr @ AllowedTransformationExpression(attrRef)) if isValueExpression(value) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + LessThan(targetExprBuilder.apply(genColMinValueExpr(colName)), value) + } - // Filter "b < colA" - // Translates to "b < colA_maxValue" for index lookup - case LessThan(value: Literal, attribute: AttributeReference) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => GreaterThan(maxValue(colName), value)) + // Filter "B < expr(colA)" and "expr(colA) > B" + // Translates to "B < colA_maxValue" for index lookup + case LessThan(value: Expression, sourceExpr @ AllowedTransformationExpression(attrRef)) if isValueExpression(value) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + GreaterThan(targetExprBuilder.apply(genColMaxValueExpr(colName)), value) + } - // Filter "colA > b" - // Translates to "colA_maxValue > b" for index lookup - case GreaterThan(attribute: AttributeReference, value: Literal) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => GreaterThan(maxValue(colName), value)) + case GreaterThan(sourceExpr @ AllowedTransformationExpression(attrRef), value: Expression) if isValueExpression(value) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + GreaterThan(targetExprBuilder.apply(genColMaxValueExpr(colName)), value) + } - // Filter "colA <= b" - // Translates to "colA_minValue <= b" for index lookup - case LessThanOrEqual(attribute: AttributeReference, value: Literal) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => LessThanOrEqual(minValue(colName), value)) + // Filter "expr(colA) <= B" and "B >= expr(colA)" + // Translates to "colA_minValue <= B" for index lookup + case LessThanOrEqual(sourceExpr @ AllowedTransformationExpression(attrRef), value: Expression) if isValueExpression(value) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + LessThanOrEqual(targetExprBuilder.apply(genColMinValueExpr(colName)), value) + } - // Filter "b >= colA" - // Translates to "b >= colA_minValue" for index lookup - case GreaterThanOrEqual(value: Literal, attribute: AttributeReference) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => LessThanOrEqual(minValue(colName), value)) + case GreaterThanOrEqual(value: Expression, sourceExpr @ AllowedTransformationExpression(attrRef)) if isValueExpression(value) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + LessThanOrEqual(targetExprBuilder.apply(genColMinValueExpr(colName)), value) + } - // Filter "b <= colA" - // Translates to "b <= colA_maxValue" for index lookup - case LessThanOrEqual(value: Literal, attribute: AttributeReference) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => GreaterThanOrEqual(maxValue(colName), value)) + // Filter "B <= expr(colA)" and "expr(colA) >= B" + // Translates to "B <= colA_maxValue" for index lookup + case LessThanOrEqual(value: Expression, sourceExpr @ AllowedTransformationExpression(attrRef)) if isValueExpression(value) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + GreaterThanOrEqual(targetExprBuilder.apply(genColMaxValueExpr(colName)), value) + } - // Filter "colA >= b" - // Translates to "colA_maxValue >= b" for index lookup - case GreaterThanOrEqual(attribute: AttributeReference, right: Literal) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => GreaterThanOrEqual(maxValue(colName), right)) + case GreaterThanOrEqual(sourceExpr @ AllowedTransformationExpression(attrRef), value: Expression) if isValueExpression(value) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + GreaterThanOrEqual(targetExprBuilder.apply(genColMaxValueExpr(colName)), value) + } // Filter "colA is null" // Translates to "colA_num_nulls > 0" for index lookup case IsNull(attribute: AttributeReference) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => GreaterThan(numNulls(colName), Literal(0))) + getTargetIndexedColumnName(attribute, indexSchema) + .map(colName => GreaterThan(genColNumNullsExpr(colName), Literal(0))) // Filter "colA is not null" // Translates to "colA_num_nulls = 0" for index lookup case IsNotNull(attribute: AttributeReference) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => EqualTo(numNulls(colName), Literal(0))) + getTargetIndexedColumnName(attribute, indexSchema) + .map(colName => EqualTo(genColNumNullsExpr(colName), Literal(0))) - // Filter "colA in (a, b, ...)" - // Translates to "(colA_minValue <= a AND colA_maxValue >= a) OR (colA_minValue <= b AND colA_maxValue >= b)" for index lookup - // NOTE: This is equivalent to "colA = a OR colA = b OR ..." - case In(attribute: AttributeReference, list: Seq[Literal]) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => - list.map { lit => colContainsValuesEqualToLiteral(colName, lit) }.reduce(Or) - ) + // Filter "expr(colA) in (B1, B2, ...)" + // Translates to "(colA_minValue <= B1 AND colA_maxValue >= B1) OR (colA_minValue <= B2 AND colA_maxValue >= B2) ... " + // for index lookup + // NOTE: This is equivalent to "colA = B1 OR colA = B2 OR ..." + case In(sourceExpr @ AllowedTransformationExpression(attrRef), list: Seq[Expression]) if list.forall(isValueExpression) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + list.map(lit => genColumnValuesEqualToExpression(colName, lit, targetExprBuilder)).reduce(Or) + } - // Filter "colA not in (a, b, ...)" - // Translates to "NOT((colA_minValue = a AND colA_maxValue = a) OR (colA_minValue = b AND colA_maxValue = b))" for index lookup - // NOTE: This is NOT an inversion of `in (a, b, ...)` expr, this is equivalent to "colA != a AND colA != b AND ..." - case Not(In(attribute: AttributeReference, list: Seq[Literal])) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => - Not( - list.map { lit => colContainsOnlyValuesEqualToLiteral(colName, lit) }.reduce(Or) - ) - ) + // Filter "expr(colA) not in (B1, B2, ...)" + // Translates to "NOT((colA_minValue = B1 AND colA_maxValue = B1) OR (colA_minValue = B2 AND colA_maxValue = B2))" for index lookup + // NOTE: This is NOT an inversion of `in (B1, B2, ...)` expr, this is equivalent to "colA != B1 AND colA != B2 AND ..." + case Not(In(sourceExpr @ AllowedTransformationExpression(attrRef), list: Seq[Expression])) if list.forall(_.foldable) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + Not(list.map(lit => genColumnOnlyValuesEqualToExpression(colName, lit, targetExprBuilder)).reduce(Or)) + } // Filter "colA like 'xxx%'" - // Translates to "colA_minValue <= xxx AND colA_maxValue >= xxx" for index lookup - // NOTE: That this operator only matches string prefixes, and this is - // essentially equivalent to "colA = b" expression - case StartsWith(attribute, v @ Literal(_: UTF8String, _)) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => colContainsValuesEqualToLiteral(colName, v)) + // Translates to "colA_minValue <= xxx AND xxx <= colA_maxValue" for index lookup + // + // NOTE: Since a) this operator matches strings by prefix and b) given that this column is going to be ordered + // lexicographically, we essentially need to check that provided literal falls w/in min/max bounds of the + // given column + case StartsWith(sourceExpr @ AllowedTransformationExpression(attrRef), v @ Literal(_: UTF8String, _)) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + genColumnValuesEqualToExpression(colName, v, targetExprBuilder) + } - // Filter "colA not like 'xxx%'" - // Translates to "NOT(colA_minValue like 'xxx%' AND colA_maxValue like 'xxx%')" for index lookup + // Filter "expr(colA) not like 'xxx%'" + // Translates to "NOT(expr(colA_minValue) like 'xxx%' AND expr(colA_maxValue) like 'xxx%')" for index lookup // NOTE: This is NOT an inversion of "colA like xxx" - case Not(StartsWith(attribute, value @ Literal(_: UTF8String, _))) => - getTargetIndexedColName(attribute, indexSchema) - .map(colName => - Not(And(StartsWith(minValue(colName), value), StartsWith(maxValue(colName), value))) - ) + case Not(StartsWith(sourceExpr @ AllowedTransformationExpression(attrRef), value @ Literal(_: UTF8String, _))) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + val minValueExpr = targetExprBuilder.apply(genColMinValueExpr(colName)) + val maxValueExpr = targetExprBuilder.apply(genColMaxValueExpr(colName)) + Not(And(StartsWith(minValueExpr, value), StartsWith(maxValueExpr, value))) + } case or: Or => val resLeft = createColumnStatsIndexFilterExprInternal(or.left, indexSchema) @@ -238,7 +299,7 @@ object DataSkippingUtils extends Logging { .forall(stat => indexSchema.exists(_.name == stat)) } - private def getTargetIndexedColName(resolvedExpr: Expression, indexSchema: StructType): Option[String] = { + private def getTargetIndexedColumnName(resolvedExpr: AttributeReference, indexSchema: StructType): Option[String] = { val colName = UnresolvedAttribute(getTargetColNameParts(resolvedExpr)).name // Verify that the column is indexed @@ -261,3 +322,91 @@ object DataSkippingUtils extends Logging { } } } + +private object ColumnStatsExpressionUtils { + + def genColMinValueExpr(colName: String): Expression = + col(getMinColumnNameFor(colName)).expr + def genColMaxValueExpr(colName: String): Expression = + col(getMaxColumnNameFor(colName)).expr + def genColNumNullsExpr(colName: String): Expression = + col(getNumNullsColumnNameFor(colName)).expr + + def genColumnValuesEqualToExpression(colName: String, + value: Expression, + targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = { + // TODO clean up + checkState(isValueExpression(value)) + + val minValueExpr = targetExprBuilder.apply(genColMinValueExpr(colName)) + val maxValueExpr = targetExprBuilder.apply(genColMaxValueExpr(colName)) + // Only case when column C contains value V is when min(C) <= V <= max(c) + And(LessThanOrEqual(minValueExpr, value), GreaterThanOrEqual(maxValueExpr, value)) + } + + def genColumnOnlyValuesEqualToExpression(colName: String, + value: Expression, + targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = { + // TODO clean up + checkState(isValueExpression(value)) + + val minValueExpr = targetExprBuilder.apply(genColMinValueExpr(colName)) + val maxValueExpr = targetExprBuilder.apply(genColMaxValueExpr(colName)) + // Only case when column C contains _only_ value V is when min(C) = V AND max(c) = V + And(EqualTo(minValueExpr, value), EqualTo(maxValueExpr, value)) + } + + def swapAttributeRefInExpr(sourceExpr: Expression, from: AttributeReference, to: Expression): Expression = { + checkState(sourceExpr.references.size == 1) + sourceExpr.transformDown { + case attrRef: AttributeReference if attrRef.sameRef(from) => to + } + } + + /** + * This check is used to validate that the expression that target column is compared against + *
    +   *    a) Has no references to other attributes (for ex, columns)
    +   *    b) Does not contain sub-queries
    +   * 
    + * + * This in turn allows us to be certain that Spark will be able to evaluate such expression + * against Column Stats Index as well + */ + def isValueExpression(expr: Expression): Boolean = + expr.references.isEmpty && !SubqueryExpression.hasSubquery(expr) + + /** + * This utility pattern-matches an expression iff + * + *
      + *
    1. It references *exactly* 1 attribute (column)
    2. + *
    3. It does NOT contain sub-queries
    4. + *
    5. It contains only whitelisted transformations that preserve ordering of the source column [1]
    6. + *
    + * + * [1] This is required to make sure that we can correspondingly map Column Stats Index values as well. Applying + * transformations that do not preserve the ordering might lead to incorrect results being returned by Data + * Skipping flow. + * + * Returns only [[AttributeReference]] contained as a sub-expression + */ + object AllowedTransformationExpression extends SparkAdapterSupport { + val exprUtils: HoodieCatalystExpressionUtils = sparkAdapter.createCatalystExpressionUtils() + + def unapply(expr: Expression): Option[AttributeReference] = { + // First step, we check that expression + // - Does NOT contain sub-queries + // - Does contain exactly 1 attribute + if (SubqueryExpression.hasSubquery(expr) || expr.references.size != 1) { + None + } else { + // Second step, we validate that holding expression is an actually permitted + // transformation + // NOTE: That transformation composition is permitted + exprUtils.tryMatchAttributeOrderingPreservingTransformation(expr) + } + } + } +} + diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RunClusteringProcedure.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RunClusteringProcedure.scala index 442ee0441..231d0939c 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RunClusteringProcedure.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RunClusteringProcedure.scala @@ -24,9 +24,9 @@ import org.apache.hudi.common.util.ValidationUtils.checkArgument import org.apache.hudi.common.util.{ClusteringUtils, Option => HOption} import org.apache.hudi.config.HoodieClusteringConfig import org.apache.hudi.exception.HoodieClusteringException -import org.apache.hudi.{AvroConversionUtils, HoodieCLIUtils, HoodieFileIndex} +import org.apache.hudi.{AvroConversionUtils, HoodieCLIUtils, HoodieFileIndex, SparkAdapterSupport} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{HoodieCatalystExpressionUtils, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.execution.datasources.FileStatusCache import org.apache.spark.sql.types._ @@ -34,7 +34,14 @@ import org.apache.spark.sql.types._ import java.util.function.Supplier import scala.collection.JavaConverters._ -class RunClusteringProcedure extends BaseProcedure with ProcedureBuilder with PredicateHelper with Logging { +class RunClusteringProcedure extends BaseProcedure + with ProcedureBuilder + with PredicateHelper + with Logging + with SparkAdapterSupport { + + private val exprUtils = sparkAdapter.createCatalystExpressionUtils() + /** * OPTIMIZE table_name|table_path [WHERE predicate] * [ORDER BY (col_name1 [, ...] ) ] @@ -120,9 +127,9 @@ class RunClusteringProcedure extends BaseProcedure with ProcedureBuilder with Pr // Resolve partition predicates val schemaResolver = new TableSchemaResolver(metaClient) val tableSchema = AvroConversionUtils.convertAvroSchemaToStructType(schemaResolver.getTableAvroSchema) - val condition = HoodieCatalystExpressionUtils.resolveFilterExpr(sparkSession, predicate, tableSchema) + val condition = exprUtils.resolveExpr(sparkSession, predicate, tableSchema) val partitionColumns = metaClient.getTableConfig.getPartitionFields.orElse(Array[String]()) - val (partitionPredicates, dataPredicates) = HoodieCatalystExpressionUtils.splitPartitionAndDataPredicates( + val (partitionPredicates, dataPredicates) = exprUtils.splitPartitionAndDataPredicates( sparkSession, splitConjunctivePredicates(condition).toArray, partitionColumns) checkArgument(dataPredicates.isEmpty, "Only partition predicates are allowed") diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala index 6b96472d4..43f070e6c 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala @@ -20,30 +20,44 @@ package org.apache.hudi import org.apache.hudi.index.columnstats.ColumnStatsIndexHelper import org.apache.hudi.testutils.HoodieClientTestBase import org.apache.spark.sql.catalyst.expressions.{Expression, Not} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, lower} import org.apache.spark.sql.hudi.DataSkippingUtils -import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType, VarcharType} -import org.apache.spark.sql.{Column, HoodieCatalystExpressionUtils, SparkSession} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, HoodieCatalystExpressionUtils, Row, SparkSession} import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments.arguments import org.junit.jupiter.params.provider.{Arguments, MethodSource} +import java.sql.Timestamp import scala.collection.JavaConverters._ // NOTE: Only A, B columns are indexed case class IndexRow( file: String, - A_minValue: Long, - A_maxValue: Long, - A_num_nulls: Long, + + // Corresponding A column is LongType + A_minValue: Long = -1, + A_maxValue: Long = -1, + A_num_nulls: Long = -1, + + // Corresponding B column is StringType B_minValue: String = null, B_maxValue: String = null, - B_num_nulls: Long = -1 -) + B_num_nulls: Long = -1, -class TestDataSkippingUtils extends HoodieClientTestBase { + // Corresponding B column is TimestampType + C_minValue: Timestamp = null, + C_maxValue: Timestamp = null, + C_num_nulls: Long = -1 +) { + def toRow: Row = Row(productIterator.toSeq: _*) +} + +class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSupport { + + val exprUtils: HoodieCatalystExpressionUtils = sparkAdapter.createCatalystExpressionUtils() var spark: SparkSession = _ @@ -53,17 +67,18 @@ class TestDataSkippingUtils extends HoodieClientTestBase { spark = sqlContext.sparkSession } - val indexedCols = Seq("A", "B") - val sourceTableSchema = + val indexedCols: Seq[String] = Seq("A", "B", "C") + val sourceTableSchema: StructType = StructType( Seq( StructField("A", LongType), StructField("B", StringType), - StructField("C", VarcharType(32)) + StructField("C", TimestampType), + StructField("D", VarcharType(32)) ) ) - val indexSchema = + val indexSchema: StructType = ColumnStatsIndexHelper.composeIndexSchema( sourceTableSchema.fields.toSeq .filter(f => indexedCols.contains(f.name)) @@ -71,15 +86,17 @@ class TestDataSkippingUtils extends HoodieClientTestBase { ) @ParameterizedTest - @MethodSource(Array("testBaseLookupFilterExpressionsSource", "testAdvancedLookupFilterExpressionsSource")) + @MethodSource( + Array( + "testBasicLookupFilterExpressionsSource", + "testAdvancedLookupFilterExpressionsSource", + "testCompositeFilterExpressionsSource" + )) def testLookupFilterExpressions(sourceExpr: String, input: Seq[IndexRow], output: Seq[String]): Unit = { - val resolvedExpr: Expression = HoodieCatalystExpressionUtils.resolveFilterExpr(spark, sourceExpr, sourceTableSchema) + val resolvedExpr: Expression = exprUtils.resolveExpr(spark, sourceExpr, sourceTableSchema) val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema) - val spark2 = spark - import spark2.implicits._ - - val indexDf = spark.createDataset(input) + val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema) val rows = indexDf.where(new Column(lookupFilter)) .select("file") @@ -93,7 +110,7 @@ class TestDataSkippingUtils extends HoodieClientTestBase { @ParameterizedTest @MethodSource(Array("testStringsLookupFilterExpressionsSource")) def testStringsLookupFilterExpressions(sourceExpr: Expression, input: Seq[IndexRow], output: Seq[String]): Unit = { - val resolvedExpr = HoodieCatalystExpressionUtils.resolveFilterExpr(spark, sourceExpr, sourceTableSchema) + val resolvedExpr = exprUtils.resolveExpr(spark, sourceExpr, sourceTableSchema) val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema) val spark2 = spark @@ -130,11 +147,21 @@ object TestDataSkippingUtils { IndexRow("file_3", 0, 0, 0, "aaa", "aba", 0), IndexRow("file_4", 0, 0, 0, "abc123", "abc345", 0) // all strings start w/ "abc" ), + Seq("file_1", "file_2", "file_3")), + arguments( + // Composite expression + Not(lower(col("B")).startsWith("abc").expr), + Seq( + IndexRow("file_1", 0, 0, 0, "ABA", "ADF", 1), // may contain strings starting w/ "ABC" (after upper) + IndexRow("file_2", 0, 0, 0, "ADF", "AZY", 0), + IndexRow("file_3", 0, 0, 0, "AAA", "ABA", 0), + IndexRow("file_4", 0, 0, 0, "ABC123", "ABC345", 0) // all strings start w/ "ABC" (after upper) + ), Seq("file_1", "file_2", "file_3")) ) } - def testBaseLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = { + def testBasicLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = { java.util.stream.Stream.of( // TODO cases // A = null @@ -263,6 +290,23 @@ object TestDataSkippingUtils { IndexRow("file_4", 0, 0, 0), // only contains 0 IndexRow("file_5", 1, 1, 0) // only contains 1 ), + Seq("file_1", "file_2", "file_3")), + arguments( + // Value expression containing expression, which isn't a literal + "A = int('0')", + Seq( + IndexRow("file_1", 1, 2, 0), + IndexRow("file_2", -1, 1, 0) + ), + Seq("file_2")), + arguments( + // Value expression containing reference to the other attribute (column), fallback + "A = D", + Seq( + IndexRow("file_1", 1, 2, 0), + IndexRow("file_2", -1, 1, 0), + IndexRow("file_3", -2, -1, 0) + ), Seq("file_1", "file_2", "file_3")) ) } @@ -316,8 +360,8 @@ object TestDataSkippingUtils { Seq("file_1", "file_2", "file_3")), arguments( - // Queries contains expression involving non-indexed column C - "A = 0 AND B = 'abc' AND C = '...'", + // Queries contains expression involving non-indexed column D + "A = 0 AND B = 'abc' AND D IS NULL", Seq( IndexRow("file_1", 1, 2, 0), IndexRow("file_2", -1, 1, 0), @@ -327,8 +371,8 @@ object TestDataSkippingUtils { Seq("file_4")), arguments( - // Queries contains expression involving non-indexed column C - "A = 0 OR B = 'abc' OR C = '...'", + // Queries contains expression involving non-indexed column D + "A = 0 OR B = 'abc' OR D IS NULL", Seq( IndexRow("file_1", 1, 2, 0), IndexRow("file_2", -1, 1, 0), @@ -338,4 +382,206 @@ object TestDataSkippingUtils { Seq("file_1", "file_2", "file_3", "file_4")) ) } + + def testCompositeFilterExpressionsSource(): java.util.stream.Stream[Arguments] = { + java.util.stream.Stream.of( + arguments( + "date_format(C, 'MM/dd/yyyy') = '03/06/2022'", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_2")), + arguments( + "'03/06/2022' = date_format(C, 'MM/dd/yyyy')", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_2")), + arguments( + "'03/06/2022' != date_format(C, 'MM/dd/yyyy')", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646625048000L), // 03/06/2022 + C_num_nulls = 0) + ), + Seq("file_1")), + arguments( + "date_format(C, 'MM/dd/yyyy') != '03/06/2022'", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646625048000L), // 03/06/2022 + C_num_nulls = 0) + ), + Seq("file_1")), + arguments( + "date_format(C, 'MM/dd/yyyy') < '03/07/2022'", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_2")), + arguments( + "'03/07/2022' > date_format(C, 'MM/dd/yyyy')", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_2")), + arguments( + "'03/07/2022' < date_format(C, 'MM/dd/yyyy')", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_1")), + arguments( + "date_format(C, 'MM/dd/yyyy') > '03/07/2022'", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_1")), + arguments( + "date_format(C, 'MM/dd/yyyy') <= '03/06/2022'", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_2")), + arguments( + "'03/06/2022' >= date_format(C, 'MM/dd/yyyy')", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_2")), + arguments( + "'03/08/2022' <= date_format(C, 'MM/dd/yyyy')", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_1")), + arguments( + "date_format(C, 'MM/dd/yyyy') >= '03/08/2022'", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_1")), + arguments( + "date_format(C, 'MM/dd/yyyy') IN ('03/08/2022')", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646711448000L), // 03/07/2022 + C_num_nulls = 0) + ), + Seq("file_1")), + arguments( + "date_format(C, 'MM/dd/yyyy') NOT IN ('03/06/2022')", + Seq( + IndexRow("file_1", + C_minValue = new Timestamp(1646711448000L), // 03/07/2022 + C_maxValue = new Timestamp(1646797848000L), // 03/08/2022 + C_num_nulls = 0), + IndexRow("file_2", + C_minValue = new Timestamp(1646625048000L), // 03/06/2022 + C_maxValue = new Timestamp(1646625048000L), // 03/06/2022 + C_num_nulls = 0) + ), + Seq("file_1")), + arguments( + // Should be identical to the one above + "date_format(to_timestamp(B, 'yyyy-MM-dd'), 'MM/dd/yyyy') NOT IN ('03/06/2022')", + Seq( + IndexRow("file_1", + B_minValue = "2022-03-07", // 03/07/2022 + B_maxValue = "2022-03-08", // 03/08/2022 + B_num_nulls = 0), + IndexRow("file_2", + B_minValue = "2022-03-06", // 03/06/2022 + B_maxValue = "2022-03-06", // 03/06/2022 + B_num_nulls = 0) + ), + Seq("file_1")) + + ) + } } diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala new file mode 100644 index 000000000..3e233352c --- /dev/null +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} + +object HoodieSpark2CatalystExpressionUtils extends HoodieCatalystExpressionUtils { + + override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = { + expr match { + case OrderPreservingTransformation(attrRef) => Some(attrRef) + case _ => None + } + } + + private object OrderPreservingTransformation { + def unapply(expr: Expression): Option[AttributeReference] = { + expr match { + // Date/Time Expressions + case DateFormatClass(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case DateAdd(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateSub(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateDiff(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateDiff(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + case FromUnixTime(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case FromUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case ParseToDate(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case ParseToTimestamp(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case ToUnixTimestamp(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case ToUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // String Expressions + case Lower(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Upper(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case org.apache.spark.sql.catalyst.expressions.Left(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + + // Math Expressions + // Binary + case Add(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case Add(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Multiply(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case Multiply(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Divide(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case BitwiseOr(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case BitwiseOr(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + // Unary + case Exp(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Expm1(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log10(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log1p(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log2(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case ShiftLeft(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case ShiftRight(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // Other + case cast @ Cast(OrderPreservingTransformation(attrRef), _, _) + if isCastPreservingOrdering(cast.child.dataType, cast.dataType) => Some(attrRef) + + // Identity transformation + case attrRef: AttributeReference => Some(attrRef) + // No match + case _ => None + } + } + } + +} diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala index 54c8b912a..42ad66598 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.hudi.SparkAdapter import org.apache.spark.sql.hudi.parser.HoodieSpark2ExtendedSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieSpark2CatalystExpressionUtils, Row, SparkSession} import scala.collection.mutable.ArrayBuffer @@ -42,10 +42,12 @@ import scala.collection.mutable.ArrayBuffer */ class Spark2Adapter extends SparkAdapter { - def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer = + override def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils = HoodieSpark2CatalystExpressionUtils + + override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer = new HoodieSparkAvroSerializer(rootCatalystType, rootAvroType, nullable) - def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer = + override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer = new HoodieSpark2AvroDeserializer(rootAvroType, rootCatalystType) override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = { diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala similarity index 93% rename from hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala rename to hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala index ad338323e..33aae23df 100644 --- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala +++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala @@ -39,14 +39,14 @@ import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SparkSession} /** - * The adapter for spark3. + * Base implementation of [[SparkAdapter]] for Spark 3.x branch */ -class Spark3Adapter extends SparkAdapter { +abstract class BaseSpark3Adapter extends SparkAdapter { - def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer = + override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer = new HoodieSparkAvroSerializer(rootCatalystType, rootAvroType, nullable) - def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer = + override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer = new HoodieSpark3AvroDeserializer(rootAvroType, rootCatalystType) override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = { diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark3_1CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark3_1CatalystExpressionUtils.scala new file mode 100644 index 000000000..cb9c31f08 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark3_1CatalystExpressionUtils.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} + +object HoodieSpark3_1CatalystExpressionUtils extends HoodieCatalystExpressionUtils { + + override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = { + expr match { + case OrderPreservingTransformation(attrRef) => Some(attrRef) + case _ => None + } + } + + private object OrderPreservingTransformation { + def unapply(expr: Expression): Option[AttributeReference] = { + expr match { + // Date/Time Expressions + case DateFormatClass(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case DateAdd(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateSub(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateDiff(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateDiff(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + case FromUnixTime(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case FromUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case ParseToDate(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case ParseToTimestamp(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case ToUnixTimestamp(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef) + case ToUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // String Expressions + case Lower(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Upper(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case org.apache.spark.sql.catalyst.expressions.Left(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + + // Math Expressions + // Binary + case Add(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case Add(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case Multiply(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case Multiply(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case Divide(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case BitwiseOr(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case BitwiseOr(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + // Unary + case Exp(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Expm1(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log10(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log1p(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log2(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case ShiftLeft(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case ShiftRight(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // Other + case cast @ Cast(OrderPreservingTransformation(attrRef), _, _) + if isCastPreservingOrdering(cast.child.dataType, cast.dataType) => Some(attrRef) + + // Identity transformation + case attrRef: AttributeReference => Some(attrRef) + // No match + case _ => None + } + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala new file mode 100644 index 000000000..106939cbb --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.adapter + +import org.apache.spark.sql.hudi.SparkAdapter +import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieSpark3_1CatalystExpressionUtils} + +/** + * Implementation of [[SparkAdapter]] for Spark 3.1.x + */ +class Spark3_1Adapter extends BaseSpark3Adapter { + + override def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils = HoodieSpark3_1CatalystExpressionUtils + +} diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/HoodieSpark3_2CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/HoodieSpark3_2CatalystExpressionUtils.scala new file mode 100644 index 000000000..8e056c033 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/HoodieSpark3_2CatalystExpressionUtils.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} + +object HoodieSpark3_2CatalystExpressionUtils extends HoodieCatalystExpressionUtils { + + override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = { + expr match { + case OrderPreservingTransformation(attrRef) => Some(attrRef) + case _ => None + } + } + + private object OrderPreservingTransformation { + def unapply(expr: Expression): Option[AttributeReference] = { + expr match { + // Date/Time Expressions + case DateFormatClass(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case DateAdd(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateSub(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateDiff(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateDiff(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + case FromUnixTime(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case FromUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case ParseToDate(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case ParseToTimestamp(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef) + case ToUnixTimestamp(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef) + case ToUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // String Expressions + case Lower(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Upper(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case org.apache.spark.sql.catalyst.expressions.Left(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + + // Math Expressions + // Binary + case Add(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case Add(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case Multiply(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case Multiply(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case Divide(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case BitwiseOr(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case BitwiseOr(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + // Unary + case Exp(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Expm1(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log10(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log1p(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log2(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case ShiftLeft(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case ShiftRight(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // Other + case cast @ Cast(OrderPreservingTransformation(attrRef), _, _, _) + if isCastPreservingOrdering(cast.child.dataType, cast.dataType) => Some(attrRef) + + // Identity transformation + case attrRef: AttributeReference => Some(attrRef) + // No match + case _ => None + } + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala index 699623f8b..1256344c3 100644 --- a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala +++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala @@ -17,16 +17,19 @@ package org.apache.spark.sql.adapter -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieSpark3_2CatalystExpressionUtils, SparkSession} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.parser.HoodieSpark3_2ExtendedSqlParser /** - * The adapter for spark3.2. + * Implementation of [[SparkAdapter]] for Spark 3.2.x branch */ -class Spark3_2Adapter extends Spark3Adapter { +class Spark3_2Adapter extends BaseSpark3Adapter { + + override def createCatalystExpressionUtils(): HoodieCatalystExpressionUtils = HoodieSpark3_2CatalystExpressionUtils + /** * if the logical plan is a TimeTravelRelation LogicalPlan. */