1
0

[HUDI-4309] fix spark32 repartition error (#6033)

This commit is contained in:
KnightChess
2022-07-08 09:38:09 +08:00
committed by GitHub
parent e74ad324c3
commit 5673819736
3 changed files with 66 additions and 10 deletions

View File

@@ -250,4 +250,48 @@ class TestTimeTravelTable extends HoodieSparkSqlTestBase {
}
}
}
test("Test Select Record with time travel and Repartition") {
if (HoodieSparkUtils.gteqSpark3_2) {
withTempDir { tmp =>
val tableName = generateTableName
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| tblproperties (
| type = 'cow',
| primaryKey = 'id',
| preCombineField = 'ts'
| )
| location '${tmp.getCanonicalPath}/$tableName'
""".stripMargin)
spark.sql(s"insert into $tableName values(1, 'a1', 10, 1000)")
val metaClient = HoodieTableMetaClient.builder()
.setBasePath(s"${tmp.getCanonicalPath}/$tableName")
.setConf(spark.sessionState.newHadoopConf())
.build()
val instant = metaClient.getActiveTimeline.getAllCommitsTimeline
.lastInstant().get().getTimestamp
spark.sql(s"insert into $tableName values(1, 'a2', 20, 2000)")
checkAnswer(s"select id, name, price, ts from $tableName distribute by cast(rand() * 2 as int)")(
Seq(1, "a2", 20.0, 2000)
)
// time travel from instant
checkAnswer(
s"select id, name, price, ts from $tableName TIMESTAMP AS OF '$instant' distribute by cast(rand() * 2 as int)")(
Seq(1, "a1", 10.0, 1000)
)
}
}
}
}

View File

@@ -623,7 +623,7 @@ class HoodieSpark3_2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterfa
ctx: QueryOrganizationContext,
expressions: Seq[Expression],
query: LogicalPlan): LogicalPlan = {
throw new ParseException("DISTRIBUTE BY is not supported", ctx)
RepartitionByExpression(expressions, query, None)
}
override def visitTransformQuerySpecification(

View File

@@ -29,26 +29,30 @@ import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException,
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.internal.VariableSubstitution
import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, SparkSession}
import scala.util.control.NonFatal
import java.util.Locale
class HoodieSpark3_2ExtendedSqlParser(session: SparkSession, delegate: ParserInterface)
extends ParserInterface with Logging {
private lazy val conf = session.sqlContext.conf
private lazy val builder = new HoodieSpark3_2ExtendedSqlAstBuilder(conf, delegate)
private val substitutor = new VariableSubstitution
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
try {
builder.visit(parser.singleStatement()) match {
case plan: LogicalPlan => plan
case _=> delegate.parsePlan(sqlText)
override def parsePlan(sqlText: String): LogicalPlan = {
val substitutionSql = substitutor.substitute(sqlText)
if (isHoodieCommand(substitutionSql)) {
parse(substitutionSql) { parser =>
builder.visit(parser.singleStatement()) match {
case plan: LogicalPlan => plan
case _ => delegate.parsePlan(sqlText)
}
}
} catch {
case NonFatal(_) =>
delegate.parsePlan(sqlText)
} else {
delegate.parsePlan(substitutionSql)
}
}
@@ -111,6 +115,14 @@ class HoodieSpark3_2ExtendedSqlParser(session: SparkSession, delegate: ParserInt
override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
delegate.parseMultipartIdentifier(sqlText)
}
private def isHoodieCommand(sqlText: String): Boolean = {
val normalized = sqlText.toLowerCase(Locale.ROOT).trim().replaceAll("\\s+", " ")
normalized.contains("system_time as of") ||
normalized.contains("timestamp as of") ||
normalized.contains("system_version as of") ||
normalized.contains("version as of")
}
}
/**