[HUDI-1879] Support Partition Prune For MergeOnRead Snapshot Table (#2926)
This commit is contained in:
@@ -0,0 +1,165 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hudi
|
||||
|
||||
import org.apache.hudi.HoodieSparkUtils.convertToCatalystExpressions
|
||||
import org.apache.hudi.HoodieSparkUtils.convertToCatalystExpression
|
||||
import org.apache.spark.sql.sources.{And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith}
|
||||
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType}
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
class TestConvertFilterToCatalystExpression {
|
||||
|
||||
private lazy val tableSchema = {
|
||||
val fields = new ArrayBuffer[StructField]()
|
||||
fields.append(StructField("id", LongType, nullable = false))
|
||||
fields.append(StructField("name", StringType, nullable = true))
|
||||
fields.append(StructField("price", DoubleType, nullable = true))
|
||||
fields.append(StructField("ts", IntegerType, nullable = false))
|
||||
StructType(fields)
|
||||
}
|
||||
|
||||
@Test
|
||||
def testBaseConvert(): Unit = {
|
||||
checkConvertFilter(eq("id", 1), "(`id` = 1)")
|
||||
checkConvertFilter(eqs("name", "a1"), "(`name` <=> 'a1')")
|
||||
checkConvertFilter(lt("price", 10), "(`price` < 10)")
|
||||
checkConvertFilter(lte("ts", 1), "(`ts` <= 1)")
|
||||
checkConvertFilter(gt("price", 10), "(`price` > 10)")
|
||||
checkConvertFilter(gte("price", 10), "(`price` >= 10)")
|
||||
checkConvertFilter(in("id", 1, 2 , 3), "(`id` IN (1, 2, 3))")
|
||||
checkConvertFilter(isNull("id"), "(`id` IS NULL)")
|
||||
checkConvertFilter(isNotNull("name"), "(`name` IS NOT NULL)")
|
||||
checkConvertFilter(and(lt("ts", 10), gt("ts", 1)),
|
||||
"((`ts` < 10) AND (`ts` > 1))")
|
||||
checkConvertFilter(or(lte("ts", 10), gte("ts", 1)),
|
||||
"((`ts` <= 10) OR (`ts` >= 1))")
|
||||
checkConvertFilter(not(and(lt("ts", 10), gt("ts", 1))),
|
||||
"(NOT ((`ts` < 10) AND (`ts` > 1)))")
|
||||
checkConvertFilter(startWith("name", "ab"), "`name` LIKE 'ab%'")
|
||||
checkConvertFilter(endWith("name", "cd"), "`name` LIKE '%cd'")
|
||||
checkConvertFilter(contains("name", "e"), "`name` LIKE '%e%'")
|
||||
}
|
||||
|
||||
@Test
|
||||
def testConvertFilters(): Unit = {
|
||||
checkConvertFilters(Array.empty[Filter], null)
|
||||
checkConvertFilters(Array(eq("id", 1)), "(`id` = 1)")
|
||||
checkConvertFilters(Array(lt("ts", 10), gt("ts", 1)),
|
||||
"((`ts` < 10) AND (`ts` > 1))")
|
||||
}
|
||||
|
||||
@Test
|
||||
def testUnSupportConvert(): Unit = {
|
||||
checkConvertFilters(Array(unsupport()), null)
|
||||
checkConvertFilters(Array(and(unsupport(), eq("id", 1))), null)
|
||||
checkConvertFilters(Array(or(unsupport(), eq("id", 1))), null)
|
||||
checkConvertFilters(Array(and(eq("id", 1), not(unsupport()))), null)
|
||||
}
|
||||
|
||||
private def checkConvertFilter(filter: Filter, expectExpression: String): Unit = {
|
||||
val exp = convertToCatalystExpression(filter, tableSchema)
|
||||
if (expectExpression == null) {
|
||||
assertEquals(exp.isEmpty, true)
|
||||
} else {
|
||||
assertEquals(exp.isDefined, true)
|
||||
assertEquals(expectExpression, exp.get.sql)
|
||||
}
|
||||
}
|
||||
|
||||
private def checkConvertFilters(filters: Array[Filter], expectExpression: String): Unit = {
|
||||
val exp = convertToCatalystExpressions(filters, tableSchema)
|
||||
if (expectExpression == null) {
|
||||
assertEquals(exp.isEmpty, true)
|
||||
} else {
|
||||
assertEquals(exp.isDefined, true)
|
||||
assertEquals(expectExpression, exp.get.sql)
|
||||
}
|
||||
}
|
||||
|
||||
private def eq(attribute: String, value: Any): Filter = {
|
||||
EqualTo(attribute, value)
|
||||
}
|
||||
|
||||
private def eqs(attribute: String, value: Any): Filter = {
|
||||
EqualNullSafe(attribute, value)
|
||||
}
|
||||
|
||||
private def gt(attribute: String, value: Any): Filter = {
|
||||
GreaterThan(attribute, value)
|
||||
}
|
||||
|
||||
private def gte(attribute: String, value: Any): Filter = {
|
||||
GreaterThanOrEqual(attribute, value)
|
||||
}
|
||||
|
||||
private def lt(attribute: String, value: Any): Filter = {
|
||||
LessThan(attribute, value)
|
||||
}
|
||||
|
||||
private def lte(attribute: String, value: Any): Filter = {
|
||||
LessThanOrEqual(attribute, value)
|
||||
}
|
||||
|
||||
private def in(attribute: String, values: Any*): Filter = {
|
||||
In(attribute, values.toArray)
|
||||
}
|
||||
|
||||
private def isNull(attribute: String): Filter = {
|
||||
IsNull(attribute)
|
||||
}
|
||||
|
||||
private def isNotNull(attribute: String): Filter = {
|
||||
IsNotNull(attribute)
|
||||
}
|
||||
|
||||
private def and(left: Filter, right: Filter): Filter = {
|
||||
And(left, right)
|
||||
}
|
||||
|
||||
private def or(left: Filter, right: Filter): Filter = {
|
||||
Or(left, right)
|
||||
}
|
||||
|
||||
private def not(child: Filter): Filter = {
|
||||
Not(child)
|
||||
}
|
||||
|
||||
private def startWith(attribute: String, value: String): Filter = {
|
||||
StringStartsWith(attribute, value)
|
||||
}
|
||||
|
||||
private def endWith(attribute: String, value: String): Filter = {
|
||||
StringEndsWith(attribute, value)
|
||||
}
|
||||
|
||||
private def contains(attribute: String, value: String): Filter = {
|
||||
StringContains(attribute, value)
|
||||
}
|
||||
|
||||
private def unsupport(): Filter = {
|
||||
UnSupportFilter("")
|
||||
}
|
||||
|
||||
case class UnSupportFilter(value: Any) extends Filter {
|
||||
override def references: Array[String] = Array.empty
|
||||
}
|
||||
}
|
||||
@@ -33,7 +33,7 @@ import org.apache.spark.sql.functions._
|
||||
import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
|
||||
import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
|
||||
import org.junit.jupiter.params.ParameterizedTest
|
||||
import org.junit.jupiter.params.provider.ValueSource
|
||||
import org.junit.jupiter.params.provider.{CsvSource, ValueSource}
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
@@ -614,4 +614,67 @@ class TestMORDataSource extends HoodieClientTestBase {
|
||||
.load(basePath)
|
||||
assertEquals(N + 1, hoodieIncViewDF1.count())
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource(Array("true, false", "false, true", "false, false", "true, true"))
|
||||
def testMORPartitionPrune(partitionEncode: Boolean, hiveStylePartition: Boolean): Unit = {
|
||||
val partitions = Array("2021/03/01", "2021/03/02", "2021/03/03", "2021/03/04", "2021/03/05")
|
||||
val newDataGen = new HoodieTestDataGenerator(partitions)
|
||||
val records1 = newDataGen.generateInsertsContainsAllPartitions("000", 100)
|
||||
val inputDF1 = spark.read.json(spark.sparkContext.parallelize(recordsToStrings(records1), 2))
|
||||
|
||||
val partitionCounts = partitions.map(p => p -> records1.count(r => r.getPartitionPath == p)).toMap
|
||||
|
||||
inputDF1.write.format("hudi")
|
||||
.options(commonOpts)
|
||||
.option(DataSourceWriteOptions.OPERATION_OPT_KEY, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL)
|
||||
.option(DataSourceWriteOptions.TABLE_TYPE_OPT_KEY, DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL)
|
||||
.option(DataSourceWriteOptions.URL_ENCODE_PARTITIONING_OPT_KEY, partitionEncode)
|
||||
.option(DataSourceWriteOptions.HIVE_STYLE_PARTITIONING_OPT_KEY, hiveStylePartition)
|
||||
.mode(SaveMode.Overwrite)
|
||||
.save(basePath)
|
||||
|
||||
val count1 = spark.read.format("hudi")
|
||||
.load(basePath)
|
||||
.filter("partition = '2021/03/01'")
|
||||
.count()
|
||||
assertEquals(partitionCounts("2021/03/01"), count1)
|
||||
|
||||
val count2 = spark.read.format("hudi")
|
||||
.load(basePath)
|
||||
.filter("partition > '2021/03/01' and partition < '2021/03/03'")
|
||||
.count()
|
||||
assertEquals(partitionCounts("2021/03/02"), count2)
|
||||
|
||||
val count3 = spark.read.format("hudi")
|
||||
.load(basePath)
|
||||
.filter("partition != '2021/03/01'")
|
||||
.count()
|
||||
assertEquals(records1.size() - partitionCounts("2021/03/01"), count3)
|
||||
|
||||
val count4 = spark.read.format("hudi")
|
||||
.load(basePath)
|
||||
.filter("partition like '2021/03/03%'")
|
||||
.count()
|
||||
assertEquals(partitionCounts("2021/03/03"), count4)
|
||||
|
||||
val count5 = spark.read.format("hudi")
|
||||
.load(basePath)
|
||||
.filter("partition like '%2021/03/%'")
|
||||
.count()
|
||||
assertEquals(records1.size(), count5)
|
||||
|
||||
val count6 = spark.read.format("hudi")
|
||||
.load(basePath)
|
||||
.filter("partition = '2021/03/01' or partition = '2021/03/05'")
|
||||
.count()
|
||||
assertEquals(partitionCounts("2021/03/01") + partitionCounts("2021/03/05"), count6)
|
||||
|
||||
val count7 = spark.read.format("hudi")
|
||||
.load(basePath)
|
||||
.filter("substr(partition, 9, 10) = '03'")
|
||||
.count()
|
||||
|
||||
assertEquals(partitionCounts("2021/03/03"), count7)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user