[HUDI-3161] Add Call Produce Command for Spark SQL (#4535)
This commit is contained in:
@@ -14,59 +14,197 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
grammar HoodieSqlCommon;
|
||||
|
||||
grammar HoodieSqlCommon;
|
||||
|
||||
@lexer::members {
|
||||
/**
|
||||
* Verify whether current token is a valid decimal token (which contains dot).
|
||||
* Returns true if the character that follows the token is not a digit or letter or underscore.
|
||||
*
|
||||
* For example:
|
||||
* For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'.
|
||||
* For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'.
|
||||
* For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'.
|
||||
* For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed
|
||||
* by a space. 34.E2 is a valid decimal token because it is followed by symbol '+'
|
||||
* which is not a digit or letter or underscore.
|
||||
*/
|
||||
public boolean isValidDecimal() {
|
||||
int nextChar = _input.LA(1);
|
||||
if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' ||
|
||||
nextChar == '_') {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
singleStatement
|
||||
: statement EOF
|
||||
;
|
||||
|
||||
statement
|
||||
: compactionStatement #compactionCommand
|
||||
| .*? #passThrough
|
||||
: compactionStatement #compactionCommand
|
||||
| CALL multipartIdentifier '(' (callArgument (',' callArgument)*)? ')' #call
|
||||
| .*? #passThrough
|
||||
;
|
||||
|
||||
compactionStatement
|
||||
: operation = (RUN | SCHEDULE) COMPACTION ON tableIdentifier (AT instantTimestamp = NUMBER)? #compactionOnTable
|
||||
| operation = (RUN | SCHEDULE) COMPACTION ON path = STRING (AT instantTimestamp = NUMBER)? #compactionOnPath
|
||||
| SHOW COMPACTION ON tableIdentifier (LIMIT limit = NUMBER)? #showCompactionOnTable
|
||||
| SHOW COMPACTION ON path = STRING (LIMIT limit = NUMBER)? #showCompactionOnPath
|
||||
: operation = (RUN | SCHEDULE) COMPACTION ON tableIdentifier (AT instantTimestamp = INTEGER_VALUE)? #compactionOnTable
|
||||
| operation = (RUN | SCHEDULE) COMPACTION ON path = STRING (AT instantTimestamp = INTEGER_VALUE)? #compactionOnPath
|
||||
| SHOW COMPACTION ON tableIdentifier (LIMIT limit = INTEGER_VALUE)? #showCompactionOnTable
|
||||
| SHOW COMPACTION ON path = STRING (LIMIT limit = INTEGER_VALUE)? #showCompactionOnPath
|
||||
;
|
||||
|
||||
tableIdentifier
|
||||
: (db=IDENTIFIER '.')? table=IDENTIFIER
|
||||
;
|
||||
|
||||
callArgument
|
||||
: expression #positionalArgument
|
||||
| identifier '=>' expression #namedArgument
|
||||
;
|
||||
|
||||
expression
|
||||
: constant
|
||||
| stringMap
|
||||
;
|
||||
|
||||
constant
|
||||
: number #numericLiteral
|
||||
| booleanValue #booleanLiteral
|
||||
| STRING+ #stringLiteral
|
||||
| identifier STRING #typeConstructor
|
||||
;
|
||||
|
||||
stringMap
|
||||
: MAP '(' constant (',' constant)* ')'
|
||||
;
|
||||
|
||||
booleanValue
|
||||
: TRUE | FALSE
|
||||
;
|
||||
|
||||
number
|
||||
: MINUS? EXPONENT_VALUE #exponentLiteral
|
||||
| MINUS? DECIMAL_VALUE #decimalLiteral
|
||||
| MINUS? INTEGER_VALUE #integerLiteral
|
||||
| MINUS? BIGINT_LITERAL #bigIntLiteral
|
||||
| MINUS? SMALLINT_LITERAL #smallIntLiteral
|
||||
| MINUS? TINYINT_LITERAL #tinyIntLiteral
|
||||
| MINUS? DOUBLE_LITERAL #doubleLiteral
|
||||
| MINUS? FLOAT_LITERAL #floatLiteral
|
||||
| MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral
|
||||
;
|
||||
|
||||
multipartIdentifier
|
||||
: parts+=identifier ('.' parts+=identifier)*
|
||||
;
|
||||
|
||||
identifier
|
||||
: IDENTIFIER #unquotedIdentifier
|
||||
| quotedIdentifier #quotedIdentifierAlternative
|
||||
| nonReserved #unquotedIdentifier
|
||||
;
|
||||
|
||||
quotedIdentifier
|
||||
: BACKQUOTED_IDENTIFIER
|
||||
;
|
||||
|
||||
nonReserved
|
||||
: CALL | COMPACTION | RUN | SCHEDULE | ON | SHOW | LIMIT
|
||||
;
|
||||
|
||||
ALL: 'ALL';
|
||||
AT: 'AT';
|
||||
CALL: 'CALL';
|
||||
COMPACTION: 'COMPACTION';
|
||||
RUN: 'RUN';
|
||||
SCHEDULE: 'SCHEDULE';
|
||||
ON: 'ON';
|
||||
SHOW: 'SHOW';
|
||||
LIMIT: 'LIMIT';
|
||||
MAP: 'MAP';
|
||||
NULL: 'NULL';
|
||||
TRUE: 'TRUE';
|
||||
FALSE: 'FALSE';
|
||||
INTERVAL: 'INTERVAL';
|
||||
TO: 'TO';
|
||||
|
||||
NUMBER
|
||||
: DIGIT+
|
||||
;
|
||||
|
||||
IDENTIFIER
|
||||
: (LETTER | DIGIT | '_')+
|
||||
;
|
||||
PLUS: '+';
|
||||
MINUS: '-';
|
||||
|
||||
STRING
|
||||
: '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
|
||||
| '"' ( ~('"'|'\\') | ('\\' .) )* '"'
|
||||
;
|
||||
|
||||
BIGINT_LITERAL
|
||||
: DIGIT+ 'L'
|
||||
;
|
||||
|
||||
SMALLINT_LITERAL
|
||||
: DIGIT+ 'S'
|
||||
;
|
||||
|
||||
TINYINT_LITERAL
|
||||
: DIGIT+ 'Y'
|
||||
;
|
||||
|
||||
INTEGER_VALUE
|
||||
: DIGIT+
|
||||
;
|
||||
|
||||
EXPONENT_VALUE
|
||||
: DIGIT+ EXPONENT
|
||||
| DECIMAL_DIGITS EXPONENT {isValidDecimal()}?
|
||||
;
|
||||
|
||||
DECIMAL_VALUE
|
||||
: DECIMAL_DIGITS {isValidDecimal()}?
|
||||
;
|
||||
|
||||
FLOAT_LITERAL
|
||||
: DIGIT+ EXPONENT? 'F'
|
||||
| DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}?
|
||||
;
|
||||
|
||||
DOUBLE_LITERAL
|
||||
: DIGIT+ EXPONENT? 'D'
|
||||
| DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}?
|
||||
;
|
||||
|
||||
BIGDECIMAL_LITERAL
|
||||
: DIGIT+ EXPONENT? 'BD'
|
||||
| DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}?
|
||||
;
|
||||
|
||||
IDENTIFIER
|
||||
: (LETTER | DIGIT | '_')+
|
||||
;
|
||||
|
||||
BACKQUOTED_IDENTIFIER
|
||||
: '`' ( ~'`' | '``' )* '`'
|
||||
;
|
||||
|
||||
fragment DECIMAL_DIGITS
|
||||
: DIGIT+ '.' DIGIT*
|
||||
| '.' DIGIT+
|
||||
;
|
||||
|
||||
fragment EXPONENT
|
||||
: 'E' [+-]? DIGIT+
|
||||
;
|
||||
|
||||
fragment DIGIT
|
||||
: [0-9]
|
||||
;
|
||||
: [0-9]
|
||||
;
|
||||
|
||||
fragment LETTER
|
||||
: [A-Z]
|
||||
;
|
||||
: [A-Z]
|
||||
;
|
||||
|
||||
SIMPLE_COMMENT
|
||||
: '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN)
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
|
||||
case class CallCommand(name: Seq[String], args: Seq[CallArgument]) extends Command {
|
||||
override def children: Seq[LogicalPlan] = Seq.empty
|
||||
|
||||
def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): CallCommand = {
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An argument in a CALL statement.
|
||||
*/
|
||||
sealed trait CallArgument {
|
||||
def expr: Expression
|
||||
}
|
||||
|
||||
/**
|
||||
* An argument in a CALL statement identified by name.
|
||||
*/
|
||||
case class NamedArgument(name: String, expr: Expression) extends CallArgument
|
||||
|
||||
/**
|
||||
* An argument in a CALL statement identified by position.
|
||||
*/
|
||||
case class PositionalArgument(expr: Expression) extends CallArgument
|
||||
@@ -22,7 +22,7 @@ import org.apache.hudi.common.model.HoodieRecord
|
||||
import org.apache.hudi.common.util.ReflectionUtils
|
||||
import org.apache.hudi.{HoodieSparkUtils, SparkAdapterSupport}
|
||||
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, Literal, NamedExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, GenericInternalRow, Literal, NamedExpression}
|
||||
import org.apache.spark.sql.catalyst.plans.Inner
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
@@ -31,10 +31,12 @@ import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation}
|
||||
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{getTableIdentifier, removeMetaFields}
|
||||
import org.apache.spark.sql.hudi.HoodieSqlUtils._
|
||||
import org.apache.spark.sql.hudi.command._
|
||||
import org.apache.spark.sql.hudi.command.procedures.{HoodieProcedures, Procedure, ProcedureArgs}
|
||||
import org.apache.spark.sql.hudi.{HoodieOptionConfig, HoodieSqlCommonUtils}
|
||||
import org.apache.spark.sql.types.StringType
|
||||
import org.apache.spark.sql.{AnalysisException, SparkSession}
|
||||
|
||||
import java.util
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
object HoodieAnalysis {
|
||||
@@ -79,6 +81,7 @@ object HoodieAnalysis {
|
||||
|
||||
/**
|
||||
* Rule for convert the logical plan to command.
|
||||
*
|
||||
* @param sparkSession
|
||||
*/
|
||||
case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan]
|
||||
@@ -133,26 +136,69 @@ case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan]
|
||||
// Convert to CompactionShowHoodiePathCommand
|
||||
case CompactionShowOnPath(path, limit) =>
|
||||
CompactionShowHoodiePathCommand(path, limit)
|
||||
case _=> plan
|
||||
// Convert to HoodieCallProcedureCommand
|
||||
case c@CallCommand(_, _) =>
|
||||
val procedure: Option[Procedure] = loadProcedure(c.name)
|
||||
val input = buildProcedureArgs(c.args)
|
||||
if (procedure.nonEmpty) {
|
||||
CallProcedureHoodieCommand(procedure.get, input)
|
||||
} else {
|
||||
c
|
||||
}
|
||||
case _ => plan
|
||||
}
|
||||
}
|
||||
|
||||
private def loadProcedure(name: Seq[String]): Option[Procedure] = {
|
||||
val procedure: Option[Procedure] = if (name.nonEmpty) {
|
||||
val builder = HoodieProcedures.newBuilder(name.last)
|
||||
if (builder != null) {
|
||||
Option(builder.build)
|
||||
} else {
|
||||
throw new AnalysisException(s"procedure: ${name.last} is not exists")
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
procedure
|
||||
}
|
||||
|
||||
private def buildProcedureArgs(exprs: Seq[CallArgument]): ProcedureArgs = {
|
||||
val values = new Array[Any](exprs.size)
|
||||
var isNamedArgs: Boolean = false
|
||||
val map = new util.LinkedHashMap[String, Int]()
|
||||
for (index <- exprs.indices) {
|
||||
exprs(index) match {
|
||||
case expr: NamedArgument =>
|
||||
map.put(expr.name, index)
|
||||
values(index) = expr.expr.eval()
|
||||
isNamedArgs = true
|
||||
case _ =>
|
||||
map.put(index.toString, index)
|
||||
values(index) = exprs(index).expr.eval()
|
||||
isNamedArgs = false
|
||||
}
|
||||
}
|
||||
ProcedureArgs(isNamedArgs, map, new GenericInternalRow(values))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Rule for resolve hoodie's extended syntax or rewrite some logical plan.
|
||||
*
|
||||
* @param sparkSession
|
||||
*/
|
||||
case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[LogicalPlan]
|
||||
with SparkAdapterSupport {
|
||||
private lazy val analyzer = sparkSession.sessionState.analyzer
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
|
||||
// Resolve merge into
|
||||
case mergeInto @ MergeIntoTable(target, source, mergeCondition, matchedActions, notMatchedActions)
|
||||
if sparkAdapter.isHoodieTable(target, sparkSession) && target.resolved =>
|
||||
|
||||
val resolver = sparkSession.sessionState.conf.resolver
|
||||
val resolvedSource = analyzer.execute(source)
|
||||
|
||||
def isInsertOrUpdateStar(assignments: Seq[Assignment]): Boolean = {
|
||||
if (assignments.isEmpty) {
|
||||
true
|
||||
@@ -363,6 +409,7 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
|
||||
/**
|
||||
* Check if the the query of insert statement has already append the meta fields to avoid
|
||||
* duplicate append.
|
||||
*
|
||||
* @param query
|
||||
* @return
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.command
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.hudi.command.procedures.{Procedure, ProcedureArgs}
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
|
||||
import scala.collection.Seq
|
||||
|
||||
case class CallProcedureHoodieCommand(
|
||||
procedure: Procedure,
|
||||
args: ProcedureArgs) extends HoodieLeafRunnableCommand {
|
||||
|
||||
override def output: Seq[Attribute] = procedure.outputType.toAttributes
|
||||
|
||||
override def run(sparkSession: SparkSession): Seq[Row] = {
|
||||
procedure.call(args)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.command.procedures
|
||||
|
||||
import org.apache.hudi.client.SparkRDDWriteClient
|
||||
import org.apache.hudi.client.common.HoodieSparkEngineContext
|
||||
import org.apache.hudi.common.model.HoodieRecordPayload
|
||||
import org.apache.hudi.config.{HoodieIndexConfig, HoodieWriteConfig}
|
||||
import org.apache.hudi.index.HoodieIndex.IndexType
|
||||
import org.apache.spark.api.java.JavaSparkContext
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
abstract class BaseProcedure extends Procedure {
|
||||
val INVALID_ARG_INDEX: Int = -1
|
||||
|
||||
val spark: SparkSession = SparkSession.active
|
||||
val jsc = new JavaSparkContext(spark.sparkContext)
|
||||
|
||||
protected def sparkSession: SparkSession = spark
|
||||
|
||||
protected def createHoodieClient(jsc: JavaSparkContext, basePath: String): SparkRDDWriteClient[_ <: HoodieRecordPayload[_ <: AnyRef]] = {
|
||||
val config = getWriteConfig(basePath)
|
||||
new SparkRDDWriteClient(new HoodieSparkEngineContext(jsc), config)
|
||||
}
|
||||
|
||||
protected def getWriteConfig(basePath: String): HoodieWriteConfig = {
|
||||
HoodieWriteConfig.newBuilder
|
||||
.withPath(basePath)
|
||||
.withIndexConfig(HoodieIndexConfig.newBuilder.withIndexType(IndexType.BLOOM).build)
|
||||
.withRollbackUsingMarkers(false)
|
||||
.build
|
||||
}
|
||||
|
||||
protected def checkArgs(target: Array[ProcedureParameter], args: ProcedureArgs): Unit = {
|
||||
val internalRow = args.internalRow
|
||||
for (i <- target.indices) {
|
||||
if (target(i).required) {
|
||||
var argsIndex: Integer = null
|
||||
if (args.isNamedArgs) {
|
||||
argsIndex = getArgsIndex(target(i).name, args)
|
||||
} else {
|
||||
argsIndex = getArgsIndex(i.toString, args)
|
||||
}
|
||||
assert(-1 != argsIndex && internalRow.get(argsIndex, target(i).dataType) != null,
|
||||
s"Argument: ${target(i).name} is required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected def getArgsIndex(key: String, args: ProcedureArgs): Integer = {
|
||||
args.map.getOrDefault(key, INVALID_ARG_INDEX)
|
||||
}
|
||||
|
||||
protected def getArgValueOrDefault(args: ProcedureArgs, parameter: ProcedureParameter): Any = {
|
||||
var argsIndex: Int = INVALID_ARG_INDEX
|
||||
if (args.isNamedArgs) {
|
||||
argsIndex = getArgsIndex(parameter.name, args)
|
||||
} else {
|
||||
argsIndex = getArgsIndex(parameter.index.toString, args)
|
||||
}
|
||||
if (argsIndex.equals(INVALID_ARG_INDEX)) parameter.default else getInternalRowValue(args.internalRow, argsIndex, parameter.dataType)
|
||||
}
|
||||
|
||||
protected def getInternalRowValue(row: InternalRow, index: Int, dataType: DataType): Any = {
|
||||
dataType match {
|
||||
case StringType => row.getString(index)
|
||||
case BinaryType => row.getBinary(index)
|
||||
case BooleanType => row.getBoolean(index)
|
||||
case CalendarIntervalType => row.getInterval(index)
|
||||
case DoubleType => row.getDouble(index)
|
||||
case d: DecimalType => row.getDecimal(index, d.precision, d.scale)
|
||||
case FloatType => row.getFloat(index)
|
||||
case ByteType => row.getByte(index)
|
||||
case IntegerType => row.getInt(index)
|
||||
case LongType => row.getLong(index)
|
||||
case ShortType => row.getShort(index)
|
||||
case NullType => null
|
||||
case _ =>
|
||||
throw new UnsupportedOperationException(s"type: ${dataType.typeName} not supported")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.command.procedures
|
||||
|
||||
import com.google.common.collect.ImmutableMap
|
||||
|
||||
import java.util
|
||||
import java.util.Locale
|
||||
import java.util.function.Supplier
|
||||
|
||||
object HoodieProcedures {
|
||||
private val BUILDERS: util.Map[String, Supplier[ProcedureBuilder]] = initProcedureBuilders
|
||||
|
||||
def newBuilder(name: String): ProcedureBuilder = {
|
||||
val builderSupplier: Supplier[ProcedureBuilder] = BUILDERS.get(name.toLowerCase(Locale.ROOT))
|
||||
if (builderSupplier != null) builderSupplier.get else null
|
||||
}
|
||||
|
||||
private def initProcedureBuilders: util.Map[String, Supplier[ProcedureBuilder]] = {
|
||||
val mapBuilder: ImmutableMap.Builder[String, Supplier[ProcedureBuilder]] = ImmutableMap.builder()
|
||||
mapBuilder.put(ShowCommitsProcedure.NAME, ShowCommitsProcedure.builder)
|
||||
mapBuilder.put(ShowCommitsMetadataProcedure.NAME, ShowCommitsMetadataProcedure.builder)
|
||||
mapBuilder.put(RollbackToInstantTimeProcedure.NAME, RollbackToInstantTimeProcedure.builder)
|
||||
mapBuilder.build
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.command.procedures
|
||||
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
import java.util
|
||||
import scala.collection.mutable
|
||||
|
||||
/**
|
||||
* An interface representing a stored procedure available for execution.
|
||||
*/
|
||||
trait Procedure {
|
||||
/**
|
||||
* Returns the input parameters of this procedure.
|
||||
*/
|
||||
def parameters: Array[ProcedureParameter]
|
||||
|
||||
/**
|
||||
* Returns the type of rows produced by this procedure.
|
||||
*/
|
||||
def outputType: StructType
|
||||
|
||||
/**
|
||||
* Executes this procedure.
|
||||
* <p>
|
||||
* Spark will align the provided arguments according to the input parameters
|
||||
* defined in {@link #parameters ( )} either by position or by name before execution.
|
||||
* <p>
|
||||
* Implementations may provide a summary of execution by returning one or many rows
|
||||
* as a result. The schema of output rows must match the defined output type
|
||||
* in {@link #outputType ( )}.
|
||||
*
|
||||
* @param args input arguments
|
||||
* @return the result of executing this procedure with the given arguments
|
||||
*/
|
||||
def call(args: ProcedureArgs): Seq[Row]
|
||||
|
||||
/**
|
||||
* Returns the description of this procedure.
|
||||
*/
|
||||
def description: String = this.getClass.toString
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.command.procedures
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
|
||||
import java.util
|
||||
|
||||
case class ProcedureArgs(isNamedArgs: Boolean,
|
||||
map: util.LinkedHashMap[String, Int],
|
||||
internalRow: InternalRow) {
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.command.procedures
|
||||
|
||||
trait ProcedureBuilder {
|
||||
def build: Procedure
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.command.procedures
|
||||
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
/**
|
||||
* An input parameter of a {@link Procedure stored procedure}.
|
||||
*/
|
||||
abstract class ProcedureParameter {
|
||||
def index: Int
|
||||
|
||||
/**
|
||||
* Returns the name of this parameter.
|
||||
*/
|
||||
def name: String
|
||||
|
||||
/**
|
||||
* Returns the type of this parameter.
|
||||
*/
|
||||
def dataType: DataType
|
||||
|
||||
/**
|
||||
* Returns true if this parameter is required.
|
||||
*/
|
||||
def required: Boolean
|
||||
|
||||
/**
|
||||
* this parameter's default value.
|
||||
*/
|
||||
def default: Any
|
||||
}
|
||||
|
||||
object ProcedureParameter {
|
||||
/**
|
||||
* Creates a required input parameter.
|
||||
*
|
||||
* @param name the name of the parameter
|
||||
* @param dataType the type of the parameter
|
||||
* @return the constructed stored procedure parameter
|
||||
*/
|
||||
def required(index: Int, name: String, dataType: DataType, default: Any): ProcedureParameterImpl = {
|
||||
ProcedureParameterImpl(index, name, dataType, default, required = true)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an optional input parameter.
|
||||
*
|
||||
* @param name the name of the parameter.
|
||||
* @param dataType the type of the parameter.
|
||||
* @return the constructed optional stored procedure parameter
|
||||
*/
|
||||
def optional(index: Int, name: String, dataType: DataType, default: Any): ProcedureParameterImpl = {
|
||||
ProcedureParameterImpl(index, name, dataType, default, required = false)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.command.procedures
|
||||
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
import java.util.Objects
|
||||
|
||||
case class ProcedureParameterImpl(index: Int, name: String, dataType: DataType, default: Any, required: Boolean)
|
||||
extends ProcedureParameter {
|
||||
|
||||
override def equals(other: Any): Boolean = {
|
||||
val that = other.asInstanceOf[ProcedureParameterImpl]
|
||||
val rtn = if (this == other) {
|
||||
true
|
||||
} else if (other == null || (getClass ne other.getClass)) {
|
||||
false
|
||||
} else {
|
||||
index == that.index && required == that.required && default == that.default && Objects.equals(name, that.name) && Objects.equals(dataType, that.dataType)
|
||||
}
|
||||
rtn
|
||||
}
|
||||
|
||||
override def hashCode: Int = Seq(index, name, dataType, required, default).hashCode()
|
||||
|
||||
override def toString: String = s"ProcedureParameter(index='$index',name='$name', type=$dataType, required=$required, default=$default)"
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.command.procedures
|
||||
|
||||
import org.apache.hudi.common.table.HoodieTableMetaClient
|
||||
import org.apache.hudi.common.table.timeline.HoodieTimeline
|
||||
import org.apache.hudi.common.table.timeline.versioning.TimelineLayoutVersion
|
||||
import org.apache.hudi.common.util.Option
|
||||
import org.apache.hudi.exception.HoodieException
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable
|
||||
import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType}
|
||||
|
||||
import java.util.function.Supplier
|
||||
|
||||
class RollbackToInstantTimeProcedure extends BaseProcedure with ProcedureBuilder {
|
||||
private val PARAMETERS = Array[ProcedureParameter](
|
||||
ProcedureParameter.required(0, "table", DataTypes.StringType, None),
|
||||
ProcedureParameter.required(1, "instant_time", DataTypes.StringType, None))
|
||||
|
||||
private val OUTPUT_TYPE = new StructType(Array[StructField](
|
||||
StructField("rollback_result", DataTypes.BooleanType, nullable = true, Metadata.empty))
|
||||
)
|
||||
|
||||
def parameters: Array[ProcedureParameter] = PARAMETERS
|
||||
|
||||
def outputType: StructType = OUTPUT_TYPE
|
||||
|
||||
override def call(args: ProcedureArgs): Seq[Row] = {
|
||||
super.checkArgs(PARAMETERS, args)
|
||||
|
||||
val table = getArgValueOrDefault(args, PARAMETERS(0)).asInstanceOf[String]
|
||||
val instantTime = getArgValueOrDefault(args, PARAMETERS(1)).asInstanceOf[String]
|
||||
|
||||
val hoodieCatalogTable = HoodieCatalogTable(sparkSession, new TableIdentifier(table))
|
||||
val basePath = hoodieCatalogTable.tableLocation
|
||||
val client = createHoodieClient(jsc, basePath)
|
||||
val config = getWriteConfig(basePath)
|
||||
val metaClient = HoodieTableMetaClient.builder
|
||||
.setConf(jsc.hadoopConfiguration)
|
||||
.setBasePath(config.getBasePath)
|
||||
.setLoadActiveTimelineOnLoad(false)
|
||||
.setConsistencyGuardConfig(config.getConsistencyGuardConfig)
|
||||
.setLayoutVersion(Option.of(new TimelineLayoutVersion(config.getTimelineLayoutVersion)))
|
||||
.build
|
||||
|
||||
val activeTimeline = metaClient.getActiveTimeline
|
||||
val completedTimeline: HoodieTimeline = activeTimeline.getCommitsTimeline.filterCompletedInstants
|
||||
val filteredTimeline = completedTimeline.containsInstant(instantTime)
|
||||
if (!filteredTimeline) {
|
||||
throw new HoodieException(s"Commit $instantTime not found in Commits $completedTimeline")
|
||||
}
|
||||
|
||||
val result = if (client.rollback(instantTime)) true else false
|
||||
val outputRow = Row(result)
|
||||
|
||||
Seq(outputRow)
|
||||
}
|
||||
|
||||
override def build: Procedure = new RollbackToInstantTimeProcedure()
|
||||
}
|
||||
|
||||
object RollbackToInstantTimeProcedure {
|
||||
val NAME: String = "rollback_to_instant"
|
||||
|
||||
def builder: Supplier[ProcedureBuilder] = new Supplier[ProcedureBuilder] {
|
||||
override def get(): RollbackToInstantTimeProcedure = new RollbackToInstantTimeProcedure()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi.command.procedures
|
||||
|
||||
import org.apache.hudi.common.model.HoodieCommitMetadata
|
||||
import org.apache.hudi.common.table.HoodieTableMetaClient
|
||||
import org.apache.hudi.common.table.timeline.{HoodieDefaultTimeline, HoodieInstant}
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable
|
||||
import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType}
|
||||
|
||||
import java.util
|
||||
import java.util.Collections
|
||||
import java.util.function.Supplier
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
class ShowCommitsProcedure(includeExtraMetadata: Boolean) extends BaseProcedure with ProcedureBuilder {
|
||||
var sortByFieldParameter: ProcedureParameter = _
|
||||
|
||||
private val PARAMETERS = Array[ProcedureParameter](
|
||||
ProcedureParameter.required(0, "table", DataTypes.StringType, None),
|
||||
ProcedureParameter.optional(1, "limit", DataTypes.IntegerType, 10)
|
||||
)
|
||||
|
||||
private val OUTPUT_TYPE = new StructType(Array[StructField](
|
||||
StructField("commit_time", DataTypes.StringType, nullable = true, Metadata.empty),
|
||||
StructField("total_bytes_written", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_files_added", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_files_updated", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_partitions_written", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_records_written", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_update_records_written", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_errors", DataTypes.LongType, nullable = true, Metadata.empty)
|
||||
))
|
||||
|
||||
private val METADATA_OUTPUT_TYPE = new StructType(Array[StructField](
|
||||
StructField("commit_time", DataTypes.StringType, nullable = true, Metadata.empty),
|
||||
StructField("action", DataTypes.StringType, nullable = true, Metadata.empty),
|
||||
StructField("partition", DataTypes.StringType, nullable = true, Metadata.empty),
|
||||
StructField("file_id", DataTypes.StringType, nullable = true, Metadata.empty),
|
||||
StructField("previous_commit", DataTypes.StringType, nullable = true, Metadata.empty),
|
||||
StructField("num_writes", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("num_inserts", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("num_deletes", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("num_update_writes", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_errors", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_log_blocks", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_corrupt_logblocks", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_rollback_blocks", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_log_records", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_updated_records_compacted", DataTypes.LongType, nullable = true, Metadata.empty),
|
||||
StructField("total_bytes_written", DataTypes.LongType, nullable = true, Metadata.empty)
|
||||
))
|
||||
|
||||
def parameters: Array[ProcedureParameter] = PARAMETERS
|
||||
|
||||
def outputType: StructType = if (includeExtraMetadata) METADATA_OUTPUT_TYPE else OUTPUT_TYPE
|
||||
|
||||
override def call(args: ProcedureArgs): Seq[Row] = {
|
||||
super.checkArgs(PARAMETERS, args)
|
||||
|
||||
val table = getArgValueOrDefault(args, PARAMETERS(0)).asInstanceOf[String]
|
||||
val limit = getArgValueOrDefault(args, PARAMETERS(1)).asInstanceOf[Int]
|
||||
|
||||
val hoodieCatalogTable = HoodieCatalogTable(sparkSession, new TableIdentifier(table))
|
||||
val basePath = hoodieCatalogTable.tableLocation
|
||||
val metaClient = HoodieTableMetaClient.builder.setConf(jsc.hadoopConfiguration()).setBasePath(basePath).build
|
||||
|
||||
val activeTimeline = metaClient.getActiveTimeline
|
||||
if (includeExtraMetadata) {
|
||||
getCommitsWithMetadata(activeTimeline, limit)
|
||||
} else {
|
||||
getCommits(activeTimeline, limit)
|
||||
}
|
||||
}
|
||||
|
||||
override def build: Procedure = new ShowCommitsProcedure(includeExtraMetadata)
|
||||
|
||||
private def getCommitsWithMetadata(timeline: HoodieDefaultTimeline,
|
||||
limit: Int): Seq[Row] = {
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
val (rows: util.ArrayList[Row], newCommits: util.ArrayList[HoodieInstant]) = getSortCommits(timeline)
|
||||
|
||||
for (i <- 0 until newCommits.size) {
|
||||
val commit = newCommits.get(i)
|
||||
val commitMetadata = HoodieCommitMetadata.fromBytes(timeline.getInstantDetails(commit).get, classOf[HoodieCommitMetadata])
|
||||
for (partitionWriteStat <- commitMetadata.getPartitionToWriteStats.entrySet) {
|
||||
for (hoodieWriteStat <- partitionWriteStat.getValue) {
|
||||
rows.add(Row(
|
||||
commit.getTimestamp, commit.getAction, hoodieWriteStat.getPartitionPath,
|
||||
hoodieWriteStat.getFileId, hoodieWriteStat.getPrevCommit, hoodieWriteStat.getNumWrites,
|
||||
hoodieWriteStat.getNumInserts, hoodieWriteStat.getNumDeletes, hoodieWriteStat.getNumUpdateWrites,
|
||||
hoodieWriteStat.getTotalWriteErrors, hoodieWriteStat.getTotalLogBlocks, hoodieWriteStat.getTotalCorruptLogBlock,
|
||||
hoodieWriteStat.getTotalRollbackBlocks, hoodieWriteStat.getTotalLogRecords,
|
||||
hoodieWriteStat.getTotalUpdatedRecordsCompacted, hoodieWriteStat.getTotalWriteBytes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rows.stream().limit(limit).toArray().map(r => r.asInstanceOf[Row]).toList
|
||||
}
|
||||
|
||||
private def getSortCommits(timeline: HoodieDefaultTimeline): (util.ArrayList[Row], util.ArrayList[HoodieInstant]) = {
|
||||
val rows = new util.ArrayList[Row]
|
||||
// timeline can be read from multiple files. So sort is needed instead of reversing the collection
|
||||
val commits: util.List[HoodieInstant] = timeline.getCommitsTimeline.filterCompletedInstants
|
||||
.getInstants.toArray().map(instant => instant.asInstanceOf[HoodieInstant]).toList.asJava
|
||||
val newCommits = new util.ArrayList[HoodieInstant](commits)
|
||||
Collections.sort(newCommits, HoodieInstant.COMPARATOR.reversed)
|
||||
(rows, newCommits)
|
||||
}
|
||||
|
||||
def getCommits(timeline: HoodieDefaultTimeline,
|
||||
limit: Int): Seq[Row] = {
|
||||
val (rows: util.ArrayList[Row], newCommits: util.ArrayList[HoodieInstant]) = getSortCommits(timeline)
|
||||
|
||||
for (i <- 0 until newCommits.size) {
|
||||
val commit = newCommits.get(i)
|
||||
val commitMetadata = HoodieCommitMetadata.fromBytes(timeline.getInstantDetails(commit).get, classOf[HoodieCommitMetadata])
|
||||
rows.add(Row(commit.getTimestamp, commitMetadata.fetchTotalBytesWritten, commitMetadata.fetchTotalFilesInsert,
|
||||
commitMetadata.fetchTotalFilesUpdated, commitMetadata.fetchTotalPartitionsWritten,
|
||||
commitMetadata.fetchTotalRecordsWritten, commitMetadata.fetchTotalUpdateRecordsWritten,
|
||||
commitMetadata.fetchTotalWriteErrors))
|
||||
}
|
||||
|
||||
rows.stream().limit(limit).toArray().map(r => r.asInstanceOf[Row]).toList
|
||||
}
|
||||
}
|
||||
|
||||
object ShowCommitsProcedure {
|
||||
val NAME = "show_commits"
|
||||
|
||||
def builder: Supplier[ProcedureBuilder] = new Supplier[ProcedureBuilder] {
|
||||
override def get() = new ShowCommitsProcedure(false)
|
||||
}
|
||||
}
|
||||
|
||||
object ShowCommitsMetadataProcedure {
|
||||
val NAME = "show_commits_metadata"
|
||||
|
||||
def builder: Supplier[ProcedureBuilder] = new Supplier[ProcedureBuilder] {
|
||||
override def get() = new ShowCommitsProcedure(true)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,22 +17,39 @@
|
||||
|
||||
package org.apache.spark.sql.parser
|
||||
|
||||
import org.antlr.v4.runtime.ParserRuleContext
|
||||
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
|
||||
import org.apache.hudi.SparkAdapterSupport
|
||||
import org.apache.hudi.spark.sql.parser.{HoodieSqlCommonBaseVisitor, HoodieSqlCommonParser}
|
||||
import org.apache.hudi.spark.sql.parser.HoodieSqlCommonParser.{CompactionOnPathContext, CompactionOnTableContext, ShowCompactionOnPathContext, ShowCompactionOnTableContext, SingleStatementContext, TableIdentifierContext}
|
||||
import org.apache.hudi.spark.sql.parser.HoodieSqlCommonBaseVisitor
|
||||
import org.apache.hudi.spark.sql.parser.HoodieSqlCommonParser._
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
||||
import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin
|
||||
import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{CompactionOperation, CompactionPath, CompactionShowOnPath, CompactionShowOnTable, CompactionTable, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
|
||||
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface, ParserUtils}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
class HoodieSqlCommonAstBuilder(session: SparkSession, delegate: ParserInterface)
|
||||
extends HoodieSqlCommonBaseVisitor[AnyRef] with Logging with SparkAdapterSupport {
|
||||
|
||||
import ParserUtils._
|
||||
|
||||
/**
|
||||
* Override the default behavior for all visit methods. This will only return a non-null result
|
||||
* when the context has only one child. This is done because there is no generic method to
|
||||
* combine the results of the context children. In all other cases null is returned.
|
||||
*/
|
||||
override def visitChildren(node: RuleNode): AnyRef = {
|
||||
if (node.getChildCount == 1) {
|
||||
node.getChild(0).accept(this)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
|
||||
ctx.statement().accept(this).asInstanceOf[LogicalPlan]
|
||||
}
|
||||
@@ -72,4 +89,62 @@ class HoodieSqlCommonAstBuilder(session: SparkSession, delegate: ParserInterface
|
||||
override def visitTableIdentifier(ctx: TableIdentifierContext): LogicalPlan = withOrigin(ctx) {
|
||||
UnresolvedRelation(TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)))
|
||||
}
|
||||
|
||||
override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) {
|
||||
if (ctx.callArgument().isEmpty) {
|
||||
throw new ParseException(s"Procedures arguments is empty", ctx)
|
||||
}
|
||||
|
||||
val name: Seq[String] = ctx.multipartIdentifier().parts.asScala.map(_.getText)
|
||||
val args: Seq[CallArgument] = ctx.callArgument().asScala.map(typedVisit[CallArgument])
|
||||
CallCommand(name, args)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a multi-part identifier as Seq[String].
|
||||
*/
|
||||
override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) {
|
||||
ctx.parts.asScala.map(_.getText)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a positional argument in a stored procedure call.
|
||||
*/
|
||||
override def visitPositionalArgument(ctx: PositionalArgumentContext): CallArgument = withOrigin(ctx) {
|
||||
val expr = typedVisit[Expression](ctx.expression)
|
||||
PositionalArgument(expr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a named argument in a stored procedure call.
|
||||
*/
|
||||
override def visitNamedArgument(ctx: NamedArgumentContext): CallArgument = withOrigin(ctx) {
|
||||
val name = ctx.identifier.getText
|
||||
val expr = typedVisit[Expression](ctx.expression)
|
||||
NamedArgument(name, expr)
|
||||
}
|
||||
|
||||
def visitConstant(ctx: ConstantContext): Literal = {
|
||||
delegate.parseExpression(ctx.getText).asInstanceOf[Literal]
|
||||
}
|
||||
|
||||
override def visitExpression(ctx: ExpressionContext): Expression = {
|
||||
// reconstruct the SQL string and parse it using the main Spark parser
|
||||
// while we can avoid the logic to build Spark expressions, we still have to parse them
|
||||
// we cannot call ctx.getText directly since it will not render spaces correctly
|
||||
// that's why we need to recurse down the tree in reconstructSqlString
|
||||
val sqlString = reconstructSqlString(ctx)
|
||||
delegate.parseExpression(sqlString)
|
||||
}
|
||||
|
||||
private def reconstructSqlString(ctx: ParserRuleContext): String = {
|
||||
ctx.children.asScala.map {
|
||||
case c: ParserRuleContext => reconstructSqlString(c)
|
||||
case t: TerminalNode => t.getText
|
||||
}.mkString(" ")
|
||||
}
|
||||
|
||||
private def typedVisit[T](ctx: ParseTree): T = {
|
||||
ctx.accept(this).asInstanceOf[T]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi
|
||||
|
||||
import com.google.common.collect.ImmutableList
|
||||
import org.apache.spark.sql.catalyst.expressions.Literal
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{CallCommand, NamedArgument, PositionalArgument}
|
||||
import org.apache.spark.sql.types.{DataType, DataTypes}
|
||||
|
||||
import java.math.BigDecimal
|
||||
import scala.collection.JavaConverters
|
||||
|
||||
class TestCallCommandParser extends TestHoodieSqlBase {
|
||||
private val parser = spark.sessionState.sqlParser
|
||||
|
||||
test("Test Call Produce with Positional Arguments") {
|
||||
val call = parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)").asInstanceOf[CallCommand]
|
||||
assertResult(ImmutableList.of("c", "n", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
|
||||
|
||||
assertResult(7)(call.args.size)
|
||||
|
||||
checkArg(call, 0, 1, DataTypes.IntegerType)
|
||||
checkArg(call, 1, "2", DataTypes.StringType)
|
||||
checkArg(call, 2, 3L, DataTypes.LongType)
|
||||
checkArg(call, 3, true, DataTypes.BooleanType)
|
||||
checkArg(call, 4, 1.0D, DataTypes.DoubleType)
|
||||
checkArg(call, 5, new BigDecimal("9.0e1"), DataTypes.createDecimalType(2, 0))
|
||||
checkArg(call, 6, new BigDecimal("900e-1"), DataTypes.createDecimalType(3, 1))
|
||||
}
|
||||
|
||||
test("Test Call Produce with Named Arguments") {
|
||||
val call = parser.parsePlan("CALL system.func(c1 => 1, c2 => '2', c3 => true)").asInstanceOf[CallCommand]
|
||||
assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
|
||||
|
||||
assertResult(3)(call.args.size)
|
||||
|
||||
checkArg(call, 0, "c1", 1, DataTypes.IntegerType)
|
||||
checkArg(call, 1, "c2", "2", DataTypes.StringType)
|
||||
checkArg(call, 2, "c3", true, DataTypes.BooleanType)
|
||||
}
|
||||
|
||||
test("Test Call Produce with Var Substitution") {
|
||||
val call = parser.parsePlan("CALL system.func('${spark.extra.prop}')").asInstanceOf[CallCommand]
|
||||
assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
|
||||
|
||||
assertResult(1)(call.args.size)
|
||||
|
||||
checkArg(call, 0, "value", DataTypes.StringType)
|
||||
}
|
||||
|
||||
test("Test Call Produce with Mixed Arguments") {
|
||||
val call = parser.parsePlan("CALL system.func(c1 => 1, '2')").asInstanceOf[CallCommand]
|
||||
assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
|
||||
|
||||
assertResult(2)(call.args.size)
|
||||
|
||||
checkArg(call, 0, "c1", 1, DataTypes.IntegerType)
|
||||
checkArg(call, 1, "2", DataTypes.StringType)
|
||||
}
|
||||
|
||||
test("Test Call Parse Error") {
|
||||
checkParseExceptionContain("CALL cat.system radish kebab")("mismatched input 'CALL' expecting")
|
||||
}
|
||||
|
||||
protected def checkParseExceptionContain(sql: String)(errorMsg: String): Unit = {
|
||||
var hasException = false
|
||||
try {
|
||||
parser.parsePlan(sql)
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
assertResult(true)(e.getMessage.contains(errorMsg))
|
||||
hasException = true
|
||||
}
|
||||
assertResult(true)(hasException)
|
||||
}
|
||||
|
||||
private def checkArg(call: CallCommand, index: Int, expectedValue: Any, expectedType: DataType): Unit = {
|
||||
checkArg(call, index, null, expectedValue, expectedType)
|
||||
}
|
||||
|
||||
private def checkArg(call: CallCommand, index: Int, expectedName: String, expectedValue: Any, expectedType: DataType): Unit = {
|
||||
if (expectedName != null) {
|
||||
val arg = checkCast(call.args.apply(index), classOf[NamedArgument])
|
||||
assertResult(expectedName)(arg.name)
|
||||
}
|
||||
else {
|
||||
val arg = call.args.apply(index)
|
||||
checkCast(arg, classOf[PositionalArgument])
|
||||
}
|
||||
val expectedExpr = toSparkLiteral(expectedValue, expectedType)
|
||||
val actualExpr = call.args.apply(index).expr
|
||||
assertResult(expectedExpr.dataType)(actualExpr.dataType)
|
||||
}
|
||||
|
||||
private def toSparkLiteral(value: Any, dataType: DataType) = Literal.apply(value, dataType)
|
||||
|
||||
private def checkCast[T](value: Any, expectedClass: Class[T]) = {
|
||||
assertResult(true)(expectedClass.isInstance(value))
|
||||
expectedClass.cast(value)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hudi
|
||||
|
||||
class TestCallProcedure extends TestHoodieSqlBase {
|
||||
|
||||
test("Test Call show_commits Procedure") {
|
||||
withTempDir { tmp =>
|
||||
val tableName = generateTableName
|
||||
// create table
|
||||
spark.sql(
|
||||
s"""
|
||||
|create table $tableName (
|
||||
| id int,
|
||||
| name string,
|
||||
| price double,
|
||||
| ts long
|
||||
|) using hudi
|
||||
| location '${tmp.getCanonicalPath}/$tableName'
|
||||
| tblproperties (
|
||||
| primaryKey = 'id',
|
||||
| preCombineField = 'ts'
|
||||
| )
|
||||
""".stripMargin)
|
||||
// insert data to table
|
||||
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
|
||||
spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500")
|
||||
|
||||
// Check required fields
|
||||
checkExceptionContain(s"""call show_commits(limit => 10)""")(
|
||||
s"Argument: table is required")
|
||||
|
||||
// collect commits for table
|
||||
val commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect()
|
||||
assertResult(2) {
|
||||
commits.length
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Test Call show_commits_metadata Procedure") {
|
||||
withTempDir { tmp =>
|
||||
val tableName = generateTableName
|
||||
// create table
|
||||
spark.sql(
|
||||
s"""
|
||||
|create table $tableName (
|
||||
| id int,
|
||||
| name string,
|
||||
| price double,
|
||||
| ts long
|
||||
|) using hudi
|
||||
| location '${tmp.getCanonicalPath}/$tableName'
|
||||
| tblproperties (
|
||||
| primaryKey = 'id',
|
||||
| preCombineField = 'ts'
|
||||
| )
|
||||
""".stripMargin)
|
||||
// insert data to table
|
||||
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
|
||||
|
||||
// Check required fields
|
||||
checkExceptionContain(s"""call show_commits_metadata(limit => 10)""")(
|
||||
s"Argument: table is required")
|
||||
|
||||
// collect commits for table
|
||||
val commits = spark.sql(s"""call show_commits_metadata(table => '$tableName', limit => 10)""").collect()
|
||||
assertResult(1) {
|
||||
commits.length
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Test Call rollback_to_instant Procedure") {
|
||||
withTempDir { tmp =>
|
||||
val tableName = generateTableName
|
||||
// create table
|
||||
spark.sql(
|
||||
s"""
|
||||
|create table $tableName (
|
||||
| id int,
|
||||
| name string,
|
||||
| price double,
|
||||
| ts long
|
||||
|) using hudi
|
||||
| location '${tmp.getCanonicalPath}/$tableName'
|
||||
| tblproperties (
|
||||
| primaryKey = 'id',
|
||||
| preCombineField = 'ts'
|
||||
| )
|
||||
""".stripMargin)
|
||||
// insert data to table
|
||||
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
|
||||
spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500")
|
||||
spark.sql(s"insert into $tableName select 3, 'a3', 30, 2000")
|
||||
|
||||
// Check required fields
|
||||
checkExceptionContain(s"""call rollback_to_instant(table => '$tableName')""")(
|
||||
s"Argument: instant_time is required")
|
||||
|
||||
// 3 commits are left before rollback
|
||||
var commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect()
|
||||
assertResult(3){commits.length}
|
||||
|
||||
// Call rollback_to_instant Procedure with Named Arguments
|
||||
var instant_time = commits(0).get(0).toString
|
||||
checkAnswer(s"""call rollback_to_instant(table => '$tableName', instant_time => '$instant_time')""")(Seq(true))
|
||||
// Call rollback_to_instant Procedure with Positional Arguments
|
||||
instant_time = commits(1).get(0).toString
|
||||
checkAnswer(s"""call rollback_to_instant('$tableName', '$instant_time')""")(Seq(true))
|
||||
|
||||
// 1 commits are left after rollback
|
||||
commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect()
|
||||
assertResult(1){commits.length}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -244,4 +244,4 @@
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
</project>
|
||||
|
||||
Reference in New Issue
Block a user