diff --git a/hudi-spark-datasource/hudi-spark/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlCommon.g4 b/hudi-spark-datasource/hudi-spark/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlCommon.g4 index f2c5a562a..0cde14a4e 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlCommon.g4 +++ b/hudi-spark-datasource/hudi-spark/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlCommon.g4 @@ -14,59 +14,197 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -grammar HoodieSqlCommon; + + grammar HoodieSqlCommon; + + @lexer::members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } +} singleStatement : statement EOF ; statement - : compactionStatement #compactionCommand - | .*? #passThrough + : compactionStatement #compactionCommand + | CALL multipartIdentifier '(' (callArgument (',' callArgument)*)? ')' #call + | .*? #passThrough ; compactionStatement - : operation = (RUN | SCHEDULE) COMPACTION ON tableIdentifier (AT instantTimestamp = NUMBER)? #compactionOnTable - | operation = (RUN | SCHEDULE) COMPACTION ON path = STRING (AT instantTimestamp = NUMBER)? #compactionOnPath - | SHOW COMPACTION ON tableIdentifier (LIMIT limit = NUMBER)? #showCompactionOnTable - | SHOW COMPACTION ON path = STRING (LIMIT limit = NUMBER)? #showCompactionOnPath + : operation = (RUN | SCHEDULE) COMPACTION ON tableIdentifier (AT instantTimestamp = INTEGER_VALUE)? #compactionOnTable + | operation = (RUN | SCHEDULE) COMPACTION ON path = STRING (AT instantTimestamp = INTEGER_VALUE)? #compactionOnPath + | SHOW COMPACTION ON tableIdentifier (LIMIT limit = INTEGER_VALUE)? #showCompactionOnTable + | SHOW COMPACTION ON path = STRING (LIMIT limit = INTEGER_VALUE)? #showCompactionOnPath ; tableIdentifier : (db=IDENTIFIER '.')? table=IDENTIFIER ; + callArgument + : expression #positionalArgument + | identifier '=>' expression #namedArgument + ; + + expression + : constant + | stringMap + ; + + constant + : number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + | identifier STRING #typeConstructor + ; + + stringMap + : MAP '(' constant (',' constant)* ')' + ; + + booleanValue + : TRUE | FALSE + ; + + number + : MINUS? EXPONENT_VALUE #exponentLiteral + | MINUS? DECIMAL_VALUE #decimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + + multipartIdentifier + : parts+=identifier ('.' parts+=identifier)* + ; + + identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + + quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + + nonReserved + : CALL | COMPACTION | RUN | SCHEDULE | ON | SHOW | LIMIT + ; + ALL: 'ALL'; AT: 'AT'; + CALL: 'CALL'; COMPACTION: 'COMPACTION'; RUN: 'RUN'; SCHEDULE: 'SCHEDULE'; ON: 'ON'; SHOW: 'SHOW'; LIMIT: 'LIMIT'; + MAP: 'MAP'; + NULL: 'NULL'; + TRUE: 'TRUE'; + FALSE: 'FALSE'; + INTERVAL: 'INTERVAL'; + TO: 'TO'; - NUMBER - : DIGIT+ - ; - - IDENTIFIER - : (LETTER | DIGIT | '_')+ - ; + PLUS: '+'; + MINUS: '-'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' | '"' ( ~('"'|'\\') | ('\\' .) )* '"' ; + BIGINT_LITERAL + : DIGIT+ 'L' + ; + SMALLINT_LITERAL + : DIGIT+ 'S' + ; + + TINYINT_LITERAL + : DIGIT+ 'Y' + ; + + INTEGER_VALUE + : DIGIT+ + ; + + EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + + DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + + FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + + DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + + BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + + IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + + BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + + fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + + fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; fragment DIGIT - : [0-9] - ; + : [0-9] + ; fragment LETTER - : [A-Z] - ; + : [A-Z] + ; SIMPLE_COMMENT : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala new file mode 100644 index 000000000..df2a95375 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Expression + +case class CallCommand(name: Seq[String], args: Seq[CallArgument]) extends Command { + override def children: Seq[LogicalPlan] = Seq.empty + + def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): CallCommand = { + this + } +} + +/** + * An argument in a CALL statement. + */ +sealed trait CallArgument { + def expr: Expression +} + +/** + * An argument in a CALL statement identified by name. + */ +case class NamedArgument(name: String, expr: Expression) extends CallArgument + +/** + * An argument in a CALL statement identified by position. + */ +case class PositionalArgument(expr: Expression) extends CallArgument diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala index 12bbd6485..28f8a92e9 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala @@ -22,7 +22,7 @@ import org.apache.hudi.common.model.HoodieRecord import org.apache.hudi.common.util.ReflectionUtils import org.apache.hudi.{HoodieSparkUtils, SparkAdapterSupport} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, GenericInternalRow, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -31,10 +31,12 @@ import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{getTableIdentifier, removeMetaFields} import org.apache.spark.sql.hudi.HoodieSqlUtils._ import org.apache.spark.sql.hudi.command._ +import org.apache.spark.sql.hudi.command.procedures.{HoodieProcedures, Procedure, ProcedureArgs} import org.apache.spark.sql.hudi.{HoodieOptionConfig, HoodieSqlCommonUtils} import org.apache.spark.sql.types.StringType import org.apache.spark.sql.{AnalysisException, SparkSession} +import java.util import scala.collection.JavaConverters._ object HoodieAnalysis { @@ -79,6 +81,7 @@ object HoodieAnalysis { /** * Rule for convert the logical plan to command. + * * @param sparkSession */ case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan] @@ -133,26 +136,69 @@ case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan] // Convert to CompactionShowHoodiePathCommand case CompactionShowOnPath(path, limit) => CompactionShowHoodiePathCommand(path, limit) - case _=> plan + // Convert to HoodieCallProcedureCommand + case c@CallCommand(_, _) => + val procedure: Option[Procedure] = loadProcedure(c.name) + val input = buildProcedureArgs(c.args) + if (procedure.nonEmpty) { + CallProcedureHoodieCommand(procedure.get, input) + } else { + c + } + case _ => plan } } + + private def loadProcedure(name: Seq[String]): Option[Procedure] = { + val procedure: Option[Procedure] = if (name.nonEmpty) { + val builder = HoodieProcedures.newBuilder(name.last) + if (builder != null) { + Option(builder.build) + } else { + throw new AnalysisException(s"procedure: ${name.last} is not exists") + } + } else { + None + } + procedure + } + + private def buildProcedureArgs(exprs: Seq[CallArgument]): ProcedureArgs = { + val values = new Array[Any](exprs.size) + var isNamedArgs: Boolean = false + val map = new util.LinkedHashMap[String, Int]() + for (index <- exprs.indices) { + exprs(index) match { + case expr: NamedArgument => + map.put(expr.name, index) + values(index) = expr.expr.eval() + isNamedArgs = true + case _ => + map.put(index.toString, index) + values(index) = exprs(index).expr.eval() + isNamedArgs = false + } + } + ProcedureArgs(isNamedArgs, map, new GenericInternalRow(values)) + } } /** * Rule for resolve hoodie's extended syntax or rewrite some logical plan. + * * @param sparkSession */ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[LogicalPlan] with SparkAdapterSupport { private lazy val analyzer = sparkSession.sessionState.analyzer - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { // Resolve merge into case mergeInto @ MergeIntoTable(target, source, mergeCondition, matchedActions, notMatchedActions) if sparkAdapter.isHoodieTable(target, sparkSession) && target.resolved => - val resolver = sparkSession.sessionState.conf.resolver val resolvedSource = analyzer.execute(source) + def isInsertOrUpdateStar(assignments: Seq[Assignment]): Boolean = { if (assignments.isEmpty) { true @@ -363,6 +409,7 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi /** * Check if the the query of insert statement has already append the meta fields to avoid * duplicate append. + * * @param query * @return */ diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CallProcedureHoodieCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CallProcedureHoodieCommand.scala new file mode 100644 index 000000000..f63f4115e --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CallProcedureHoodieCommand.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.hudi.command.procedures.{Procedure, ProcedureArgs} +import org.apache.spark.sql.{Row, SparkSession} + +import scala.collection.Seq + +case class CallProcedureHoodieCommand( + procedure: Procedure, + args: ProcedureArgs) extends HoodieLeafRunnableCommand { + + override def output: Seq[Attribute] = procedure.outputType.toAttributes + + override def run(sparkSession: SparkSession): Seq[Row] = { + procedure.call(args) + } +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/BaseProcedure.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/BaseProcedure.scala new file mode 100644 index 000000000..e64df997d --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/BaseProcedure.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command.procedures + +import org.apache.hudi.client.SparkRDDWriteClient +import org.apache.hudi.client.common.HoodieSparkEngineContext +import org.apache.hudi.common.model.HoodieRecordPayload +import org.apache.hudi.config.{HoodieIndexConfig, HoodieWriteConfig} +import org.apache.hudi.index.HoodieIndex.IndexType +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +abstract class BaseProcedure extends Procedure { + val INVALID_ARG_INDEX: Int = -1 + + val spark: SparkSession = SparkSession.active + val jsc = new JavaSparkContext(spark.sparkContext) + + protected def sparkSession: SparkSession = spark + + protected def createHoodieClient(jsc: JavaSparkContext, basePath: String): SparkRDDWriteClient[_ <: HoodieRecordPayload[_ <: AnyRef]] = { + val config = getWriteConfig(basePath) + new SparkRDDWriteClient(new HoodieSparkEngineContext(jsc), config) + } + + protected def getWriteConfig(basePath: String): HoodieWriteConfig = { + HoodieWriteConfig.newBuilder + .withPath(basePath) + .withIndexConfig(HoodieIndexConfig.newBuilder.withIndexType(IndexType.BLOOM).build) + .withRollbackUsingMarkers(false) + .build + } + + protected def checkArgs(target: Array[ProcedureParameter], args: ProcedureArgs): Unit = { + val internalRow = args.internalRow + for (i <- target.indices) { + if (target(i).required) { + var argsIndex: Integer = null + if (args.isNamedArgs) { + argsIndex = getArgsIndex(target(i).name, args) + } else { + argsIndex = getArgsIndex(i.toString, args) + } + assert(-1 != argsIndex && internalRow.get(argsIndex, target(i).dataType) != null, + s"Argument: ${target(i).name} is required") + } + } + } + + protected def getArgsIndex(key: String, args: ProcedureArgs): Integer = { + args.map.getOrDefault(key, INVALID_ARG_INDEX) + } + + protected def getArgValueOrDefault(args: ProcedureArgs, parameter: ProcedureParameter): Any = { + var argsIndex: Int = INVALID_ARG_INDEX + if (args.isNamedArgs) { + argsIndex = getArgsIndex(parameter.name, args) + } else { + argsIndex = getArgsIndex(parameter.index.toString, args) + } + if (argsIndex.equals(INVALID_ARG_INDEX)) parameter.default else getInternalRowValue(args.internalRow, argsIndex, parameter.dataType) + } + + protected def getInternalRowValue(row: InternalRow, index: Int, dataType: DataType): Any = { + dataType match { + case StringType => row.getString(index) + case BinaryType => row.getBinary(index) + case BooleanType => row.getBoolean(index) + case CalendarIntervalType => row.getInterval(index) + case DoubleType => row.getDouble(index) + case d: DecimalType => row.getDecimal(index, d.precision, d.scale) + case FloatType => row.getFloat(index) + case ByteType => row.getByte(index) + case IntegerType => row.getInt(index) + case LongType => row.getLong(index) + case ShortType => row.getShort(index) + case NullType => null + case _ => + throw new UnsupportedOperationException(s"type: ${dataType.typeName} not supported") + } + } +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/HoodieProcedures.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/HoodieProcedures.scala new file mode 100644 index 000000000..7b919fcef --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/HoodieProcedures.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command.procedures + +import com.google.common.collect.ImmutableMap + +import java.util +import java.util.Locale +import java.util.function.Supplier + +object HoodieProcedures { + private val BUILDERS: util.Map[String, Supplier[ProcedureBuilder]] = initProcedureBuilders + + def newBuilder(name: String): ProcedureBuilder = { + val builderSupplier: Supplier[ProcedureBuilder] = BUILDERS.get(name.toLowerCase(Locale.ROOT)) + if (builderSupplier != null) builderSupplier.get else null + } + + private def initProcedureBuilders: util.Map[String, Supplier[ProcedureBuilder]] = { + val mapBuilder: ImmutableMap.Builder[String, Supplier[ProcedureBuilder]] = ImmutableMap.builder() + mapBuilder.put(ShowCommitsProcedure.NAME, ShowCommitsProcedure.builder) + mapBuilder.put(ShowCommitsMetadataProcedure.NAME, ShowCommitsMetadataProcedure.builder) + mapBuilder.put(RollbackToInstantTimeProcedure.NAME, RollbackToInstantTimeProcedure.builder) + mapBuilder.build + } +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/Procedure.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/Procedure.scala new file mode 100644 index 000000000..f34e30615 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/Procedure.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command.procedures + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType + +import java.util +import scala.collection.mutable + +/** + * An interface representing a stored procedure available for execution. + */ +trait Procedure { + /** + * Returns the input parameters of this procedure. + */ + def parameters: Array[ProcedureParameter] + + /** + * Returns the type of rows produced by this procedure. + */ + def outputType: StructType + + /** + * Executes this procedure. + *
+ * Spark will align the provided arguments according to the input parameters + * defined in {@link #parameters ( )} either by position or by name before execution. + *
+ * Implementations may provide a summary of execution by returning one or many rows + * as a result. The schema of output rows must match the defined output type + * in {@link #outputType ( )}. + * + * @param args input arguments + * @return the result of executing this procedure with the given arguments + */ + def call(args: ProcedureArgs): Seq[Row] + + /** + * Returns the description of this procedure. + */ + def description: String = this.getClass.toString +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureArgs.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureArgs.scala new file mode 100644 index 000000000..5c462c1b8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureArgs.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command.procedures + +import org.apache.spark.sql.catalyst.InternalRow + +import java.util + +case class ProcedureArgs(isNamedArgs: Boolean, + map: util.LinkedHashMap[String, Int], + internalRow: InternalRow) { +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureBuilder.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureBuilder.scala new file mode 100644 index 000000000..b2ecd0a30 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureBuilder.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command.procedures + +trait ProcedureBuilder { + def build: Procedure +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureParameter.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureParameter.scala new file mode 100644 index 000000000..a9ad252bd --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureParameter.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command.procedures + +import org.apache.spark.sql.types.DataType + +/** + * An input parameter of a {@link Procedure stored procedure}. + */ +abstract class ProcedureParameter { + def index: Int + + /** + * Returns the name of this parameter. + */ + def name: String + + /** + * Returns the type of this parameter. + */ + def dataType: DataType + + /** + * Returns true if this parameter is required. + */ + def required: Boolean + + /** + * this parameter's default value. + */ + def default: Any +} + +object ProcedureParameter { + /** + * Creates a required input parameter. + * + * @param name the name of the parameter + * @param dataType the type of the parameter + * @return the constructed stored procedure parameter + */ + def required(index: Int, name: String, dataType: DataType, default: Any): ProcedureParameterImpl = { + ProcedureParameterImpl(index, name, dataType, default, required = true) + } + + /** + * Creates an optional input parameter. + * + * @param name the name of the parameter. + * @param dataType the type of the parameter. + * @return the constructed optional stored procedure parameter + */ + def optional(index: Int, name: String, dataType: DataType, default: Any): ProcedureParameterImpl = { + ProcedureParameterImpl(index, name, dataType, default, required = false) + } +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureParameterImpl.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureParameterImpl.scala new file mode 100644 index 000000000..a7f411704 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ProcedureParameterImpl.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command.procedures + +import org.apache.spark.sql.types.DataType + +import java.util.Objects + +case class ProcedureParameterImpl(index: Int, name: String, dataType: DataType, default: Any, required: Boolean) + extends ProcedureParameter { + + override def equals(other: Any): Boolean = { + val that = other.asInstanceOf[ProcedureParameterImpl] + val rtn = if (this == other) { + true + } else if (other == null || (getClass ne other.getClass)) { + false + } else { + index == that.index && required == that.required && default == that.default && Objects.equals(name, that.name) && Objects.equals(dataType, that.dataType) + } + rtn + } + + override def hashCode: Int = Seq(index, name, dataType, required, default).hashCode() + + override def toString: String = s"ProcedureParameter(index='$index',name='$name', type=$dataType, required=$required, default=$default)" +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RollbackToInstantTimeProcedure.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RollbackToInstantTimeProcedure.scala new file mode 100644 index 000000000..5414e8db6 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/RollbackToInstantTimeProcedure.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command.procedures + +import org.apache.hudi.common.table.HoodieTableMetaClient +import org.apache.hudi.common.table.timeline.HoodieTimeline +import org.apache.hudi.common.table.timeline.versioning.TimelineLayoutVersion +import org.apache.hudi.common.util.Option +import org.apache.hudi.exception.HoodieException +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable +import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType} + +import java.util.function.Supplier + +class RollbackToInstantTimeProcedure extends BaseProcedure with ProcedureBuilder { + private val PARAMETERS = Array[ProcedureParameter]( + ProcedureParameter.required(0, "table", DataTypes.StringType, None), + ProcedureParameter.required(1, "instant_time", DataTypes.StringType, None)) + + private val OUTPUT_TYPE = new StructType(Array[StructField]( + StructField("rollback_result", DataTypes.BooleanType, nullable = true, Metadata.empty)) + ) + + def parameters: Array[ProcedureParameter] = PARAMETERS + + def outputType: StructType = OUTPUT_TYPE + + override def call(args: ProcedureArgs): Seq[Row] = { + super.checkArgs(PARAMETERS, args) + + val table = getArgValueOrDefault(args, PARAMETERS(0)).asInstanceOf[String] + val instantTime = getArgValueOrDefault(args, PARAMETERS(1)).asInstanceOf[String] + + val hoodieCatalogTable = HoodieCatalogTable(sparkSession, new TableIdentifier(table)) + val basePath = hoodieCatalogTable.tableLocation + val client = createHoodieClient(jsc, basePath) + val config = getWriteConfig(basePath) + val metaClient = HoodieTableMetaClient.builder + .setConf(jsc.hadoopConfiguration) + .setBasePath(config.getBasePath) + .setLoadActiveTimelineOnLoad(false) + .setConsistencyGuardConfig(config.getConsistencyGuardConfig) + .setLayoutVersion(Option.of(new TimelineLayoutVersion(config.getTimelineLayoutVersion))) + .build + + val activeTimeline = metaClient.getActiveTimeline + val completedTimeline: HoodieTimeline = activeTimeline.getCommitsTimeline.filterCompletedInstants + val filteredTimeline = completedTimeline.containsInstant(instantTime) + if (!filteredTimeline) { + throw new HoodieException(s"Commit $instantTime not found in Commits $completedTimeline") + } + + val result = if (client.rollback(instantTime)) true else false + val outputRow = Row(result) + + Seq(outputRow) + } + + override def build: Procedure = new RollbackToInstantTimeProcedure() +} + +object RollbackToInstantTimeProcedure { + val NAME: String = "rollback_to_instant" + + def builder: Supplier[ProcedureBuilder] = new Supplier[ProcedureBuilder] { + override def get(): RollbackToInstantTimeProcedure = new RollbackToInstantTimeProcedure() + } +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ShowCommitsProcedure.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ShowCommitsProcedure.scala new file mode 100644 index 000000000..da089baba --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/ShowCommitsProcedure.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command.procedures + +import org.apache.hudi.common.model.HoodieCommitMetadata +import org.apache.hudi.common.table.HoodieTableMetaClient +import org.apache.hudi.common.table.timeline.{HoodieDefaultTimeline, HoodieInstant} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable +import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType} + +import java.util +import java.util.Collections +import java.util.function.Supplier +import scala.collection.JavaConverters._ + +class ShowCommitsProcedure(includeExtraMetadata: Boolean) extends BaseProcedure with ProcedureBuilder { + var sortByFieldParameter: ProcedureParameter = _ + + private val PARAMETERS = Array[ProcedureParameter]( + ProcedureParameter.required(0, "table", DataTypes.StringType, None), + ProcedureParameter.optional(1, "limit", DataTypes.IntegerType, 10) + ) + + private val OUTPUT_TYPE = new StructType(Array[StructField]( + StructField("commit_time", DataTypes.StringType, nullable = true, Metadata.empty), + StructField("total_bytes_written", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_files_added", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_files_updated", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_partitions_written", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_records_written", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_update_records_written", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_errors", DataTypes.LongType, nullable = true, Metadata.empty) + )) + + private val METADATA_OUTPUT_TYPE = new StructType(Array[StructField]( + StructField("commit_time", DataTypes.StringType, nullable = true, Metadata.empty), + StructField("action", DataTypes.StringType, nullable = true, Metadata.empty), + StructField("partition", DataTypes.StringType, nullable = true, Metadata.empty), + StructField("file_id", DataTypes.StringType, nullable = true, Metadata.empty), + StructField("previous_commit", DataTypes.StringType, nullable = true, Metadata.empty), + StructField("num_writes", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("num_inserts", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("num_deletes", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("num_update_writes", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_errors", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_log_blocks", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_corrupt_logblocks", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_rollback_blocks", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_log_records", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_updated_records_compacted", DataTypes.LongType, nullable = true, Metadata.empty), + StructField("total_bytes_written", DataTypes.LongType, nullable = true, Metadata.empty) + )) + + def parameters: Array[ProcedureParameter] = PARAMETERS + + def outputType: StructType = if (includeExtraMetadata) METADATA_OUTPUT_TYPE else OUTPUT_TYPE + + override def call(args: ProcedureArgs): Seq[Row] = { + super.checkArgs(PARAMETERS, args) + + val table = getArgValueOrDefault(args, PARAMETERS(0)).asInstanceOf[String] + val limit = getArgValueOrDefault(args, PARAMETERS(1)).asInstanceOf[Int] + + val hoodieCatalogTable = HoodieCatalogTable(sparkSession, new TableIdentifier(table)) + val basePath = hoodieCatalogTable.tableLocation + val metaClient = HoodieTableMetaClient.builder.setConf(jsc.hadoopConfiguration()).setBasePath(basePath).build + + val activeTimeline = metaClient.getActiveTimeline + if (includeExtraMetadata) { + getCommitsWithMetadata(activeTimeline, limit) + } else { + getCommits(activeTimeline, limit) + } + } + + override def build: Procedure = new ShowCommitsProcedure(includeExtraMetadata) + + private def getCommitsWithMetadata(timeline: HoodieDefaultTimeline, + limit: Int): Seq[Row] = { + import scala.collection.JavaConversions._ + + val (rows: util.ArrayList[Row], newCommits: util.ArrayList[HoodieInstant]) = getSortCommits(timeline) + + for (i <- 0 until newCommits.size) { + val commit = newCommits.get(i) + val commitMetadata = HoodieCommitMetadata.fromBytes(timeline.getInstantDetails(commit).get, classOf[HoodieCommitMetadata]) + for (partitionWriteStat <- commitMetadata.getPartitionToWriteStats.entrySet) { + for (hoodieWriteStat <- partitionWriteStat.getValue) { + rows.add(Row( + commit.getTimestamp, commit.getAction, hoodieWriteStat.getPartitionPath, + hoodieWriteStat.getFileId, hoodieWriteStat.getPrevCommit, hoodieWriteStat.getNumWrites, + hoodieWriteStat.getNumInserts, hoodieWriteStat.getNumDeletes, hoodieWriteStat.getNumUpdateWrites, + hoodieWriteStat.getTotalWriteErrors, hoodieWriteStat.getTotalLogBlocks, hoodieWriteStat.getTotalCorruptLogBlock, + hoodieWriteStat.getTotalRollbackBlocks, hoodieWriteStat.getTotalLogRecords, + hoodieWriteStat.getTotalUpdatedRecordsCompacted, hoodieWriteStat.getTotalWriteBytes)) + } + } + } + + rows.stream().limit(limit).toArray().map(r => r.asInstanceOf[Row]).toList + } + + private def getSortCommits(timeline: HoodieDefaultTimeline): (util.ArrayList[Row], util.ArrayList[HoodieInstant]) = { + val rows = new util.ArrayList[Row] + // timeline can be read from multiple files. So sort is needed instead of reversing the collection + val commits: util.List[HoodieInstant] = timeline.getCommitsTimeline.filterCompletedInstants + .getInstants.toArray().map(instant => instant.asInstanceOf[HoodieInstant]).toList.asJava + val newCommits = new util.ArrayList[HoodieInstant](commits) + Collections.sort(newCommits, HoodieInstant.COMPARATOR.reversed) + (rows, newCommits) + } + + def getCommits(timeline: HoodieDefaultTimeline, + limit: Int): Seq[Row] = { + val (rows: util.ArrayList[Row], newCommits: util.ArrayList[HoodieInstant]) = getSortCommits(timeline) + + for (i <- 0 until newCommits.size) { + val commit = newCommits.get(i) + val commitMetadata = HoodieCommitMetadata.fromBytes(timeline.getInstantDetails(commit).get, classOf[HoodieCommitMetadata]) + rows.add(Row(commit.getTimestamp, commitMetadata.fetchTotalBytesWritten, commitMetadata.fetchTotalFilesInsert, + commitMetadata.fetchTotalFilesUpdated, commitMetadata.fetchTotalPartitionsWritten, + commitMetadata.fetchTotalRecordsWritten, commitMetadata.fetchTotalUpdateRecordsWritten, + commitMetadata.fetchTotalWriteErrors)) + } + + rows.stream().limit(limit).toArray().map(r => r.asInstanceOf[Row]).toList + } +} + +object ShowCommitsProcedure { + val NAME = "show_commits" + + def builder: Supplier[ProcedureBuilder] = new Supplier[ProcedureBuilder] { + override def get() = new ShowCommitsProcedure(false) + } +} + +object ShowCommitsMetadataProcedure { + val NAME = "show_commits_metadata" + + def builder: Supplier[ProcedureBuilder] = new Supplier[ProcedureBuilder] { + override def get() = new ShowCommitsProcedure(true) + } +} + + diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/parser/HoodieSqlCommonAstBuilder.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/parser/HoodieSqlCommonAstBuilder.scala index b1f5a32fe..3146740b1 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/parser/HoodieSqlCommonAstBuilder.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/parser/HoodieSqlCommonAstBuilder.scala @@ -17,22 +17,39 @@ package org.apache.spark.sql.parser +import org.antlr.v4.runtime.ParserRuleContext +import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} import org.apache.hudi.SparkAdapterSupport -import org.apache.hudi.spark.sql.parser.{HoodieSqlCommonBaseVisitor, HoodieSqlCommonParser} -import org.apache.hudi.spark.sql.parser.HoodieSqlCommonParser.{CompactionOnPathContext, CompactionOnTableContext, ShowCompactionOnPathContext, ShowCompactionOnTableContext, SingleStatementContext, TableIdentifierContext} +import org.apache.hudi.spark.sql.parser.HoodieSqlCommonBaseVisitor +import org.apache.hudi.spark.sql.parser.HoodieSqlCommonParser._ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin -import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} -import org.apache.spark.sql.catalyst.plans.logical.{CompactionOperation, CompactionPath, CompactionShowOnPath, CompactionShowOnTable, CompactionTable, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface, ParserUtils} +import org.apache.spark.sql.catalyst.plans.logical._ + +import scala.collection.JavaConverters._ class HoodieSqlCommonAstBuilder(session: SparkSession, delegate: ParserInterface) extends HoodieSqlCommonBaseVisitor[AnyRef] with Logging with SparkAdapterSupport { import ParserUtils._ + /** + * Override the default behavior for all visit methods. This will only return a non-null result + * when the context has only one child. This is done because there is no generic method to + * combine the results of the context children. In all other cases null is returned. + */ + override def visitChildren(node: RuleNode): AnyRef = { + if (node.getChildCount == 1) { + node.getChild(0).accept(this) + } else { + null + } + } + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { ctx.statement().accept(this).asInstanceOf[LogicalPlan] } @@ -72,4 +89,62 @@ class HoodieSqlCommonAstBuilder(session: SparkSession, delegate: ParserInterface override def visitTableIdentifier(ctx: TableIdentifierContext): LogicalPlan = withOrigin(ctx) { UnresolvedRelation(TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))) } + + override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) { + if (ctx.callArgument().isEmpty) { + throw new ParseException(s"Procedures arguments is empty", ctx) + } + + val name: Seq[String] = ctx.multipartIdentifier().parts.asScala.map(_.getText) + val args: Seq[CallArgument] = ctx.callArgument().asScala.map(typedVisit[CallArgument]) + CallCommand(name, args) + } + + /** + * Return a multi-part identifier as Seq[String]. + */ + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + ctx.parts.asScala.map(_.getText) + } + + /** + * Create a positional argument in a stored procedure call. + */ + override def visitPositionalArgument(ctx: PositionalArgumentContext): CallArgument = withOrigin(ctx) { + val expr = typedVisit[Expression](ctx.expression) + PositionalArgument(expr) + } + + /** + * Create a named argument in a stored procedure call. + */ + override def visitNamedArgument(ctx: NamedArgumentContext): CallArgument = withOrigin(ctx) { + val name = ctx.identifier.getText + val expr = typedVisit[Expression](ctx.expression) + NamedArgument(name, expr) + } + + def visitConstant(ctx: ConstantContext): Literal = { + delegate.parseExpression(ctx.getText).asInstanceOf[Literal] + } + + override def visitExpression(ctx: ExpressionContext): Expression = { + // reconstruct the SQL string and parse it using the main Spark parser + // while we can avoid the logic to build Spark expressions, we still have to parse them + // we cannot call ctx.getText directly since it will not render spaces correctly + // that's why we need to recurse down the tree in reconstructSqlString + val sqlString = reconstructSqlString(ctx) + delegate.parseExpression(sqlString) + } + + private def reconstructSqlString(ctx: ParserRuleContext): String = { + ctx.children.asScala.map { + case c: ParserRuleContext => reconstructSqlString(c) + case t: TerminalNode => t.getText + }.mkString(" ") + } + + private def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestCallCommandParser.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestCallCommandParser.scala new file mode 100644 index 000000000..9d1c02ad9 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestCallCommandParser.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi + +import com.google.common.collect.ImmutableList +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.{CallCommand, NamedArgument, PositionalArgument} +import org.apache.spark.sql.types.{DataType, DataTypes} + +import java.math.BigDecimal +import scala.collection.JavaConverters + +class TestCallCommandParser extends TestHoodieSqlBase { + private val parser = spark.sessionState.sqlParser + + test("Test Call Produce with Positional Arguments") { + val call = parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)").asInstanceOf[CallCommand] + assertResult(ImmutableList.of("c", "n", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava) + + assertResult(7)(call.args.size) + + checkArg(call, 0, 1, DataTypes.IntegerType) + checkArg(call, 1, "2", DataTypes.StringType) + checkArg(call, 2, 3L, DataTypes.LongType) + checkArg(call, 3, true, DataTypes.BooleanType) + checkArg(call, 4, 1.0D, DataTypes.DoubleType) + checkArg(call, 5, new BigDecimal("9.0e1"), DataTypes.createDecimalType(2, 0)) + checkArg(call, 6, new BigDecimal("900e-1"), DataTypes.createDecimalType(3, 1)) + } + + test("Test Call Produce with Named Arguments") { + val call = parser.parsePlan("CALL system.func(c1 => 1, c2 => '2', c3 => true)").asInstanceOf[CallCommand] + assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava) + + assertResult(3)(call.args.size) + + checkArg(call, 0, "c1", 1, DataTypes.IntegerType) + checkArg(call, 1, "c2", "2", DataTypes.StringType) + checkArg(call, 2, "c3", true, DataTypes.BooleanType) + } + + test("Test Call Produce with Var Substitution") { + val call = parser.parsePlan("CALL system.func('${spark.extra.prop}')").asInstanceOf[CallCommand] + assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava) + + assertResult(1)(call.args.size) + + checkArg(call, 0, "value", DataTypes.StringType) + } + + test("Test Call Produce with Mixed Arguments") { + val call = parser.parsePlan("CALL system.func(c1 => 1, '2')").asInstanceOf[CallCommand] + assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava) + + assertResult(2)(call.args.size) + + checkArg(call, 0, "c1", 1, DataTypes.IntegerType) + checkArg(call, 1, "2", DataTypes.StringType) + } + + test("Test Call Parse Error") { + checkParseExceptionContain("CALL cat.system radish kebab")("mismatched input 'CALL' expecting") + } + + protected def checkParseExceptionContain(sql: String)(errorMsg: String): Unit = { + var hasException = false + try { + parser.parsePlan(sql) + } catch { + case e: Throwable => + assertResult(true)(e.getMessage.contains(errorMsg)) + hasException = true + } + assertResult(true)(hasException) + } + + private def checkArg(call: CallCommand, index: Int, expectedValue: Any, expectedType: DataType): Unit = { + checkArg(call, index, null, expectedValue, expectedType) + } + + private def checkArg(call: CallCommand, index: Int, expectedName: String, expectedValue: Any, expectedType: DataType): Unit = { + if (expectedName != null) { + val arg = checkCast(call.args.apply(index), classOf[NamedArgument]) + assertResult(expectedName)(arg.name) + } + else { + val arg = call.args.apply(index) + checkCast(arg, classOf[PositionalArgument]) + } + val expectedExpr = toSparkLiteral(expectedValue, expectedType) + val actualExpr = call.args.apply(index).expr + assertResult(expectedExpr.dataType)(actualExpr.dataType) + } + + private def toSparkLiteral(value: Any, dataType: DataType) = Literal.apply(value, dataType) + + private def checkCast[T](value: Any, expectedClass: Class[T]) = { + assertResult(true)(expectedClass.isInstance(value)) + expectedClass.cast(value) + } +} diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestCallProcedure.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestCallProcedure.scala new file mode 100644 index 000000000..eb2c614df --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestCallProcedure.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi + +class TestCallProcedure extends TestHoodieSqlBase { + + test("Test Call show_commits Procedure") { + withTempDir { tmp => + val tableName = generateTableName + // create table + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | tblproperties ( + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + """.stripMargin) + // insert data to table + spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000") + spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500") + + // Check required fields + checkExceptionContain(s"""call show_commits(limit => 10)""")( + s"Argument: table is required") + + // collect commits for table + val commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect() + assertResult(2) { + commits.length + } + } + } + + test("Test Call show_commits_metadata Procedure") { + withTempDir { tmp => + val tableName = generateTableName + // create table + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | tblproperties ( + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + """.stripMargin) + // insert data to table + spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000") + + // Check required fields + checkExceptionContain(s"""call show_commits_metadata(limit => 10)""")( + s"Argument: table is required") + + // collect commits for table + val commits = spark.sql(s"""call show_commits_metadata(table => '$tableName', limit => 10)""").collect() + assertResult(1) { + commits.length + } + } + } + + test("Test Call rollback_to_instant Procedure") { + withTempDir { tmp => + val tableName = generateTableName + // create table + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | tblproperties ( + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + """.stripMargin) + // insert data to table + spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000") + spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500") + spark.sql(s"insert into $tableName select 3, 'a3', 30, 2000") + + // Check required fields + checkExceptionContain(s"""call rollback_to_instant(table => '$tableName')""")( + s"Argument: instant_time is required") + + // 3 commits are left before rollback + var commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect() + assertResult(3){commits.length} + + // Call rollback_to_instant Procedure with Named Arguments + var instant_time = commits(0).get(0).toString + checkAnswer(s"""call rollback_to_instant(table => '$tableName', instant_time => '$instant_time')""")(Seq(true)) + // Call rollback_to_instant Procedure with Positional Arguments + instant_time = commits(1).get(0).toString + checkAnswer(s"""call rollback_to_instant('$tableName', '$instant_time')""")(Seq(true)) + + // 1 commits are left after rollback + commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect() + assertResult(1){commits.length} + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3-common/pom.xml b/hudi-spark-datasource/hudi-spark3-common/pom.xml index affa98737..8fd46c7b3 100644 --- a/hudi-spark-datasource/hudi-spark3-common/pom.xml +++ b/hudi-spark-datasource/hudi-spark3-common/pom.xml @@ -244,4 +244,4 @@ - \ No newline at end of file +