1
0

[HUDI-3161] Add Call Produce Command for Spark SQL (#4535)

This commit is contained in:
ForwardXu
2022-02-24 23:45:37 +08:00
committed by GitHub
parent 943b99775b
commit 521338b4d9
17 changed files with 1228 additions and 28 deletions

View File

@@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import com.google.common.collect.ImmutableList
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.{CallCommand, NamedArgument, PositionalArgument}
import org.apache.spark.sql.types.{DataType, DataTypes}
import java.math.BigDecimal
import scala.collection.JavaConverters
class TestCallCommandParser extends TestHoodieSqlBase {
private val parser = spark.sessionState.sqlParser
test("Test Call Produce with Positional Arguments") {
val call = parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)").asInstanceOf[CallCommand]
assertResult(ImmutableList.of("c", "n", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
assertResult(7)(call.args.size)
checkArg(call, 0, 1, DataTypes.IntegerType)
checkArg(call, 1, "2", DataTypes.StringType)
checkArg(call, 2, 3L, DataTypes.LongType)
checkArg(call, 3, true, DataTypes.BooleanType)
checkArg(call, 4, 1.0D, DataTypes.DoubleType)
checkArg(call, 5, new BigDecimal("9.0e1"), DataTypes.createDecimalType(2, 0))
checkArg(call, 6, new BigDecimal("900e-1"), DataTypes.createDecimalType(3, 1))
}
test("Test Call Produce with Named Arguments") {
val call = parser.parsePlan("CALL system.func(c1 => 1, c2 => '2', c3 => true)").asInstanceOf[CallCommand]
assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
assertResult(3)(call.args.size)
checkArg(call, 0, "c1", 1, DataTypes.IntegerType)
checkArg(call, 1, "c2", "2", DataTypes.StringType)
checkArg(call, 2, "c3", true, DataTypes.BooleanType)
}
test("Test Call Produce with Var Substitution") {
val call = parser.parsePlan("CALL system.func('${spark.extra.prop}')").asInstanceOf[CallCommand]
assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
assertResult(1)(call.args.size)
checkArg(call, 0, "value", DataTypes.StringType)
}
test("Test Call Produce with Mixed Arguments") {
val call = parser.parsePlan("CALL system.func(c1 => 1, '2')").asInstanceOf[CallCommand]
assertResult(ImmutableList.of("system", "func"))(JavaConverters.seqAsJavaListConverter(call.name).asJava)
assertResult(2)(call.args.size)
checkArg(call, 0, "c1", 1, DataTypes.IntegerType)
checkArg(call, 1, "2", DataTypes.StringType)
}
test("Test Call Parse Error") {
checkParseExceptionContain("CALL cat.system radish kebab")("mismatched input 'CALL' expecting")
}
protected def checkParseExceptionContain(sql: String)(errorMsg: String): Unit = {
var hasException = false
try {
parser.parsePlan(sql)
} catch {
case e: Throwable =>
assertResult(true)(e.getMessage.contains(errorMsg))
hasException = true
}
assertResult(true)(hasException)
}
private def checkArg(call: CallCommand, index: Int, expectedValue: Any, expectedType: DataType): Unit = {
checkArg(call, index, null, expectedValue, expectedType)
}
private def checkArg(call: CallCommand, index: Int, expectedName: String, expectedValue: Any, expectedType: DataType): Unit = {
if (expectedName != null) {
val arg = checkCast(call.args.apply(index), classOf[NamedArgument])
assertResult(expectedName)(arg.name)
}
else {
val arg = call.args.apply(index)
checkCast(arg, classOf[PositionalArgument])
}
val expectedExpr = toSparkLiteral(expectedValue, expectedType)
val actualExpr = call.args.apply(index).expr
assertResult(expectedExpr.dataType)(actualExpr.dataType)
}
private def toSparkLiteral(value: Any, dataType: DataType) = Literal.apply(value, dataType)
private def checkCast[T](value: Any, expectedClass: Class[T]) = {
assertResult(true)(expectedClass.isInstance(value))
expectedClass.cast(value)
}
}

View File

@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
class TestCallProcedure extends TestHoodieSqlBase {
test("Test Call show_commits Procedure") {
withTempDir { tmp =>
val tableName = generateTableName
// create table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| tblproperties (
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
// insert data to table
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500")
// Check required fields
checkExceptionContain(s"""call show_commits(limit => 10)""")(
s"Argument: table is required")
// collect commits for table
val commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect()
assertResult(2) {
commits.length
}
}
}
test("Test Call show_commits_metadata Procedure") {
withTempDir { tmp =>
val tableName = generateTableName
// create table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| tblproperties (
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
// insert data to table
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
// Check required fields
checkExceptionContain(s"""call show_commits_metadata(limit => 10)""")(
s"Argument: table is required")
// collect commits for table
val commits = spark.sql(s"""call show_commits_metadata(table => '$tableName', limit => 10)""").collect()
assertResult(1) {
commits.length
}
}
}
test("Test Call rollback_to_instant Procedure") {
withTempDir { tmp =>
val tableName = generateTableName
// create table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| tblproperties (
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
// insert data to table
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500")
spark.sql(s"insert into $tableName select 3, 'a3', 30, 2000")
// Check required fields
checkExceptionContain(s"""call rollback_to_instant(table => '$tableName')""")(
s"Argument: instant_time is required")
// 3 commits are left before rollback
var commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect()
assertResult(3){commits.length}
// Call rollback_to_instant Procedure with Named Arguments
var instant_time = commits(0).get(0).toString
checkAnswer(s"""call rollback_to_instant(table => '$tableName', instant_time => '$instant_time')""")(Seq(true))
// Call rollback_to_instant Procedure with Positional Arguments
instant_time = commits(1).get(0).toString
checkAnswer(s"""call rollback_to_instant('$tableName', '$instant_time')""")(Seq(true))
// 1 commits are left after rollback
commits = spark.sql(s"""call show_commits(table => '$tableName', limit => 10)""").collect()
assertResult(1){commits.length}
}
}
}