[HUDI-1659] Basic Implement Of Spark Sql Support For Hoodie (#2645)
Main functions: Support create table for hoodie. Support CTAS. Support Insert for hoodie. Including dynamic partition and static partition insert. Support MergeInto for hoodie. Support DELETE Support UPDATE Both support spark2 & spark3 based on DataSourceV1. Main changes: Add sql parser for spark2. Add HoodieAnalysis for sql resolve and logical plan rewrite. Add commands implementation for CREATE TABLE、INSERT、MERGE INTO & CTAS. In order to push down the update&insert logical to the HoodieRecordPayload for MergeInto, I make same change to the HoodieWriteHandler and other related classes. 1、Add the inputSchema for parser the incoming record. This is because the inputSchema for MergeInto is different from writeSchema as there are some transforms in the update& insert expression. 2、Add WRITE_SCHEMA to HoodieWriteConfig to pass the write schema for merge into. 3、Pass properties to HoodieRecordPayload#getInsertValue to pass the insert expression and table schema. Verify this pull request Add TestCreateTable for test create hoodie tables and CTAS. Add TestInsertTable for test insert hoodie tables. Add TestMergeIntoTable for test merge hoodie tables. Add TestUpdateTable for test update hoodie tables. Add TestDeleteTable for test delete hoodie tables. Add TestSqlStatement for test supported ddl/dml currently.
This commit is contained in:
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.adapter
|
||||
|
||||
import org.apache.hudi.Spark2RowSerDe
|
||||
import org.apache.hudi.client.utils.SparkRowSerDe
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, Like}
|
||||
import org.apache.spark.sql.catalyst.parser.ParserInterface
|
||||
import org.apache.spark.sql.catalyst.plans.JoinType
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Join, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
|
||||
import org.apache.spark.sql.execution.datasources.{Spark2ParsePartitionUtil, SparkParsePartitionUtil}
|
||||
import org.apache.spark.sql.hudi.SparkAdapter
|
||||
import org.apache.spark.sql.hudi.parser.HoodieSqlParser
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
/**
|
||||
* A sql adapter for spark2.
|
||||
*/
|
||||
class Spark2Adapter extends SparkAdapter {
|
||||
|
||||
override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
|
||||
new Spark2RowSerDe(encoder)
|
||||
}
|
||||
|
||||
override def toTableIdentify(aliasId: AliasIdentifier): TableIdentifier = {
|
||||
TableIdentifier(aliasId.identifier, aliasId.database)
|
||||
}
|
||||
|
||||
override def toTableIdentify(relation: UnresolvedRelation): TableIdentifier = {
|
||||
relation.tableIdentifier
|
||||
}
|
||||
|
||||
override def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join = {
|
||||
Join(left, right, joinType, None)
|
||||
}
|
||||
|
||||
override def isInsertInto(plan: LogicalPlan): Boolean = {
|
||||
plan.isInstanceOf[InsertIntoTable]
|
||||
}
|
||||
|
||||
override def getInsertIntoChildren(plan: LogicalPlan):
|
||||
Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = {
|
||||
plan match {
|
||||
case InsertIntoTable(table, partition, query, overwrite, ifPartitionNotExists) =>
|
||||
Some((table, partition, query, overwrite, ifPartitionNotExists))
|
||||
case _=> None
|
||||
}
|
||||
}
|
||||
|
||||
override def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]],
|
||||
query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan = {
|
||||
InsertIntoTable(table, partition, query, overwrite, ifPartitionNotExists)
|
||||
}
|
||||
|
||||
override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = {
|
||||
Some(
|
||||
(spark: SparkSession, delegate: ParserInterface) => new HoodieSqlParser(spark, delegate)
|
||||
)
|
||||
}
|
||||
|
||||
override def createSparkParsePartitionUtil(conf: SQLConf): SparkParsePartitionUtil = new Spark2ParsePartitionUtil
|
||||
|
||||
override def createLike(left: Expression, right: Expression): Expression = {
|
||||
Like(left, right)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
// This code is just copy from v2Commands.scala in spark 3.0
|
||||
case class DeleteFromTable(
|
||||
table: LogicalPlan,
|
||||
condition: Option[Expression]) extends Command {
|
||||
override def children: Seq[LogicalPlan] = Seq(table)
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
// This code is just copy from v2Commands.scala in spark 3.0
|
||||
|
||||
/**
|
||||
* The logical plan of the MERGE INTO command that works for v2 tables.
|
||||
*/
|
||||
case class MergeIntoTable(
|
||||
targetTable: LogicalPlan,
|
||||
sourceTable: LogicalPlan,
|
||||
mergeCondition: Expression,
|
||||
matchedActions: Seq[MergeAction],
|
||||
notMatchedActions: Seq[MergeAction]) extends Command {
|
||||
override def children: Seq[LogicalPlan] = Seq(targetTable, sourceTable)
|
||||
}
|
||||
|
||||
|
||||
sealed abstract class MergeAction extends Expression with Unevaluable {
|
||||
def condition: Option[Expression]
|
||||
override def foldable: Boolean = false
|
||||
override def nullable: Boolean = false
|
||||
override def dataType: DataType = throw new UnresolvedException(this, "nullable")
|
||||
override def children: Seq[Expression] = condition.toSeq
|
||||
}
|
||||
|
||||
case class DeleteAction(condition: Option[Expression]) extends MergeAction
|
||||
|
||||
case class UpdateAction(
|
||||
condition: Option[Expression],
|
||||
assignments: Seq[Assignment]) extends MergeAction {
|
||||
override def children: Seq[Expression] = condition.toSeq ++ assignments
|
||||
}
|
||||
|
||||
case class InsertAction(
|
||||
condition: Option[Expression],
|
||||
assignments: Seq[Assignment]) extends MergeAction {
|
||||
override def children: Seq[Expression] = condition.toSeq ++ assignments
|
||||
}
|
||||
|
||||
case class Assignment(key: Expression, value: Expression) extends Expression with Unevaluable {
|
||||
override def foldable: Boolean = false
|
||||
override def nullable: Boolean = false
|
||||
override def dataType: DataType = throw new UnresolvedException(this, "nullable")
|
||||
override def children: Seq[Expression] = key :: value :: Nil
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
// This code is just copy from v2Commands.scala in spark 3.0
|
||||
case class UpdateTable(
|
||||
table: LogicalPlan,
|
||||
assignments: Seq[Assignment],
|
||||
condition: Option[Expression]
|
||||
) extends Command {
|
||||
override def children: Seq[LogicalPlan] = Seq(table)
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.spark.sql.hudi.parser
|
||||
|
||||
import org.antlr.v4.runtime.tree.ParseTree
|
||||
import org.apache.hudi.spark.sql.parser.HoodieSqlBaseBaseVisitor
|
||||
import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser._
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface, ParserUtils}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
/**
|
||||
* The AstBuilder for HoodieSqlParser to parser the AST tree to Logical Plan.
|
||||
* Here we only do the parser for the extended sql syntax. e.g MergeInto. For
|
||||
* other sql syntax we use the delegate sql parser which is the SparkSqlParser.
|
||||
*/
|
||||
class HoodieSqlAstBuilder(conf: SQLConf, delegate: ParserInterface) extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging {
|
||||
|
||||
import ParserUtils._
|
||||
|
||||
override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
|
||||
ctx.statement().accept(this).asInstanceOf[LogicalPlan]
|
||||
}
|
||||
|
||||
override def visitMergeIntoTable (ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
|
||||
visitMergeInto(ctx.mergeInto())
|
||||
}
|
||||
|
||||
override def visitMergeInto(ctx: MergeIntoContext): LogicalPlan = withOrigin(ctx) {
|
||||
val target = UnresolvedRelation(visitTableIdentifier(ctx.target))
|
||||
val source = if (ctx.source != null) {
|
||||
UnresolvedRelation(visitTableIdentifier(ctx.source))
|
||||
} else {
|
||||
val queryText = treeToString(ctx.subquery)
|
||||
delegate.parsePlan(queryText)
|
||||
}
|
||||
val aliasedTarget =
|
||||
if (ctx.tableAlias(0) != null) mayApplyAliasPlan(ctx.tableAlias(0), target) else target
|
||||
val aliasedSource =
|
||||
if (ctx.tableAlias(1) != null) mayApplyAliasPlan(ctx.tableAlias(1), source) else source
|
||||
|
||||
val mergeCondition = expression(ctx.mergeCondition().condition)
|
||||
|
||||
if (ctx.matchedClauses().size() > 2) {
|
||||
throw new ParseException("There should be at most 2 'WHEN MATCHED' clauses.",
|
||||
ctx.matchedClauses.get(2))
|
||||
}
|
||||
|
||||
val matchedClauses: Seq[MergeAction] = ctx.matchedClauses().asScala.flatMap {
|
||||
c =>
|
||||
val deleteCtx = c.deleteClause()
|
||||
val deleteClause = if (deleteCtx != null) {
|
||||
val deleteCond = if (deleteCtx.deleteCond != null) {
|
||||
Some(expression(deleteCtx.deleteCond))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
Some(DeleteAction(deleteCond))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
val updateCtx = c.updateClause()
|
||||
val updateClause = if (updateCtx != null) {
|
||||
val updateAction = updateCtx.updateAction()
|
||||
val updateCond = if (updateCtx.updateCond != null) {
|
||||
Some(expression(updateCtx.updateCond))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
if (updateAction.ASTERISK() != null) {
|
||||
Some(UpdateAction(updateCond, Seq.empty))
|
||||
} else {
|
||||
val assignments = withAssignments(updateAction.assignmentList())
|
||||
Some(UpdateAction(updateCond, assignments))
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
(deleteClause ++ updateClause).toSeq
|
||||
}
|
||||
val notMatchedClauses: Seq[InsertAction] = ctx.notMatchedClause().asScala.map {
|
||||
notMatchedClause =>
|
||||
val insertCtx = notMatchedClause.insertClause()
|
||||
val insertAction = insertCtx.insertAction()
|
||||
val insertCond = if (insertCtx.insertCond != null) {
|
||||
Some(expression(insertCtx.insertCond))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
if (insertAction.ASTERISK() != null) {
|
||||
InsertAction(insertCond, Seq.empty)
|
||||
} else {
|
||||
val attrList = insertAction.columns.qualifiedName().asScala
|
||||
.map(attr => UnresolvedAttribute(visitQualifiedName(attr)))
|
||||
val attrSet = scala.collection.mutable.Set[UnresolvedAttribute]()
|
||||
attrList.foreach(attr => {
|
||||
if (attrSet.contains(attr)) {
|
||||
throw new ParseException(s"find duplicate field :'${attr.name}'",
|
||||
insertAction.columns)
|
||||
}
|
||||
attrSet += attr
|
||||
})
|
||||
val valueList = insertAction.expression().asScala.map(expression)
|
||||
if (attrList.size != valueList.size) {
|
||||
throw new ParseException("The columns of source and target tables are not equal: " +
|
||||
s"target: $attrList, source: $valueList", insertAction)
|
||||
}
|
||||
val assignments = attrList.zip(valueList).map(kv => Assignment(kv._1, kv._2))
|
||||
InsertAction(insertCond, assignments)
|
||||
}
|
||||
}
|
||||
MergeIntoTable(aliasedTarget, aliasedSource, mergeCondition,
|
||||
matchedClauses, notMatchedClauses)
|
||||
}
|
||||
|
||||
private def withAssignments(assignCtx: AssignmentListContext): Seq[Assignment] =
|
||||
withOrigin(assignCtx) {
|
||||
assignCtx.assignment().asScala.map { assign =>
|
||||
Assignment(UnresolvedAttribute(visitQualifiedName(assign.key)),
|
||||
expression(assign.value))
|
||||
}
|
||||
}
|
||||
|
||||
override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
|
||||
val updateStmt = ctx.updateTableStmt()
|
||||
val table = UnresolvedRelation(visitTableIdentifier(updateStmt.tableIdentifier()))
|
||||
val condition = if (updateStmt.where != null) Some(expression(updateStmt.where)) else None
|
||||
val assignments = withAssignments(updateStmt.assignmentList())
|
||||
UpdateTable(table, assignments, condition)
|
||||
}
|
||||
|
||||
override def visitDeleteTable (ctx: DeleteTableContext): LogicalPlan = withOrigin(ctx) {
|
||||
val deleteStmt = ctx.deleteTableStmt()
|
||||
val table = UnresolvedRelation(visitTableIdentifier(deleteStmt.tableIdentifier()))
|
||||
val condition = if (deleteStmt.where != null) Some(expression(deleteStmt.where)) else None
|
||||
DeleteFromTable(table, condition)
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert the tree to string.
|
||||
*/
|
||||
private def treeToString(tree: ParseTree): String = {
|
||||
if (tree.getChildCount == 0) {
|
||||
tree.getText
|
||||
} else {
|
||||
(for (i <- 0 until tree.getChildCount) yield {
|
||||
treeToString(tree.getChild(i))
|
||||
}).mkString(" ")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the expression tree to spark sql Expression.
|
||||
* Here we use the SparkSqlParser to do the parse.
|
||||
*/
|
||||
private def expression(tree: ParseTree): Expression = {
|
||||
val expressionText = treeToString(tree)
|
||||
delegate.parseExpression(expressionText)
|
||||
}
|
||||
|
||||
// ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder
|
||||
/**
|
||||
* If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and
|
||||
* column aliases for a [[LogicalPlan]].
|
||||
*/
|
||||
protected def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = {
|
||||
if (tableAlias.strictIdentifier != null) {
|
||||
val subquery = SubqueryAlias(tableAlias.strictIdentifier.getText, plan)
|
||||
if (tableAlias.identifierList != null) {
|
||||
val columnNames = visitIdentifierList(tableAlias.identifierList)
|
||||
UnresolvedSubqueryColumnAliases(columnNames, subquery)
|
||||
} else {
|
||||
subquery
|
||||
}
|
||||
} else {
|
||||
plan
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Parse a qualified name to a multipart name.
|
||||
*/
|
||||
override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) {
|
||||
ctx.identifier.asScala.map(_.getText)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a Sequence of Strings for a parenthesis enclosed alias list.
|
||||
*/
|
||||
override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
|
||||
visitIdentifierSeq(ctx.identifierSeq)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a Sequence of Strings for an identifier list.
|
||||
*/
|
||||
override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
|
||||
ctx.identifier.asScala.map(_.getText)
|
||||
}
|
||||
|
||||
/* ********************************************************************************************
|
||||
* Table Identifier parsing
|
||||
* ******************************************************************************************** */
|
||||
/**
|
||||
* Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
|
||||
*/
|
||||
override def visitTableIdentifier(ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
|
||||
TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.parser
|
||||
|
||||
import org.antlr.v4.runtime._
|
||||
import org.antlr.v4.runtime.atn.PredictionMode
|
||||
import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
|
||||
import org.antlr.v4.runtime.tree.TerminalNodeImpl
|
||||
import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser.{NonReservedContext, QuotedIdentifierContext}
|
||||
import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseListener, HoodieSqlBaseLexer, HoodieSqlBaseParser}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.trees.Origin
|
||||
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{AnalysisException, SparkSession}
|
||||
|
||||
class HoodieSqlParser(session: SparkSession, delegate: ParserInterface)
|
||||
extends ParserInterface with Logging {
|
||||
|
||||
private lazy val conf = session.sqlContext.conf
|
||||
private lazy val builder = new HoodieSqlAstBuilder(conf, delegate)
|
||||
|
||||
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
|
||||
builder.visit(parser.singleStatement()) match {
|
||||
case plan: LogicalPlan => plan
|
||||
case _=> delegate.parsePlan(sqlText)
|
||||
}
|
||||
}
|
||||
|
||||
override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText)
|
||||
|
||||
override def parseTableIdentifier(sqlText: String): TableIdentifier =
|
||||
delegate.parseTableIdentifier(sqlText)
|
||||
|
||||
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
|
||||
delegate.parseFunctionIdentifier(sqlText)
|
||||
|
||||
override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText)
|
||||
|
||||
override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText)
|
||||
|
||||
protected def parse[T](command: String)(toResult: HoodieSqlBaseParser => T): T = {
|
||||
logDebug(s"Parsing command: $command")
|
||||
|
||||
val lexer = new HoodieSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
|
||||
lexer.removeErrorListeners()
|
||||
lexer.addErrorListener(ParseErrorListener)
|
||||
lexer.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced
|
||||
|
||||
val tokenStream = new CommonTokenStream(lexer)
|
||||
val parser = new HoodieSqlBaseParser(tokenStream)
|
||||
parser.addParseListener(PostProcessor)
|
||||
parser.removeErrorListeners()
|
||||
parser.addErrorListener(ParseErrorListener)
|
||||
parser.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced
|
||||
|
||||
try {
|
||||
try {
|
||||
// first, try parsing with potentially faster SLL mode
|
||||
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
|
||||
toResult(parser)
|
||||
}
|
||||
catch {
|
||||
case e: ParseCancellationException =>
|
||||
// if we fail, parse with LL mode
|
||||
tokenStream.seek(0) // rewind input stream
|
||||
parser.reset()
|
||||
|
||||
// Try Again.
|
||||
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
|
||||
toResult(parser)
|
||||
}
|
||||
}
|
||||
catch {
|
||||
case e: ParseException if e.command.isDefined =>
|
||||
throw e
|
||||
case e: ParseException =>
|
||||
throw e.withCommand(command)
|
||||
case e: AnalysisException =>
|
||||
val position = Origin(e.line, e.startPosition)
|
||||
throw new ParseException(Option(command), e.message, position, position)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`.
|
||||
*/
|
||||
class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
|
||||
override def consume(): Unit = wrapped.consume
|
||||
override def getSourceName(): String = wrapped.getSourceName
|
||||
override def index(): Int = wrapped.index
|
||||
override def mark(): Int = wrapped.mark
|
||||
override def release(marker: Int): Unit = wrapped.release(marker)
|
||||
override def seek(where: Int): Unit = wrapped.seek(where)
|
||||
override def size(): Int = wrapped.size
|
||||
|
||||
override def getText(interval: Interval): String = {
|
||||
// ANTLR 4.7's CodePointCharStream implementations have bugs when
|
||||
// getText() is called with an empty stream, or intervals where
|
||||
// the start > end. See
|
||||
// https://github.com/antlr/antlr4/commit/ac9f7530 for one fix
|
||||
// that is not yet in a released ANTLR artifact.
|
||||
if (size() > 0 && (interval.b - interval.a >= 0)) {
|
||||
wrapped.getText(interval)
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
// scalastyle:off
|
||||
override def LA(i: Int): Int = {
|
||||
// scalastyle:on
|
||||
val la = wrapped.LA(i)
|
||||
if (la == 0 || la == IntStream.EOF) la
|
||||
else Character.toUpperCase(la)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`.
|
||||
*/
|
||||
case object PostProcessor extends HoodieSqlBaseBaseListener {
|
||||
|
||||
/** Remove the back ticks from an Identifier. */
|
||||
override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
|
||||
replaceTokenByIdentifier(ctx, 1) { token =>
|
||||
// Remove the double back ticks in the string.
|
||||
token.setText(token.getText.replace("``", "`"))
|
||||
token
|
||||
}
|
||||
}
|
||||
|
||||
/** Treat non-reserved keywords as Identifiers. */
|
||||
override def exitNonReserved(ctx: NonReservedContext): Unit = {
|
||||
replaceTokenByIdentifier(ctx, 0)(identity)
|
||||
}
|
||||
|
||||
private def replaceTokenByIdentifier(
|
||||
ctx: ParserRuleContext,
|
||||
stripMargins: Int)(
|
||||
f: CommonToken => CommonToken = identity): Unit = {
|
||||
val parent = ctx.getParent
|
||||
parent.removeLastChild()
|
||||
val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
|
||||
val newToken = new CommonToken(
|
||||
new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
|
||||
HoodieSqlBaseParser.IDENTIFIER,
|
||||
token.getChannel,
|
||||
token.getStartIndex + stripMargins,
|
||||
token.getStopIndex - stripMargins)
|
||||
parent.addChild(new TerminalNodeImpl(f(newToken)))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user