1
0

[HUDI-3161] Add Call Produce Command for Spark SQL (#4535)

This commit is contained in:
ForwardXu
2022-02-24 23:45:37 +08:00
committed by GitHub
parent 943b99775b
commit 521338b4d9
17 changed files with 1228 additions and 28 deletions

View File

@@ -14,59 +14,197 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
grammar HoodieSqlCommon;
grammar HoodieSqlCommon;
@lexer::members {
/**
* Verify whether current token is a valid decimal token (which contains dot).
* Returns true if the character that follows the token is not a digit or letter or underscore.
*
* For example:
* For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'.
* For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'.
* For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'.
* For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed
* by a space. 34.E2 is a valid decimal token because it is followed by symbol '+'
* which is not a digit or letter or underscore.
*/
public boolean isValidDecimal() {
int nextChar = _input.LA(1);
if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' ||
nextChar == '_') {
return false;
} else {
return true;
}
}
}
singleStatement
: statement EOF
;
statement
: compactionStatement #compactionCommand
| .*? #passThrough
: compactionStatement #compactionCommand
| CALL multipartIdentifier '(' (callArgument (',' callArgument)*)? ')' #call
| .*? #passThrough
;
compactionStatement
: operation = (RUN | SCHEDULE) COMPACTION ON tableIdentifier (AT instantTimestamp = NUMBER)? #compactionOnTable
| operation = (RUN | SCHEDULE) COMPACTION ON path = STRING (AT instantTimestamp = NUMBER)? #compactionOnPath
| SHOW COMPACTION ON tableIdentifier (LIMIT limit = NUMBER)? #showCompactionOnTable
| SHOW COMPACTION ON path = STRING (LIMIT limit = NUMBER)? #showCompactionOnPath
: operation = (RUN | SCHEDULE) COMPACTION ON tableIdentifier (AT instantTimestamp = INTEGER_VALUE)? #compactionOnTable
| operation = (RUN | SCHEDULE) COMPACTION ON path = STRING (AT instantTimestamp = INTEGER_VALUE)? #compactionOnPath
| SHOW COMPACTION ON tableIdentifier (LIMIT limit = INTEGER_VALUE)? #showCompactionOnTable
| SHOW COMPACTION ON path = STRING (LIMIT limit = INTEGER_VALUE)? #showCompactionOnPath
;
tableIdentifier
: (db=IDENTIFIER '.')? table=IDENTIFIER
;
callArgument
: expression #positionalArgument
| identifier '=>' expression #namedArgument
;
expression
: constant
| stringMap
;
constant
: number #numericLiteral
| booleanValue #booleanLiteral
| STRING+ #stringLiteral
| identifier STRING #typeConstructor
;
stringMap
: MAP '(' constant (',' constant)* ')'
;
booleanValue
: TRUE | FALSE
;
number
: MINUS? EXPONENT_VALUE #exponentLiteral
| MINUS? DECIMAL_VALUE #decimalLiteral
| MINUS? INTEGER_VALUE #integerLiteral
| MINUS? BIGINT_LITERAL #bigIntLiteral
| MINUS? SMALLINT_LITERAL #smallIntLiteral
| MINUS? TINYINT_LITERAL #tinyIntLiteral
| MINUS? DOUBLE_LITERAL #doubleLiteral
| MINUS? FLOAT_LITERAL #floatLiteral
| MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral
;
multipartIdentifier
: parts+=identifier ('.' parts+=identifier)*
;
identifier
: IDENTIFIER #unquotedIdentifier
| quotedIdentifier #quotedIdentifierAlternative
| nonReserved #unquotedIdentifier
;
quotedIdentifier
: BACKQUOTED_IDENTIFIER
;
nonReserved
: CALL | COMPACTION | RUN | SCHEDULE | ON | SHOW | LIMIT
;
ALL: 'ALL';
AT: 'AT';
CALL: 'CALL';
COMPACTION: 'COMPACTION';
RUN: 'RUN';
SCHEDULE: 'SCHEDULE';
ON: 'ON';
SHOW: 'SHOW';
LIMIT: 'LIMIT';
MAP: 'MAP';
NULL: 'NULL';
TRUE: 'TRUE';
FALSE: 'FALSE';
INTERVAL: 'INTERVAL';
TO: 'TO';
NUMBER
: DIGIT+
;
IDENTIFIER
: (LETTER | DIGIT | '_')+
;
PLUS: '+';
MINUS: '-';
STRING
: '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
| '"' ( ~('"'|'\\') | ('\\' .) )* '"'
;
BIGINT_LITERAL
: DIGIT+ 'L'
;
SMALLINT_LITERAL
: DIGIT+ 'S'
;
TINYINT_LITERAL
: DIGIT+ 'Y'
;
INTEGER_VALUE
: DIGIT+
;
EXPONENT_VALUE
: DIGIT+ EXPONENT
| DECIMAL_DIGITS EXPONENT {isValidDecimal()}?
;
DECIMAL_VALUE
: DECIMAL_DIGITS {isValidDecimal()}?
;
FLOAT_LITERAL
: DIGIT+ EXPONENT? 'F'
| DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}?
;
DOUBLE_LITERAL
: DIGIT+ EXPONENT? 'D'
| DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}?
;
BIGDECIMAL_LITERAL
: DIGIT+ EXPONENT? 'BD'
| DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}?
;
IDENTIFIER
: (LETTER | DIGIT | '_')+
;
BACKQUOTED_IDENTIFIER
: '`' ( ~'`' | '``' )* '`'
;
fragment DECIMAL_DIGITS
: DIGIT+ '.' DIGIT*
| '.' DIGIT+
;
fragment EXPONENT
: 'E' [+-]? DIGIT+
;
fragment DIGIT
: [0-9]
;
: [0-9]
;
fragment LETTER
: [A-Z]
;
: [A-Z]
;
SIMPLE_COMMENT
: '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN)

View File

@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.Expression
case class CallCommand(name: Seq[String], args: Seq[CallArgument]) extends Command {
override def children: Seq[LogicalPlan] = Seq.empty
def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): CallCommand = {
this
}
}
/**
* An argument in a CALL statement.
*/
sealed trait CallArgument {
def expr: Expression
}
/**
* An argument in a CALL statement identified by name.
*/
case class NamedArgument(name: String, expr: Expression) extends CallArgument
/**
* An argument in a CALL statement identified by position.
*/
case class PositionalArgument(expr: Expression) extends CallArgument

View File

@@ -22,7 +22,7 @@ import org.apache.hudi.common.model.HoodieRecord
import org.apache.hudi.common.util.ReflectionUtils
import org.apache.hudi.{HoodieSparkUtils, SparkAdapterSupport}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, GenericInternalRow, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@@ -31,10 +31,12 @@ import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation}
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{getTableIdentifier, removeMetaFields}
import org.apache.spark.sql.hudi.HoodieSqlUtils._
import org.apache.spark.sql.hudi.command._
import org.apache.spark.sql.hudi.command.procedures.{HoodieProcedures, Procedure, ProcedureArgs}
import org.apache.spark.sql.hudi.{HoodieOptionConfig, HoodieSqlCommonUtils}
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{AnalysisException, SparkSession}
import java.util
import scala.collection.JavaConverters._
object HoodieAnalysis {
@@ -79,6 +81,7 @@ object HoodieAnalysis {
/**
* Rule for convert the logical plan to command.
*
* @param sparkSession
*/
case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan]
@@ -133,26 +136,69 @@ case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan]
// Convert to CompactionShowHoodiePathCommand
case CompactionShowOnPath(path, limit) =>
CompactionShowHoodiePathCommand(path, limit)
case _=> plan
// Convert to HoodieCallProcedureCommand
case c@CallCommand(_, _) =>
val procedure: Option[Procedure] = loadProcedure(c.name)
val input = buildProcedureArgs(c.args)
if (procedure.nonEmpty) {
CallProcedureHoodieCommand(procedure.get, input)
} else {
c
}
case _ => plan
}
}
private def loadProcedure(name: Seq[String]): Option[Procedure] = {
val procedure: Option[Procedure] = if (name.nonEmpty) {
val builder = HoodieProcedures.newBuilder(name.last)
if (builder != null) {
Option(builder.build)
} else {
throw new AnalysisException(s"procedure: ${name.last} is not exists")
}
} else {
None
}
procedure
}
private def buildProcedureArgs(exprs: Seq[CallArgument]): ProcedureArgs = {
val values = new Array[Any](exprs.size)
var isNamedArgs: Boolean = false
val map = new util.LinkedHashMap[String, Int]()
for (index <- exprs.indices) {
exprs(index) match {
case expr: NamedArgument =>
map.put(expr.name, index)
values(index) = expr.expr.eval()
isNamedArgs = true
case _ =>
map.put(index.toString, index)
values(index) = exprs(index).expr.eval()
isNamedArgs = false
}
}
ProcedureArgs(isNamedArgs, map, new GenericInternalRow(values))
}
}
/**
* Rule for resolve hoodie's extended syntax or rewrite some logical plan.
*
* @param sparkSession
*/
case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[LogicalPlan]
with SparkAdapterSupport {
private lazy val analyzer = sparkSession.sessionState.analyzer
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
// Resolve merge into
case mergeInto @ MergeIntoTable(target, source, mergeCondition, matchedActions, notMatchedActions)
if sparkAdapter.isHoodieTable(target, sparkSession) && target.resolved =>
val resolver = sparkSession.sessionState.conf.resolver
val resolvedSource = analyzer.execute(source)
def isInsertOrUpdateStar(assignments: Seq[Assignment]): Boolean = {
if (assignments.isEmpty) {
true
@@ -363,6 +409,7 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
/**
* Check if the the query of insert statement has already append the meta fields to avoid
* duplicate append.
*
* @param query
* @return
*/

View File

@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.hudi.command.procedures.{Procedure, ProcedureArgs}
import org.apache.spark.sql.{Row, SparkSession}
import scala.collection.Seq
case class CallProcedureHoodieCommand(
procedure: Procedure,
args: ProcedureArgs) extends HoodieLeafRunnableCommand {
override def output: Seq[Attribute] = procedure.outputType.toAttributes
override def run(sparkSession: SparkSession): Seq[Row] = {
procedure.call(args)
}
}

View File

@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.procedures
import org.apache.hudi.client.SparkRDDWriteClient
import org.apache.hudi.client.common.HoodieSparkEngineContext
import org.apache.hudi.common.model.HoodieRecordPayload
import org.apache.hudi.config.{HoodieIndexConfig, HoodieWriteConfig}
import org.apache.hudi.index.HoodieIndex.IndexType
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
abstract class BaseProcedure extends Procedure {
val INVALID_ARG_INDEX: Int = -1
val spark: SparkSession = SparkSession.active
val jsc = new JavaSparkContext(spark.sparkContext)
protected def sparkSession: SparkSession = spark
protected def createHoodieClient(jsc: JavaSparkContext, basePath: String): SparkRDDWriteClient[_ <: HoodieRecordPayload[_ <: AnyRef]] = {
val config = getWriteConfig(basePath)
new SparkRDDWriteClient(new HoodieSparkEngineContext(jsc), config)
}
protected def getWriteConfig(basePath: String): HoodieWriteConfig = {
HoodieWriteConfig.newBuilder
.withPath(basePath)
.withIndexConfig(HoodieIndexConfig.newBuilder.withIndexType(IndexType.BLOOM).build)
.withRollbackUsingMarkers(false)
.build
}
protected def checkArgs(target: Array[ProcedureParameter], args: ProcedureArgs): Unit = {
val internalRow = args.internalRow
for (i <- target.indices) {
if (target(i).required) {
var argsIndex: Integer = null
if (args.isNamedArgs) {
argsIndex = getArgsIndex(target(i).name, args)
} else {
argsIndex = getArgsIndex(i.toString, args)
}
assert(-1 != argsIndex && internalRow.get(argsIndex, target(i).dataType) != null,
s"Argument: ${target(i).name} is required")
}
}
}
protected def getArgsIndex(key: String, args: ProcedureArgs): Integer = {
args.map.getOrDefault(key, INVALID_ARG_INDEX)
}
protected def getArgValueOrDefault(args: ProcedureArgs, parameter: ProcedureParameter): Any = {
var argsIndex: Int = INVALID_ARG_INDEX
if (args.isNamedArgs) {
argsIndex = getArgsIndex(parameter.name, args)
} else {
argsIndex = getArgsIndex(parameter.index.toString, args)
}
if (argsIndex.equals(INVALID_ARG_INDEX)) parameter.default else getInternalRowValue(args.internalRow, argsIndex, parameter.dataType)
}
protected def getInternalRowValue(row: InternalRow, index: Int, dataType: DataType): Any = {
dataType match {
case StringType => row.getString(index)
case BinaryType => row.getBinary(index)
case BooleanType => row.getBoolean(index)
case CalendarIntervalType => row.getInterval(index)
case DoubleType => row.getDouble(index)
case d: DecimalType => row.getDecimal(index, d.precision, d.scale)
case FloatType => row.getFloat(index)
case ByteType => row.getByte(index)
case IntegerType => row.getInt(index)
case LongType => row.getLong(index)
case ShortType => row.getShort(index)
case NullType => null
case _ =>
throw new UnsupportedOperationException(s"type: ${dataType.typeName} not supported")
}
}
}

View File

@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.procedures
import com.google.common.collect.ImmutableMap
import java.util
import java.util.Locale
import java.util.function.Supplier
object HoodieProcedures {
private val BUILDERS: util.Map[String, Supplier[ProcedureBuilder]] = initProcedureBuilders
def newBuilder(name: String): ProcedureBuilder = {
val builderSupplier: Supplier[ProcedureBuilder] = BUILDERS.get(name.toLowerCase(Locale.ROOT))
if (builderSupplier != null) builderSupplier.get else null
}
private def initProcedureBuilders: util.Map[String, Supplier[ProcedureBuilder]] = {
val mapBuilder: ImmutableMap.Builder[String, Supplier[ProcedureBuilder]] = ImmutableMap.builder()
mapBuilder.put(ShowCommitsProcedure.NAME, ShowCommitsProcedure.builder)
mapBuilder.put(ShowCommitsMetadataProcedure.NAME, ShowCommitsMetadataProcedure.builder)
mapBuilder.put(RollbackToInstantTimeProcedure.NAME, RollbackToInstantTimeProcedure.builder)
mapBuilder.build
}
}

View File

@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.procedures
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
import java.util
import scala.collection.mutable
/**
* An interface representing a stored procedure available for execution.
*/
trait Procedure {
/**
* Returns the input parameters of this procedure.
*/
def parameters: Array[ProcedureParameter]
/**
* Returns the type of rows produced by this procedure.
*/
def outputType: StructType
/**
* Executes this procedure.
* <p>
* Spark will align the provided arguments according to the input parameters
* defined in {@link #parameters ( )} either by position or by name before execution.
* <p>
* Implementations may provide a summary of execution by returning one or many rows
* as a result. The schema of output rows must match the defined output type
* in {@link #outputType ( )}.
*
* @param args input arguments
* @return the result of executing this procedure with the given arguments
*/
def call(args: ProcedureArgs): Seq[Row]
/**
* Returns the description of this procedure.
*/
def description: String = this.getClass.toString
}

View File

@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.procedures
import org.apache.spark.sql.catalyst.InternalRow
import java.util
case class ProcedureArgs(isNamedArgs: Boolean,
map: util.LinkedHashMap[String, Int],
internalRow: InternalRow) {
}

View File

@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.procedures
trait ProcedureBuilder {
def build: Procedure
}

View File

@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.procedures
import org.apache.spark.sql.types.DataType
/**
* An input parameter of a {@link Procedure stored procedure}.
*/
abstract class ProcedureParameter {
def index: Int
/**
* Returns the name of this parameter.
*/
def name: String
/**
* Returns the type of this parameter.
*/
def dataType: DataType
/**
* Returns true if this parameter is required.
*/
def required: Boolean
/**
* this parameter's default value.
*/
def default: Any
}
object ProcedureParameter {
/**
* Creates a required input parameter.
*
* @param name the name of the parameter
* @param dataType the type of the parameter
* @return the constructed stored procedure parameter
*/
def required(index: Int, name: String, dataType: DataType, default: Any): ProcedureParameterImpl = {
ProcedureParameterImpl(index, name, dataType, default, required = true)
}
/**
* Creates an optional input parameter.
*
* @param name the name of the parameter.
* @param dataType the type of the parameter.
* @return the constructed optional stored procedure parameter
*/
def optional(index: Int, name: String, dataType: DataType, default: Any): ProcedureParameterImpl = {
ProcedureParameterImpl(index, name, dataType, default, required = false)
}
}

View File

@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.procedures
import org.apache.spark.sql.types.DataType
import java.util.Objects
case class ProcedureParameterImpl(index: Int, name: String, dataType: DataType, default: Any, required: Boolean)
extends ProcedureParameter {
override def equals(other: Any): Boolean = {
val that = other.asInstanceOf[ProcedureParameterImpl]
val rtn = if (this == other) {
true
} else if (other == null || (getClass ne other.getClass)) {
false
} else {
index == that.index && required == that.required && default == that.default && Objects.equals(name, that.name) && Objects.equals(dataType, that.dataType)
}
rtn
}
override def hashCode: Int = Seq(index, name, dataType, required, default).hashCode()
override def toString: String = s"ProcedureParameter(index='$index',name='$name', type=$dataType, required=$required, default=$default)"
}

View File

@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.procedures
import org.apache.hudi.common.table.HoodieTableMetaClient
import org.apache.hudi.common.table.timeline.HoodieTimeline
import org.apache.hudi.common.table.timeline.versioning.TimelineLayoutVersion
import org.apache.hudi.common.util.Option
import org.apache.hudi.exception.HoodieException
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable
import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType}
import java.util.function.Supplier
class RollbackToInstantTimeProcedure extends BaseProcedure with ProcedureBuilder {
private val PARAMETERS = Array[ProcedureParameter](
ProcedureParameter.required(0, "table", DataTypes.StringType, None),
ProcedureParameter.required(1, "instant_time", DataTypes.StringType, None))
private val OUTPUT_TYPE = new StructType(Array[StructField](
StructField("rollback_result", DataTypes.BooleanType, nullable = true, Metadata.empty))
)
def parameters: Array[ProcedureParameter] = PARAMETERS
def outputType: StructType = OUTPUT_TYPE
override def call(args: ProcedureArgs): Seq[Row] = {
super.checkArgs(PARAMETERS, args)
val table = getArgValueOrDefault(args, PARAMETERS(0)).asInstanceOf[String]
val instantTime = getArgValueOrDefault(args, PARAMETERS(1)).asInstanceOf[String]
val hoodieCatalogTable = HoodieCatalogTable(sparkSession, new TableIdentifier(table))
val basePath = hoodieCatalogTable.tableLocation
val client = createHoodieClient(jsc, basePath)
val config = getWriteConfig(basePath)
val metaClient = HoodieTableMetaClient.builder
.setConf(jsc.hadoopConfiguration)
.setBasePath(config.getBasePath)
.setLoadActiveTimelineOnLoad(false)
.setConsistencyGuardConfig(config.getConsistencyGuardConfig)
.setLayoutVersion(Option.of(new TimelineLayoutVersion(config.getTimelineLayoutVersion)))
.build
val activeTimeline = metaClient.getActiveTimeline
val completedTimeline: HoodieTimeline = activeTimeline.getCommitsTimeline.filterCompletedInstants
val filteredTimeline = completedTimeline.containsInstant(instantTime)
if (!filteredTimeline) {
throw new HoodieException(s"Commit $instantTime not found in Commits $completedTimeline")
}
val result = if (client.rollback(instantTime)) true else false
val outputRow = Row(result)
Seq(outputRow)
}
override def build: Procedure = new RollbackToInstantTimeProcedure()
}
object RollbackToInstantTimeProcedure {
val NAME: String = "rollback_to_instant"
def builder: Supplier[ProcedureBuilder] = new Supplier[ProcedureBuilder] {
override def get(): RollbackToInstantTimeProcedure = new RollbackToInstantTimeProcedure()
}
}

View File

@@ -0,0 +1,163 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.procedures
import org.apache.hudi.common.model.HoodieCommitMetadata
import org.apache.hudi.common.table.HoodieTableMetaClient
import org.apache.hudi.common.table.timeline.{HoodieDefaultTimeline, HoodieInstant}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable
import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType}
import java.util
import java.util.Collections
import java.util.function.Supplier
import scala.collection.JavaConverters._
class ShowCommitsProcedure(includeExtraMetadata: Boolean) extends BaseProcedure with ProcedureBuilder {
var sortByFieldParameter: ProcedureParameter = _
private val PARAMETERS = Array[ProcedureParameter](
ProcedureParameter.required(0, "table", DataTypes.StringType, None),
ProcedureParameter.optional(1, "limit", DataTypes.IntegerType, 10)
)
private val OUTPUT_TYPE = new StructType(Array[StructField](
StructField("commit_time", DataTypes.StringType, nullable = true, Metadata.empty),
StructField("total_bytes_written", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_files_added", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_files_updated", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_partitions_written", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_records_written", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_update_records_written", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_errors", DataTypes.LongType, nullable = true, Metadata.empty)
))
private val METADATA_OUTPUT_TYPE = new StructType(Array[StructField](
StructField("commit_time", DataTypes.StringType, nullable = true, Metadata.empty),
StructField("action", DataTypes.StringType, nullable = true, Metadata.empty),
StructField("partition", DataTypes.StringType, nullable = true, Metadata.empty),
StructField("file_id", DataTypes.StringType, nullable = true, Metadata.empty),
StructField("previous_commit", DataTypes.StringType, nullable = true, Metadata.empty),
StructField("num_writes", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("num_inserts", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("num_deletes", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("num_update_writes", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_errors", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_log_blocks", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_corrupt_logblocks", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_rollback_blocks", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_log_records", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_updated_records_compacted", DataTypes.LongType, nullable = true, Metadata.empty),
StructField("total_bytes_written", DataTypes.LongType, nullable = true, Metadata.empty)
))
def parameters: Array[ProcedureParameter] = PARAMETERS
def outputType: StructType = if (includeExtraMetadata) METADATA_OUTPUT_TYPE else OUTPUT_TYPE
override def call(args: ProcedureArgs): Seq[Row] = {
super.checkArgs(PARAMETERS, args)
val table = getArgValueOrDefault(args, PARAMETERS(0)).asInstanceOf[String]
val limit = getArgValueOrDefault(args, PARAMETERS(1)).asInstanceOf[Int]
val hoodieCatalogTable = HoodieCatalogTable(sparkSession, new TableIdentifier(table))
val basePath = hoodieCatalogTable.tableLocation
val metaClient = HoodieTableMetaClient.builder.setConf(jsc.hadoopConfiguration()).setBasePath(basePath).build
val activeTimeline = metaClient.getActiveTimeline
if (includeExtraMetadata) {
getCommitsWithMetadata(activeTimeline, limit)
} else {
getCommits(activeTimeline, limit)
}
}
override def build: Procedure = new ShowCommitsProcedure(includeExtraMetadata)
private def getCommitsWithMetadata(timeline: HoodieDefaultTimeline,
limit: Int): Seq[Row] = {
import scala.collection.JavaConversions._
val (rows: util.ArrayList[Row], newCommits: util.ArrayList[HoodieInstant]) = getSortCommits(timeline)
for (i <- 0 until newCommits.size) {
val commit = newCommits.get(i)
val commitMetadata = HoodieCommitMetadata.fromBytes(timeline.getInstantDetails(commit).get, classOf[HoodieCommitMetadata])
for (partitionWriteStat <- commitMetadata.getPartitionToWriteStats.entrySet) {
for (hoodieWriteStat <- partitionWriteStat.getValue) {
rows.add(Row(
commit.getTimestamp, commit.getAction, hoodieWriteStat.getPartitionPath,
hoodieWriteStat.getFileId, hoodieWriteStat.getPrevCommit, hoodieWriteStat.getNumWrites,
hoodieWriteStat.getNumInserts, hoodieWriteStat.getNumDeletes, hoodieWriteStat.getNumUpdateWrites,
hoodieWriteStat.getTotalWriteErrors, hoodieWriteStat.getTotalLogBlocks, hoodieWriteStat.getTotalCorruptLogBlock,
hoodieWriteStat.getTotalRollbackBlocks, hoodieWriteStat.getTotalLogRecords,
hoodieWriteStat.getTotalUpdatedRecordsCompacted, hoodieWriteStat.getTotalWriteBytes))
}
}
}
rows.stream().limit(limit).toArray().map(r => r.asInstanceOf[Row]).toList
}
private def getSortCommits(timeline: HoodieDefaultTimeline): (util.ArrayList[Row], util.ArrayList[HoodieInstant]) = {
val rows = new util.ArrayList[Row]
// timeline can be read from multiple files. So sort is needed instead of reversing the collection
val commits: util.List[HoodieInstant] = timeline.getCommitsTimeline.filterCompletedInstants
.getInstants.toArray().map(instant => instant.asInstanceOf[HoodieInstant]).toList.asJava
val newCommits = new util.ArrayList[HoodieInstant](commits)
Collections.sort(newCommits, HoodieInstant.COMPARATOR.reversed)
(rows, newCommits)
}
def getCommits(timeline: HoodieDefaultTimeline,
limit: Int): Seq[Row] = {
val (rows: util.ArrayList[Row], newCommits: util.ArrayList[HoodieInstant]) = getSortCommits(timeline)
for (i <- 0 until newCommits.size) {
val commit = newCommits.get(i)
val commitMetadata = HoodieCommitMetadata.fromBytes(timeline.getInstantDetails(commit).get, classOf[HoodieCommitMetadata])
rows.add(Row(commit.getTimestamp, commitMetadata.fetchTotalBytesWritten, commitMetadata.fetchTotalFilesInsert,
commitMetadata.fetchTotalFilesUpdated, commitMetadata.fetchTotalPartitionsWritten,
commitMetadata.fetchTotalRecordsWritten, commitMetadata.fetchTotalUpdateRecordsWritten,
commitMetadata.fetchTotalWriteErrors))
}
rows.stream().limit(limit).toArray().map(r => r.asInstanceOf[Row]).toList
}
}
object ShowCommitsProcedure {
val NAME = "show_commits"
def builder: Supplier[ProcedureBuilder] = new Supplier[ProcedureBuilder] {
override def get() = new ShowCommitsProcedure(false)
}
}
object ShowCommitsMetadataProcedure {
val NAME = "show_commits_metadata"
def builder: Supplier[ProcedureBuilder] = new Supplier[ProcedureBuilder] {
override def get() = new ShowCommitsProcedure(true)
}
}

View File

@@ -17,22 +17,39 @@
package org.apache.spark.sql.parser
import org.antlr.v4.runtime.ParserRuleContext
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
import org.apache.hudi.SparkAdapterSupport
import org.apache.hudi.spark.sql.parser.{HoodieSqlCommonBaseVisitor, HoodieSqlCommonParser}
import org.apache.hudi.spark.sql.parser.HoodieSqlCommonParser.{CompactionOnPathContext, CompactionOnTableContext, ShowCompactionOnPathContext, ShowCompactionOnTableContext, SingleStatementContext, TableIdentifierContext}
import org.apache.hudi.spark.sql.parser.HoodieSqlCommonBaseVisitor
import org.apache.hudi.spark.sql.parser.HoodieSqlCommonParser._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin
import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils}
import org.apache.spark.sql.catalyst.plans.logical.{CompactionOperation, CompactionPath, CompactionShowOnPath, CompactionShowOnTable, CompactionTable, LogicalPlan}
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface, ParserUtils}
import org.apache.spark.sql.catalyst.plans.logical._
import scala.collection.JavaConverters._
class HoodieSqlCommonAstBuilder(session: SparkSession, delegate: ParserInterface)
extends HoodieSqlCommonBaseVisitor[AnyRef] with Logging with SparkAdapterSupport {
import ParserUtils._
/**
* Override the default behavior for all visit methods. This will only return a non-null result
* when the context has only one child. This is done because there is no generic method to
* combine the results of the context children. In all other cases null is returned.
*/
override def visitChildren(node: RuleNode): AnyRef = {
if (node.getChildCount == 1) {
node.getChild(0).accept(this)
} else {
null
}
}
override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
ctx.statement().accept(this).asInstanceOf[LogicalPlan]
}
@@ -72,4 +89,62 @@ class HoodieSqlCommonAstBuilder(session: SparkSession, delegate: ParserInterface
override def visitTableIdentifier(ctx: TableIdentifierContext): LogicalPlan = withOrigin(ctx) {
UnresolvedRelation(TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)))
}
override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) {
if (ctx.callArgument().isEmpty) {
throw new ParseException(s"Procedures arguments is empty", ctx)
}
val name: Seq[String] = ctx.multipartIdentifier().parts.asScala.map(_.getText)
val args: Seq[CallArgument] = ctx.callArgument().asScala.map(typedVisit[CallArgument])
CallCommand(name, args)
}
/**
* Return a multi-part identifier as Seq[String].
*/
override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) {
ctx.parts.asScala.map(_.getText)
}
/**
* Create a positional argument in a stored procedure call.
*/
override def visitPositionalArgument(ctx: PositionalArgumentContext): CallArgument = withOrigin(ctx) {
val expr = typedVisit[Expression](ctx.expression)
PositionalArgument(expr)
}
/**
* Create a named argument in a stored procedure call.
*/
override def visitNamedArgument(ctx: NamedArgumentContext): CallArgument = withOrigin(ctx) {
val name = ctx.identifier.getText
val expr = typedVisit[Expression](ctx.expression)
NamedArgument(name, expr)
}
def visitConstant(ctx: ConstantContext): Literal = {
delegate.parseExpression(ctx.getText).asInstanceOf[Literal]
}
override def visitExpression(ctx: ExpressionContext): Expression = {
// reconstruct the SQL string and parse it using the main Spark parser
// while we can avoid the logic to build Spark expressions, we still have to parse them
// we cannot call ctx.getText directly since it will not render spaces correctly
// that's why we need to recurse down the tree in reconstructSqlString
val sqlString = reconstructSqlString(ctx)
delegate.parseExpression(sqlString)
}
private def reconstructSqlString(ctx: ParserRuleContext): String = {
ctx.children.asScala.map {
case c: ParserRuleContext => reconstructSqlString(c)
case t: TerminalNode => t.getText
}.mkString(" ")
}
private def typedVisit[T](ctx: ParseTree): T = {
ctx.accept(this).asInstanceOf[T]
}
}

View File

@@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import com.google.common.collect.ImmutableList
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.{CallCommand, NamedArgument, PositionalArgument}
import org.apache.spark.sql.types.{DataType, DataTypes}
import java.math.BigDecimal
import scala.collection.JavaConverters
class TestCallCommandParser extends TestHoodieSqlBase {
private val parser = spark.sessionState.sqlParser
test("Test Call Produce with Positional Arguments") {
val call = parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)").asInstanceOf[CallCommand]
assertResult(ImmutableList.of("c", "n", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
assertResult(7)(call.args.size)
checkArg(call, 0, 1, DataTypes.IntegerType)
checkArg(call, 1, "2", DataTypes.StringType)
checkArg(call, 2, 3L, DataTypes.LongType)
checkArg(call, 3, true, DataTypes.BooleanType)
checkArg(call, 4, 1.0D, DataTypes.DoubleType)
checkArg(call, 5, new BigDecimal("9.0e1"), DataTypes.createDecimalType(2, 0))
checkArg(call, 6, new BigDecimal("900e-1"), DataTypes.createDecimalType(3, 1))
}
test("Test Call Produce with Named Arguments") {
val call = parser.parsePlan("CALL system.func(c1 => 1, c2 => '2', c3 => true)").asInstanceOf[CallCommand]
assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
assertResult(3)(call.args.size)
checkArg(call, 0, "c1", 1, DataTypes.IntegerType)
checkArg(call, 1, "c2", "2", DataTypes.StringType)
checkArg(call, 2, "c3", true, DataTypes.BooleanType)
}
test("Test Call Produce with Var Substitution") {
val call = parser.parsePlan("CALL system.func('${spark.extra.prop}')").asInstanceOf[CallCommand]
assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
assertResult(1)(call.args.size)
checkArg(call, 0, "value", DataTypes.StringType)
}
test("Test Call Produce with Mixed Arguments") {
val call = parser.parsePlan("CALL system.func(c1 => 1, '2')").asInstanceOf[CallCommand]
assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
assertResult(2)(call.args.size)
checkArg(call, 0, "c1", 1, DataTypes.IntegerType)
checkArg(call, 1, "2", DataTypes.StringType)
}
test("Test Call Parse Error") {
checkParseExceptionContain("CALL cat.system radish kebab")("mismatched input 'CALL' expecting")
}
protected def checkParseExceptionContain(sql: String)(errorMsg: String): Unit = {
var hasException = false
try {
parser.parsePlan(sql)
} catch {
case e: Throwable =>
assertResult(true)(e.getMessage.contains(errorMsg))
hasException = true
}
assertResult(true)(hasException)
}
private def checkArg(call: CallCommand, index: Int, expectedValue: Any, expectedType: DataType): Unit = {
checkArg(call, index, null, expectedValue, expectedType)
}
private def checkArg(call: CallCommand, index: Int, expectedName: String, expectedValue: Any, expectedType: DataType): Unit = {
if (expectedName != null) {
val arg = checkCast(call.args.apply(index), classOf[NamedArgument])
assertResult(expectedName)(arg.name)
}
else {
val arg = call.args.apply(index)
checkCast(arg, classOf[PositionalArgument])
}
val expectedExpr = toSparkLiteral(expectedValue, expectedType)
val actualExpr = call.args.apply(index).expr
assertResult(expectedExpr.dataType)(actualExpr.dataType)
}
private def toSparkLiteral(value: Any, dataType: DataType) = Literal.apply(value, dataType)
private def checkCast[T](value: Any, expectedClass: Class[T]) = {
assertResult(true)(expectedClass.isInstance(value))
expectedClass.cast(value)
}
}

View File

@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
class TestCallProcedure extends TestHoodieSqlBase {
test("Test Call show_commits Procedure") {
withTempDir { tmp =>
val tableName = generateTableName
// create table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| tblproperties (
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
// insert data to table
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500")
// Check required fields
checkExceptionContain(s"""call show_commits(limit => 10)""")(
s"Argument: table is required")
// collect commits for table
val commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect()
assertResult(2) {
commits.length
}
}
}
test("Test Call show_commits_metadata Procedure") {
withTempDir { tmp =>
val tableName = generateTableName
// create table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| tblproperties (
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
// insert data to table
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
// Check required fields
checkExceptionContain(s"""call show_commits_metadata(limit => 10)""")(
s"Argument: table is required")
// collect commits for table
val commits = spark.sql(s"""call show_commits_metadata(table => '$tableName', limit => 10)""").collect()
assertResult(1) {
commits.length
}
}
}
test("Test Call rollback_to_instant Procedure") {
withTempDir { tmp =>
val tableName = generateTableName
// create table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| tblproperties (
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
// insert data to table
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500")
spark.sql(s"insert into $tableName select 3, 'a3', 30, 2000")
// Check required fields
checkExceptionContain(s"""call rollback_to_instant(table => '$tableName')""")(
s"Argument: instant_time is required")
// 3 commits are left before rollback
var commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect()
assertResult(3){commits.length}
// Call rollback_to_instant Procedure with Named Arguments
var instant_time = commits(0).get(0).toString
checkAnswer(s"""call rollback_to_instant(table => '$tableName', instant_time => '$instant_time')""")(Seq(true))
// Call rollback_to_instant Procedure with Positional Arguments
instant_time = commits(1).get(0).toString
checkAnswer(s"""call rollback_to_instant('$tableName', '$instant_time')""")(Seq(true))
// 1 commits are left after rollback
commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect()
assertResult(1){commits.length}
}
}
}

View File

@@ -244,4 +244,4 @@
</dependency>
</dependencies>
</project>
</project>