1
0

[HUDI-1659] Basic Implement Of Spark Sql Support For Hoodie (#2645)

Main functions:
Support create table for hoodie.
Support CTAS.
Support Insert for hoodie. Including dynamic partition and static partition insert.
Support MergeInto for hoodie.
Support DELETE
Support UPDATE
Both support spark2 & spark3 based on DataSourceV1.

Main changes:
Add sql parser for spark2.
Add HoodieAnalysis for sql resolve and logical plan rewrite.
Add commands implementation for CREATE TABLE、INSERT、MERGE INTO & CTAS.
In order to push down the update&insert logical to the HoodieRecordPayload for MergeInto, I make same change to the
HoodieWriteHandler and other related classes.
1、Add the inputSchema for parser the incoming record. This is because the inputSchema for MergeInto is different from writeSchema as there are some transforms in the update& insert expression.
2、Add WRITE_SCHEMA to HoodieWriteConfig to pass the write schema for merge into.
3、Pass properties to HoodieRecordPayload#getInsertValue to pass the insert expression and table schema.


Verify this pull request
Add TestCreateTable for test create hoodie tables and CTAS.
Add TestInsertTable for test insert hoodie tables.
Add TestMergeIntoTable for test merge hoodie tables.
Add TestUpdateTable for test update hoodie tables.
Add TestDeleteTable for test delete hoodie tables.
Add TestSqlStatement for test supported ddl/dml currently.
This commit is contained in:
pengzhiwei
2021-06-08 14:24:32 +08:00
committed by GitHub
parent cf83f10f5b
commit f760ec543e
86 changed files with 7346 additions and 255 deletions

View File

@@ -21,10 +21,10 @@
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>hudi-spark-common</artifactId>
<artifactId>hudi-spark-common_${scala.binary.version}</artifactId>
<version>0.9.0-SNAPSHOT</version>
<name>hudi-spark-common</name>
<name>hudi-spark-common_${scala.binary.version}</name>
<packaging>jar</packaging>
<properties>

View File

@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
import org.apache.spark.sql.execution.datasources.SparkParsePartitionUtil
import org.apache.spark.sql.internal.SQLConf
/**
* An interface to adapter the difference between spark2 and spark3
* in some spark related class.
*/
trait SparkAdapter extends Serializable {
/**
* Create the SparkRowSerDe.
*/
def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe
/**
* Convert a AliasIdentifier to TableIdentifier.
*/
def toTableIdentify(aliasId: AliasIdentifier): TableIdentifier
/**
* Convert a UnresolvedRelation to TableIdentifier.
*/
def toTableIdentify(relation: UnresolvedRelation): TableIdentifier
/**
* Create Join logical plan.
*/
def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join
/**
* Test if the logical plan is a Insert Into LogicalPlan.
*/
def isInsertInto(plan: LogicalPlan): Boolean
/**
* Get the member of the Insert Into LogicalPlan.
*/
def getInsertIntoChildren(plan: LogicalPlan):
Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)]
/**
* Create a Insert Into LogicalPlan.
*/
def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]],
query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan
/**
* Create the hoodie's extended spark sql parser.
*/
def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = None
/**
* Create the SparkParsePartitionUtil.
*/
def createSparkParsePartitionUtil(conf: SQLConf): SparkParsePartitionUtil
/**
* Create Like expression.
*/
def createLike(left: Expression, right: Expression): Expression
}

View File

@@ -201,7 +201,7 @@
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark-common</artifactId>
<artifactId>hudi-spark-common_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
@@ -210,15 +210,17 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark2_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark3_2.12</artifactId>
<artifactId>${hudi.spark.module}_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
<groupId>org.apache.hudi</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<!-- Logging -->
@@ -272,6 +274,11 @@
<artifactId>spark-sql_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>

View File

@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi.sql;
import org.apache.avro.generic.IndexedRecord;
/***
* A interface for CodeGen to execute expressions on the record
* and return the results with a array for each expression.
*/
public interface IExpressionEvaluator {
/**
* Evaluate the result of the expressions based on the record.
*/
Object[] eval(IndexedRecord record);
/**
* Get the code of the expressions. This is used for debug.
*/
String getCode();
}

View File

@@ -71,7 +71,7 @@ case class HoodieFileIndex(
schemaSpec: Option[StructType],
options: Map[String, String],
@transient fileStatusCache: FileStatusCache = NoopCache)
extends FileIndex with Logging {
extends FileIndex with Logging with SparkAdapterSupport {
private val basePath = metaClient.getBasePath
@@ -247,7 +247,7 @@ case class HoodieFileIndex(
.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(SQLConf.get.sessionLocalTimeZone)
val sparkParsePartitionUtil = HoodieSparkUtils.createSparkParsePartitionUtil(spark
val sparkParsePartitionUtil = sparkAdapter.createSparkParsePartitionUtil(spark
.sessionState.conf)
// Convert partition path to PartitionRowPath
val partitionRowPaths = partitionPaths.map { partitionPath =>
@@ -323,17 +323,20 @@ case class HoodieFileIndex(
}
// Fetch the rest from the file system.
val fetchedPartition2Files =
spark.sparkContext.parallelize(pathToFetch, Math.min(pathToFetch.size, maxListParallelism))
.map { partitionRowPath =>
// Here we use a LocalEngineContext to get the files in the partition.
// We can do this because the TableMetadata.getAllFilesInPartition only rely on the
// hadoopConf of the EngineContext.
val engineContext = new HoodieLocalEngineContext(serializableConf.get())
val filesInPartition = FSUtils.getFilesInPartition(engineContext, metadataConfig,
if (pathToFetch.nonEmpty) {
spark.sparkContext.parallelize(pathToFetch, Math.min(pathToFetch.size, maxListParallelism))
.map { partitionRowPath =>
// Here we use a LocalEngineContext to get the files in the partition.
// We can do this because the TableMetadata.getAllFilesInPartition only rely on the
// hadoopConf of the EngineContext.
val engineContext = new HoodieLocalEngineContext(serializableConf.get())
val filesInPartition = FSUtils.getFilesInPartition(engineContext, metadataConfig,
basePath, partitionRowPath.fullPartitionPath(basePath))
(partitionRowPath, filesInPartition)
}.collect().map(f => f._1 -> f._2).toMap
(partitionRowPath, filesInPartition)
}.collect().map(f => f._1 -> f._2).toMap
} else {
Map.empty[PartitionRowPath, Array[FileStatus]]
}
// Update the fileStatusCache
fetchedPartition2Files.foreach {
case (partitionRowPath, filesInPartition) =>

View File

@@ -19,6 +19,7 @@ package org.apache.hudi
import java.util
import java.util.Properties
import org.apache.avro.Schema
import org.apache.avro.generic.GenericRecord
import org.apache.hadoop.conf.Configuration
@@ -45,6 +46,7 @@ import org.apache.spark.SPARK_VERSION
import org.apache.spark.SparkContext
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.hudi.HoodieSqlUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.SCHEMA_STRING_LENGTH_THRESHOLD
import org.apache.spark.sql.types.StructType
@@ -54,7 +56,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ListBuffer
import org.apache.hudi.common.table.HoodieTableConfig.{DEFAULT_ARCHIVELOG_FOLDER, HOODIE_ARCHIVELOG_FOLDER_PROP_NAME}
private[hudi] object HoodieSparkSqlWriter {
object HoodieSparkSqlWriter {
private val log = LogManager.getLogger(getClass)
private var tableExists: Boolean = false
@@ -450,11 +452,13 @@ private[hudi] object HoodieSparkSqlWriter {
// The following code refers to the spark code in
// https://github.com/apache/spark/blob/master/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
// Sync schema with meta fields
val schemaWithMetaFields = HoodieSqlUtils.addMetaFields(schema)
val partitionSet = parameters(HIVE_PARTITION_FIELDS_OPT_KEY)
.split(",").map(_.trim).filter(!_.isEmpty).toSet
val threshold = sqlConf.getConf(SCHEMA_STRING_LENGTH_THRESHOLD)
val (partitionCols, dataCols) = schema.partition(c => partitionSet.contains(c.name))
val (partitionCols, dataCols) = schemaWithMetaFields.partition(c => partitionSet.contains(c.name))
val reOrderedType = StructType(dataCols ++ partitionCols)
val schemaParts = reOrderedType.json.grouped(threshold).toSeq

View File

@@ -21,23 +21,20 @@ package org.apache.hudi
import org.apache.avro.Schema
import org.apache.avro.generic.GenericRecord
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.hudi.common.model.HoodieRecord
import org.apache.spark.SPARK_VERSION
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.avro.SchemaConverters
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal}
import org.apache.spark.sql.execution.datasources.{FileStatusCache, InMemoryFileIndex, Spark2ParsePartitionUtil, Spark3ParsePartitionUtil, SparkParsePartitionUtil}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.execution.datasources.{FileStatusCache, InMemoryFileIndex}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import scala.collection.JavaConverters._
object HoodieSparkUtils {
object HoodieSparkUtils extends SparkAdapterSupport {
def getMetaSchema: StructType = {
StructType(HoodieRecord.HOODIE_META_COLUMNS.asScala.map(col => {
@@ -102,7 +99,7 @@ object HoodieSparkUtils {
// Use the Avro schema to derive the StructType which has the correct nullability information
val dataType = SchemaConverters.toSqlType(avroSchema).dataType.asInstanceOf[StructType]
val encoder = RowEncoder.apply(dataType).resolveAndBind()
val deserializer = HoodieSparkUtils.createRowSerDe(encoder)
val deserializer = sparkAdapter.createSparkRowSerDe(encoder)
df.queryExecution.toRdd.map(row => deserializer.deserializeRow(row))
.mapPartitions { records =>
if (records.isEmpty) Iterator.empty
@@ -113,24 +110,6 @@ object HoodieSparkUtils {
}
}
def createRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
// TODO remove Spark2RowSerDe if Spark 2.x support is dropped
if (SPARK_VERSION.startsWith("2.")) {
new Spark2RowSerDe(encoder)
} else {
new Spark3RowSerDe(encoder)
}
}
def createSparkParsePartitionUtil(conf: SQLConf): SparkParsePartitionUtil = {
// TODO remove Spark2RowSerDe if Spark 2.x support is dropped
if (SPARK_VERSION.startsWith("2.")) {
new Spark2ParsePartitionUtil
} else {
new Spark3ParsePartitionUtil(conf)
}
}
/**
* Convert Filters to Catalyst Expressions and joined by And. If convert success return an
* Non-Empty Option[Expression],or else return None.
@@ -204,15 +183,15 @@ object HoodieSparkUtils {
case StringStartsWith(attribute, value) =>
val leftExp = toAttribute(attribute, tableSchema)
val rightExp = Literal.create(s"$value%")
org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp)
sparkAdapter.createLike(leftExp, rightExp)
case StringEndsWith(attribute, value) =>
val leftExp = toAttribute(attribute, tableSchema)
val rightExp = Literal.create(s"%$value")
org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp)
sparkAdapter.createLike(leftExp, rightExp)
case StringContains(attribute, value) =>
val leftExp = toAttribute(attribute, tableSchema)
val rightExp = Literal.create(s"%$value%")
org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp)
sparkAdapter.createLike(leftExp, rightExp)
case _=> null
}
)

View File

@@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred.JobConf
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.avro.SchemaConverters
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{FileStatusCache, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -61,8 +62,17 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext,
private val jobConf = new JobConf(conf)
// use schema from latest metadata, if not present, read schema from the data file
private val schemaUtil = new TableSchemaResolver(metaClient)
private val tableAvroSchema = schemaUtil.getTableAvroSchema
private val tableStructSchema = AvroConversionUtils.convertAvroSchemaToStructType(tableAvroSchema)
private lazy val tableAvroSchema = {
try {
schemaUtil.getTableAvroSchema
} catch {
case _: Throwable => // If there is no commit in the table, we cann't get the schema
// with schemaUtil, use the userSchema instead.
SchemaConverters.toAvroType(userSchema)
}
}
private lazy val tableStructSchema = AvroConversionUtils.convertAvroSchemaToStructType(tableAvroSchema)
private val mergeType = optParams.getOrElse(
DataSourceReadOptions.REALTIME_MERGE_OPT_KEY,
DataSourceReadOptions.DEFAULT_REALTIME_MERGE_OPT_VAL)
@@ -163,18 +173,23 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext,
metaClient.getActiveTimeline.getCommitsTimeline
.filterCompletedInstants, fileStatuses.toArray)
val latestFiles: List[HoodieBaseFile] = fsView.getLatestBaseFiles.iterator().asScala.toList
val latestCommit = fsView.getLastInstant.get().getTimestamp
val fileGroup = HoodieRealtimeInputFormatUtils.groupLogsByBaseFile(conf, latestFiles.asJava).asScala
val fileSplits = fileGroup.map(kv => {
val baseFile = kv._1
val logPaths = if (kv._2.isEmpty) Option.empty else Option(kv._2.asScala.toList)
val filePath = MergeOnReadSnapshotRelation.getFilePath(baseFile.getFileStatus.getPath)
val partitionedFile = PartitionedFile(InternalRow.empty, filePath, 0, baseFile.getFileLen)
HoodieMergeOnReadFileSplit(Option(partitionedFile), logPaths, latestCommit,
metaClient.getBasePath, maxCompactionMemoryInBytes, mergeType)
}).toList
fileSplits
if (!fsView.getLastInstant.isPresent) { // Return empty list if the table has no commit
List.empty
} else {
val latestCommit = fsView.getLastInstant.get().getTimestamp
val fileGroup = HoodieRealtimeInputFormatUtils.groupLogsByBaseFile(conf, latestFiles.asJava).asScala
val fileSplits = fileGroup.map(kv => {
val baseFile = kv._1
val logPaths = if (kv._2.isEmpty) Option.empty else Option(kv._2.asScala.toList)
val filePath = MergeOnReadSnapshotRelation.getFilePath(baseFile.getFileStatus.getPath)
val partitionedFile = PartitionedFile(InternalRow.empty, filePath, 0, baseFile.getFileLen)
HoodieMergeOnReadFileSplit(Option(partitionedFile), logPaths, latestCommit,
metaClient.getBasePath, maxCompactionMemoryInBytes, mergeType)
}).toList
fileSplits
}
}
}
}

View File

@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi
import org.apache.spark.sql.hudi.{HoodieSqlUtils, SparkAdapter}
/**
* Use the SparkAdapterSupport trait to get the SparkAdapter when we
* need to adapt the difference between spark2 and spark3.
*/
trait SparkAdapterSupport {
lazy val sparkAdapter: SparkAdapter = {
val adapterClass = if (HoodieSqlUtils.isSpark3) {
"org.apache.spark.sql.adapter.Spark3Adapter"
} else {
"org.apache.spark.sql.adapter.Spark2Adapter"
}
getClass.getClassLoader.loadClass(adapterClass)
.newInstance().asInstanceOf[SparkAdapter]
}
}

View File

@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive
import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkConf
import org.apache.spark.sql.hive.client.HiveClient
object HiveClientUtils {
def newClientForMetadata(conf: SparkConf, hadoopConf: Configuration): HiveClient = {
HiveUtils.newClientForMetadata(conf, hadoopConf)
}
}

View File

@@ -0,0 +1,213 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import org.apache.hudi.DataSourceWriteOptions
import org.apache.hudi.common.model.DefaultHoodieRecordPayload
import org.apache.hudi.common.table.HoodieTableConfig
/**
* The HoodieOptionConfig defines some short name for the hoodie
* option key and value.
*/
object HoodieOptionConfig {
/**
* The short name for the value of COW_TABLE_TYPE_OPT_VAL.
*/
val SQL_VALUE_TABLE_TYPE_COW = "cow"
/**
* The short name for the value of MOR_TABLE_TYPE_OPT_VAL.
*/
val SQL_VALUE_TABLE_TYPE_MOR = "mor"
val SQL_KEY_TABLE_PRIMARY_KEY: HoodieOption[String] = buildConf()
.withSqlKey("primaryKey")
.withHoodieKey(DataSourceWriteOptions.RECORDKEY_FIELD_OPT_KEY)
.withTableConfigKey(HoodieTableConfig.HOODIE_TABLE_RECORDKEY_FIELDS)
.build()
val SQL_KEY_TABLE_TYPE: HoodieOption[String] = buildConf()
.withSqlKey("type")
.withHoodieKey(DataSourceWriteOptions.TABLE_TYPE_OPT_KEY)
.withTableConfigKey(HoodieTableConfig.HOODIE_TABLE_TYPE_PROP_NAME)
.defaultValue(SQL_VALUE_TABLE_TYPE_COW)
.build()
val SQL_KEY_PRECOMBINE_FIELD: HoodieOption[String] = buildConf()
.withSqlKey("preCombineField")
.withHoodieKey(DataSourceWriteOptions.PRECOMBINE_FIELD_OPT_KEY)
.withTableConfigKey(HoodieTableConfig.HOODIE_TABLE_PRECOMBINE_FIELD)
.build()
val SQL_PAYLOAD_CLASS: HoodieOption[String] = buildConf()
.withSqlKey("payloadClass")
.withHoodieKey(DataSourceWriteOptions.PAYLOAD_CLASS_OPT_KEY)
.withTableConfigKey(HoodieTableConfig.HOODIE_PAYLOAD_CLASS_PROP_NAME)
.defaultValue(classOf[DefaultHoodieRecordPayload].getName)
.build()
/**
* The mapping of the sql short name key to the hoodie's config key.
*/
private lazy val keyMapping: Map[String, String] = {
HoodieOptionConfig.getClass.getDeclaredFields
.filter(f => f.getType == classOf[HoodieOption[_]])
.map(f => {f.setAccessible(true); f.get(HoodieOptionConfig).asInstanceOf[HoodieOption[_]]})
.map(option => option.sqlKeyName -> option.hoodieKeyName)
.toMap
}
/**
* The mapping of the sql short name key to the hoodie table config key
* defined in HoodieTableConfig.
*/
private lazy val keyTableConfigMapping: Map[String, String] = {
HoodieOptionConfig.getClass.getDeclaredFields
.filter(f => f.getType == classOf[HoodieOption[_]])
.map(f => {f.setAccessible(true); f.get(HoodieOptionConfig).asInstanceOf[HoodieOption[_]]})
.filter(_.tableConfigKey.isDefined)
.map(option => option.sqlKeyName -> option.tableConfigKey.get)
.toMap
}
private lazy val tableConfigKeyToSqlKey: Map[String, String] =
keyTableConfigMapping.map(f => f._2 -> f._1)
/**
* Mapping of the short sql value to the hoodie's config value
*/
private val valueMapping: Map[String, String] = Map (
SQL_VALUE_TABLE_TYPE_COW -> DataSourceWriteOptions.COW_TABLE_TYPE_OPT_VAL,
SQL_VALUE_TABLE_TYPE_MOR -> DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL
)
private lazy val reverseValueMapping = valueMapping.map(f => f._2 -> f._1)
/**
* Mapping the sql's short name key/value in the options to the hoodie's config key/value.
* @param options
* @return
*/
def mappingSqlOptionToHoodieParam(options: Map[String, String]): Map[String, String] = {
options.map (kv =>
keyMapping.getOrElse(kv._1, kv._1) -> valueMapping.getOrElse(kv._2, kv._2))
}
/**
* Mapping the sql options to the hoodie table config which used to store to the hoodie
* .properites when create the table.
* @param options
* @return
*/
def mappingSqlOptionToTableConfig(options: Map[String, String]): Map[String, String] = {
defaultTableConfig ++
options.filterKeys(k => keyTableConfigMapping.contains(k))
.map(kv => keyTableConfigMapping(kv._1) -> valueMapping.getOrElse(kv._2, kv._2))
}
/**
* Mapping the table config (loaded from the hoodie.properties) to the sql options.
* @param options
* @return
*/
def mappingTableConfigToSqlOption(options: Map[String, String]): Map[String, String] = {
options.filterKeys(k => tableConfigKeyToSqlKey.contains(k))
.map(kv => tableConfigKeyToSqlKey(kv._1) -> reverseValueMapping.getOrElse(kv._2, kv._2))
}
private lazy val defaultTableConfig: Map[String, String] = {
HoodieOptionConfig.getClass.getDeclaredFields
.filter(f => f.getType == classOf[HoodieOption[_]])
.map(f => {f.setAccessible(true); f.get(HoodieOptionConfig).asInstanceOf[HoodieOption[_]]})
.filter(option => option.tableConfigKey.isDefined && option.defaultValue.isDefined)
.map(option => option.tableConfigKey.get ->
valueMapping.getOrElse(option.defaultValue.get.toString, option.defaultValue.get.toString))
.toMap
}
/**
* Get the primary key from the table options.
* @param options
* @return
*/
def getPrimaryColumns(options: Map[String, String]): Array[String] = {
val params = mappingSqlOptionToHoodieParam(options)
params.get(DataSourceWriteOptions.RECORDKEY_FIELD_OPT_KEY)
.map(_.split(",").filter(_.length > 0))
.getOrElse(Array.empty)
}
/**
* Get the table type from the table options.
* @param options
* @return
*/
def getTableType(options: Map[String, String]): String = {
val params = mappingSqlOptionToHoodieParam(options)
params.getOrElse(DataSourceWriteOptions.TABLE_TYPE_OPT_KEY,
DataSourceWriteOptions.DEFAULT_TABLE_TYPE_OPT_VAL)
}
def getPreCombineField(options: Map[String, String]): Option[String] = {
val params = mappingSqlOptionToHoodieParam(options)
params.get(DataSourceWriteOptions.PRECOMBINE_FIELD_OPT_KEY)
}
def buildConf[T](): HoodieOptions[T] = {
new HoodieOptions[T]
}
}
case class HoodieOption[T](sqlKeyName: String, hoodieKeyName: String,
defaultValue: Option[T], tableConfigKey: Option[String] = None)
class HoodieOptions[T] {
private var sqlKeyName: String = _
private var hoodieKeyName: String =_
private var tableConfigKey: String =_
private var defaultValue: T =_
def withSqlKey(sqlKeyName: String): HoodieOptions[T] = {
this.sqlKeyName = sqlKeyName
this
}
def withHoodieKey(hoodieKeyName: String): HoodieOptions[T] = {
this.hoodieKeyName = hoodieKeyName
this
}
def withTableConfigKey(tableConfigKey: String): HoodieOptions[T] = {
this.tableConfigKey = tableConfigKey
this
}
def defaultValue(defaultValue: T): HoodieOptions[T] = {
this.defaultValue = defaultValue
this
}
def build(): HoodieOption[T] = {
HoodieOption(sqlKeyName, hoodieKeyName, Option(defaultValue), Option(tableConfigKey))
}
}

View File

@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import org.apache.hudi.SparkAdapterSupport
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.hudi.analysis.HoodieAnalysis
/**
* The Hoodie SparkSessionExtension for extending the syntax and add the rules.
*/
class HoodieSparkSessionExtension extends (SparkSessionExtensions => Unit)
with SparkAdapterSupport{
override def apply(extensions: SparkSessionExtensions): Unit = {
// For spark2, we add a extended sql parser
if (sparkAdapter.createExtendedSparkParser.isDefined) {
extensions.injectParser { (session, parser) =>
sparkAdapter.createExtendedSparkParser.get(session, parser)
}
}
HoodieAnalysis.customResolutionRules().foreach { rule =>
extensions.injectResolutionRule { session =>
rule(session)
}
}
HoodieAnalysis.customPostHocResolutionRules().foreach { rule =>
extensions.injectPostHocResolutionRule { session =>
rule(session)
}
}
}
}

View File

@@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import scala.collection.JavaConverters._
import java.net.URI
import java.util.Locale
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hudi.SparkAdapterSupport
import org.apache.hudi.common.model.HoodieRecord
import org.apache.spark.SPARK_VERSION
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.expressions.{And, Cast, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, NullType, StringType, StructField, StructType}
import scala.collection.immutable.Map
object HoodieSqlUtils extends SparkAdapterSupport {
def isHoodieTable(table: CatalogTable): Boolean = {
table.provider.map(_.toLowerCase(Locale.ROOT)).orNull == "hudi"
}
def isHoodieTable(tableId: TableIdentifier, spark: SparkSession): Boolean = {
val table = spark.sessionState.catalog.getTableMetadata(tableId)
isHoodieTable(table)
}
def isHoodieTable(table: LogicalPlan, spark: SparkSession): Boolean = {
tripAlias(table) match {
case LogicalRelation(_, _, Some(tbl), _) => isHoodieTable(tbl)
case relation: UnresolvedRelation =>
isHoodieTable(sparkAdapter.toTableIdentify(relation), spark)
case _=> false
}
}
private def tripAlias(plan: LogicalPlan): LogicalPlan = {
plan match {
case SubqueryAlias(_, relation: LogicalPlan) =>
tripAlias(relation)
case other =>
other
}
}
/**
* Add the hoodie meta fields to the schema.
* @param schema
* @return
*/
def addMetaFields(schema: StructType): StructType = {
val metaFields = HoodieRecord.HOODIE_META_COLUMNS.asScala
// filter the meta field to avoid duplicate field.
val dataFields = schema.fields.filterNot(f => metaFields.contains(f.name))
val fields = metaFields.map(StructField(_, StringType)) ++ dataFields
StructType(fields)
}
private lazy val metaFields = HoodieRecord.HOODIE_META_COLUMNS.asScala.toSet
/**
* Remove the meta fields from the schema.
* @param schema
* @return
*/
def removeMetaFields(schema: StructType): StructType = {
StructType(schema.fields.filterNot(f => isMetaField(f.name)))
}
def isMetaField(name: String): Boolean = {
metaFields.contains(name)
}
def removeMetaFields(df: DataFrame): DataFrame = {
val withoutMetaColumns = df.logicalPlan.output
.filterNot(attr => isMetaField(attr.name))
.map(new Column(_))
if (withoutMetaColumns.length != df.logicalPlan.output.size) {
df.select(withoutMetaColumns: _*)
} else {
df
}
}
/**
* Get the table location.
* @param tableId
* @param spark
* @return
*/
def getTableLocation(tableId: TableIdentifier, spark: SparkSession): Option[String] = {
val table = spark.sessionState.catalog.getTableMetadata(tableId)
getTableLocation(table, spark)
}
def getTableLocation(table: CatalogTable, sparkSession: SparkSession): Option[String] = {
val uri = if (table.tableType == CatalogTableType.MANAGED && isHoodieTable(table)) {
Some(sparkSession.sessionState.catalog.defaultTablePath(table.identifier))
} else {
table.storage.locationUri
}
val conf = sparkSession.sessionState.newHadoopConf()
uri.map(makePathQualified(_, conf))
.map(removePlaceHolder)
}
private def removePlaceHolder(path: String): String = {
if (path == null || path.length == 0) {
path
} else if (path.endsWith("-__PLACEHOLDER__")) {
path.substring(0, path.length() - 16)
} else {
path
}
}
def makePathQualified(path: URI, hadoopConf: Configuration): String = {
val hadoopPath = new Path(path)
val fs = hadoopPath.getFileSystem(hadoopConf)
fs.makeQualified(hadoopPath).toUri.toString
}
def castIfNeeded(child: Expression, dataType: DataType, conf: SQLConf): Expression = {
child match {
case Literal(nul, NullType) => Literal(nul, dataType)
case _ => if (child.dataType != dataType)
Cast(child, dataType, Option(conf.sessionLocalTimeZone)) else child
}
}
/**
* Split the expression to a sub expression seq by the AND operation.
* @param expression
* @return
*/
def splitByAnd(expression: Expression): Seq[Expression] = {
expression match {
case And(left, right) =>
splitByAnd(left) ++ splitByAnd(right)
case exp => Seq(exp)
}
}
/**
* Append the SparkSession config and table options to the baseConfig.
* We add the "spark" prefix to hoodie's config key.
* @param spark
* @param options
* @param baseConfig
* @return
*/
def withSparkConf(spark: SparkSession, options: Map[String, String])
(baseConfig: Map[String, String]): Map[String, String] = {
baseConfig ++ // Table options has the highest priority
(spark.sessionState.conf.getAllConfs ++ HoodieOptionConfig.mappingSqlOptionToHoodieParam(options))
.filterKeys(_.startsWith("hoodie."))
}
def isSpark3: Boolean = SPARK_VERSION.startsWith("3.")
}

View File

@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import java.io.ByteArrayOutputStream
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.SparkConf
import org.apache.spark.serializer.KryoSerializer
object SerDeUtils {
private val kryoLocal = new ThreadLocal[Kryo] {
override protected def initialValue: Kryo = {
val serializer = new KryoSerializer(new SparkConf(true))
serializer.newKryo()
}
}
def toBytes(o: Any): Array[Byte] = {
val outputStream = new ByteArrayOutputStream(4096 * 5)
val output = new Output(outputStream)
try {
kryoLocal.get.writeClassAndObject(output, o)
output.flush()
} finally {
output.clear()
output.close()
}
outputStream.toByteArray
}
def toObject(bytes: Array[Byte]): Any = {
val input = new Input(bytes)
kryoLocal.get.readClassAndObject(input)
}
}

View File

@@ -0,0 +1,313 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.analysis
import org.apache.hudi.SparkAdapterSupport
import scala.collection.JavaConverters._
import org.apache.hudi.common.model.HoodieRecord
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, DeleteFromTable, InsertAction, LogicalPlan, MergeIntoTable, Project, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.CreateDataSourceTableCommand
import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation}
import org.apache.spark.sql.hudi.HoodieSqlUtils
import org.apache.spark.sql.hudi.HoodieSqlUtils._
import org.apache.spark.sql.hudi.command.{CreateHoodieTableAsSelectCommand, CreateHoodieTableCommand, DeleteHoodieTableCommand, InsertIntoHoodieTableCommand, MergeIntoHoodieTableCommand, UpdateHoodieTableCommand}
import org.apache.spark.sql.types.StringType
object HoodieAnalysis {
def customResolutionRules(): Seq[SparkSession => Rule[LogicalPlan]] =
Seq(
session => HoodieResolveReferences(session),
session => HoodieAnalysis(session)
)
def customPostHocResolutionRules(): Seq[SparkSession => Rule[LogicalPlan]] =
Seq(
session => HoodiePostAnalysisRule(session)
)
}
/**
* Rule for convert the logical plan to command.
* @param sparkSession
*/
case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan]
with SparkAdapterSupport {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan match {
// Convert to MergeIntoHoodieTableCommand
case m @ MergeIntoTable(target, _, _, _, _)
if m.resolved && isHoodieTable(target, sparkSession) =>
MergeIntoHoodieTableCommand(m)
// Convert to UpdateHoodieTableCommand
case u @ UpdateTable(table, _, _)
if u.resolved && isHoodieTable(table, sparkSession) =>
UpdateHoodieTableCommand(u)
// Convert to DeleteHoodieTableCommand
case d @ DeleteFromTable(table, _)
if d.resolved && isHoodieTable(table, sparkSession) =>
DeleteHoodieTableCommand(d)
// Convert to InsertIntoHoodieTableCommand
case l if sparkAdapter.isInsertInto(l) =>
val (table, partition, query, overwrite, _) = sparkAdapter.getInsertIntoChildren(l).get
table match {
case relation: LogicalRelation if isHoodieTable(relation, sparkSession) =>
new InsertIntoHoodieTableCommand(relation, query, partition, overwrite)
case _ =>
l
}
// Convert to CreateHoodieTableAsSelectCommand
case CreateTable(table, mode, Some(query))
if query.resolved && isHoodieTable(table) =>
CreateHoodieTableAsSelectCommand(table, mode, query)
case _=> plan
}
}
}
/**
* Rule for resolve hoodie's extended syntax or rewrite some logical plan.
* @param sparkSession
*/
case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[LogicalPlan]
with SparkAdapterSupport {
private lazy val analyzer = sparkSession.sessionState.analyzer
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
// Resolve merge into
case MergeIntoTable(target, source, mergeCondition, matchedActions, notMatchedActions)
if isHoodieTable(target, sparkSession) && target.resolved =>
val resolvedSource = analyzer.execute(source)
def isInsertOrUpdateStar(assignments: Seq[Assignment]): Boolean = {
if (assignments.isEmpty) {
true
} else {
// This is a Hack for test if it is "update set *" or "insert *" for spark3.
// As spark3's own ResolveReference will append first five columns of the target
// table(which is the hoodie meta fields) to the assignments for "update set *" and
// "insert *", so we test if the first five assignmentFieldNames is the meta fields
// to judge if it is "update set *" or "insert *".
// We can do this because under the normal case, we should not allow to update or set
// the hoodie's meta field in sql statement, it is a system field, cannot set the value
// by user.
if (HoodieSqlUtils.isSpark3) {
val assignmentFieldNames = assignments.map(_.key).map {
case attr: AttributeReference =>
attr.name
case _ => ""
}.toArray
val metaFields = HoodieRecord.HOODIE_META_COLUMNS.asScala
if (metaFields.mkString(",").startsWith(assignmentFieldNames.take(metaFields.length).mkString(","))) {
true
} else {
false
}
} else {
false
}
}
}
def resolveConditionAssignments(condition: Option[Expression],
assignments: Seq[Assignment]): (Option[Expression], Seq[Assignment]) = {
val resolvedCondition = condition.map(resolveExpressionFrom(resolvedSource)(_))
val resolvedAssignments = if (isInsertOrUpdateStar(assignments)) {
// assignments is empty means insert * or update set *
// we fill assign all the source fields to the target fields
target.output
.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
.zip(resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name)))
.map { case (targetAttr, sourceAttr) => Assignment(targetAttr, sourceAttr) }
} else {
assignments.map(assignment => {
val resolvedKey = resolveExpressionFrom(target)(assignment.key)
val resolvedValue = resolveExpressionFrom(resolvedSource, Some(target))(assignment.value)
Assignment(resolvedKey, resolvedValue)
})
}
(resolvedCondition, resolvedAssignments)
}
// Resolve the merge condition
val resolvedMergeCondition = resolveExpressionFrom(resolvedSource, Some(target))(mergeCondition)
// Resolve the matchedActions
val resolvedMatchedActions = matchedActions.map {
case UpdateAction(condition, assignments) =>
val (resolvedCondition, resolvedAssignments) =
resolveConditionAssignments(condition, assignments)
UpdateAction(resolvedCondition, resolvedAssignments)
case DeleteAction(condition) =>
val resolvedCondition = condition.map(resolveExpressionFrom(resolvedSource)(_))
DeleteAction(resolvedCondition)
}
// Resolve the notMatchedActions
val resolvedNotMatchedActions = notMatchedActions.map {
case InsertAction(condition, assignments) =>
val (resolvedCondition, resolvedAssignments) =
resolveConditionAssignments(condition, assignments)
InsertAction(resolvedCondition, resolvedAssignments)
}
// Return the resolved MergeIntoTable
MergeIntoTable(target, resolvedSource, resolvedMergeCondition,
resolvedMatchedActions, resolvedNotMatchedActions)
// Resolve update table
case UpdateTable(table, assignments, condition)
if isHoodieTable(table, sparkSession) && table.resolved =>
// Resolve condition
val resolvedCondition = condition.map(resolveExpressionFrom(table)(_))
// Resolve assignments
val resolvedAssignments = assignments.map(assignment => {
val resolvedKey = resolveExpressionFrom(table)(assignment.key)
val resolvedValue = resolveExpressionFrom(table)(assignment.value)
Assignment(resolvedKey, resolvedValue)
})
// Return the resolved UpdateTable
UpdateTable(table, resolvedAssignments, resolvedCondition)
// Resolve Delete Table
case DeleteFromTable(table, condition)
if isHoodieTable(table, sparkSession) && table.resolved =>
// Resolve condition
val resolvedCondition = condition.map(resolveExpressionFrom(table)(_))
// Return the resolved DeleteTable
DeleteFromTable(table, resolvedCondition)
// Append the meta field to the insert query to walk through the validate for the
// number of insert fields with the number of the target table fields.
case l if sparkAdapter.isInsertInto(l) =>
val (table, partition, query, overwrite, ifPartitionNotExists) =
sparkAdapter.getInsertIntoChildren(l).get
if (isHoodieTable(table, sparkSession) && query.resolved &&
!containUnResolvedStar(query) &&
!checkAlreadyAppendMetaField(query)) {
val metaFields = HoodieRecord.HOODIE_META_COLUMNS.asScala.map(
Alias(Literal.create(null, StringType), _)()).toArray[NamedExpression]
val newQuery = query match {
case project: Project =>
val withMetaFieldProjects =
metaFields ++ project.projectList
// Append the meta fields to the insert query.
Project(withMetaFieldProjects, project.child)
case _ =>
val withMetaFieldProjects = metaFields ++ query.output
Project(withMetaFieldProjects, query)
}
sparkAdapter.createInsertInto(table, partition, newQuery, overwrite, ifPartitionNotExists)
} else {
l
}
case p => p
}
private def containUnResolvedStar(query: LogicalPlan): Boolean = {
query match {
case project: Project => project.projectList.exists(_.isInstanceOf[UnresolvedStar])
case _ => false
}
}
/**
* Check if the the query of insert statement has already append the meta fields to avoid
* duplicate append.
* @param query
* @return
*/
private def checkAlreadyAppendMetaField(query: LogicalPlan): Boolean = {
query.output.take(HoodieRecord.HOODIE_META_COLUMNS.size())
.filter(isMetaField)
.map {
case AttributeReference(name, _, _, _) => name.toLowerCase
case other => throw new IllegalArgumentException(s"$other should not be a hoodie meta field")
}.toSet == HoodieRecord.HOODIE_META_COLUMNS.asScala.toSet
}
private def isMetaField(exp: Expression): Boolean = {
val metaFields = HoodieRecord.HOODIE_META_COLUMNS.asScala.toSet
exp match {
case Alias(_, name) if metaFields.contains(name.toLowerCase) => true
case AttributeReference(name, _, _, _) if metaFields.contains(name.toLowerCase) => true
case _=> false
}
}
/**
* Resolve the expression.
* 1、 Fake a a project for the expression based on the source plan
* 2、 Resolve the fake project
* 3、 Get the resolved expression from the faked project
* @param left The left source plan for the expression.
* @param right The right source plan for the expression.
* @param expression The expression to resolved.
* @return The resolved expression.
*/
private def resolveExpressionFrom(left: LogicalPlan, right: Option[LogicalPlan] = None)
(expression: Expression): Expression = {
// Fake a project for the expression based on the source plan.
val fakeProject = if (right.isDefined) {
Project(Seq(Alias(expression, "_c0")()),
sparkAdapter.createJoin(left, right.get, Inner))
} else {
Project(Seq(Alias(expression, "_c0")()),
left)
}
// Resolve the fake project
val resolvedProject =
analyzer.ResolveReferences.apply(fakeProject).asInstanceOf[Project]
val unResolvedAttrs = resolvedProject.projectList.head.collect {
case attr: UnresolvedAttribute => attr
}
if (unResolvedAttrs.nonEmpty) {
throw new AnalysisException(s"Cannot resolve ${unResolvedAttrs.mkString(",")} in " +
s"${expression.sql}, the input " + s"columns is: [${fakeProject.child.output.mkString(", ")}]")
}
// Fetch the resolved expression from the fake project.
resolvedProject.projectList.head.asInstanceOf[Alias].child
}
}
/**
* Rule for rewrite some spark commands to hudi's implementation.
* @param sparkSession
*/
case class HoodiePostAnalysisRule(sparkSession: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan match {
// Rewrite the CreateDataSourceTableCommand to CreateHoodieTableCommand
case CreateDataSourceTableCommand(table, ignoreIfExists)
if isHoodieTable(table) =>
CreateHoodieTableCommand(table, ignoreIfExists)
case _ => plan
}
}
}

View File

@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.hudi.HoodieSqlUtils.getTableLocation
/**
* Command for create table as query statement.
*/
case class CreateHoodieTableAsSelectCommand(
table: CatalogTable,
mode: SaveMode,
query: LogicalPlan) extends DataWritingCommand {
override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
assert(table.tableType != CatalogTableType.VIEW)
assert(table.provider.isDefined)
val sessionState = sparkSession.sessionState
val db = table.identifier.database.getOrElse(sessionState.catalog.getCurrentDatabase)
val tableIdentWithDB = table.identifier.copy(database = Some(db))
val tableName = tableIdentWithDB.unquotedString
if (sessionState.catalog.tableExists(tableIdentWithDB)) {
assert(mode != SaveMode.Overwrite,
s"Expect the table $tableName has been dropped when the save mode is Overwrite")
if (mode == SaveMode.ErrorIfExists) {
throw new RuntimeException(s"Table $tableName already exists. You need to drop it first.")
}
if (mode == SaveMode.Ignore) {
// Since the table already exists and the save mode is Ignore, we will just return.
// scalastyle:off
return Seq.empty
// scalastyle:on
}
}
val tablePath = getTableLocation(table, sparkSession)
.getOrElse(s"Missing path for table ${table.identifier}")
val conf = sparkSession.sessionState.newHadoopConf()
assert(CreateHoodieTableCommand.isEmptyPath(tablePath, conf),
s"Path '$tablePath' should be empty for CTAS")
// ReOrder the query which move the partition columns to the last of the project list
val reOrderedQuery = reOrderPartitionColumn(query, table.partitionColumnNames)
val tableWithSchema = table.copy(schema = reOrderedQuery.schema)
// Execute the insert query
try {
val success = InsertIntoHoodieTableCommand.run(sparkSession, tableWithSchema, reOrderedQuery, Map.empty,
mode == SaveMode.Overwrite, refreshTable = false)
if (success) {
// If write success, create the table in catalog if it has not synced to the
// catalog by the meta sync.
if (!sparkSession.sessionState.catalog.tableExists(tableIdentWithDB)) {
// Create the table
val createTableCommand = CreateHoodieTableCommand(tableWithSchema, mode == SaveMode.Ignore)
createTableCommand.createTableInCatalog(sparkSession, checkPathForManagedTable = false)
}
} else { // failed to insert data, clear table path
clearTablePath(tablePath, conf)
}
} catch {
case e: Throwable => // failed to insert data, clear table path
clearTablePath(tablePath, conf)
throw e
}
Seq.empty[Row]
}
private def clearTablePath(tablePath: String, conf: Configuration): Unit = {
val path = new Path(tablePath)
val fs = path.getFileSystem(conf)
fs.delete(path, true)
}
override def outputColumnNames: Seq[String] = query.output.map(_.name)
private def reOrderPartitionColumn(query: LogicalPlan,
partitionColumns: Seq[String]): LogicalPlan = {
if (partitionColumns.isEmpty) {
query
} else {
val nonPartitionAttrs = query.output.filter(p => !partitionColumns.contains(p.name))
val partitionAttrs = query.output.filter(p => partitionColumns.contains(p.name))
val reorderAttrs = nonPartitionAttrs ++ partitionAttrs
Project(reorderAttrs, query)
}
}
}

View File

@@ -0,0 +1,359 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command
import scala.collection.JavaConverters._
import java.util.{Locale, Properties}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hudi.{DataSourceWriteOptions, SparkAdapterSupport}
import org.apache.hudi.common.model.HoodieFileFormat
import org.apache.hudi.common.table.{HoodieTableMetaClient, TableSchemaResolver}
import org.apache.hudi.common.util.ValidationUtils
import org.apache.hudi.hadoop.HoodieParquetInputFormat
import org.apache.hudi.hadoop.realtime.HoodieParquetRealtimeInputFormat
import org.apache.hudi.hadoop.utils.HoodieInputFormatUtils
import org.apache.spark.{SPARK_VERSION, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.SchemaConverters
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.hive.HiveClientUtils
import org.apache.spark.sql.hive.HiveExternalCatalog._
import org.apache.spark.sql.hudi.HoodieSqlUtils._
import org.apache.spark.sql.hudi.HoodieOptionConfig
import org.apache.spark.sql.hudi.command.CreateHoodieTableCommand.{initTableIfNeed, tableExistsInPath, isEmptyPath}
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.internal.StaticSQLConf.SCHEMA_STRING_LENGTH_THRESHOLD
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import scala.collection.mutable
/**
* Command for create hoodie table.
* @param table
* @param ignoreIfExists
*/
case class CreateHoodieTableCommand(table: CatalogTable, ignoreIfExists: Boolean)
extends RunnableCommand with SparkAdapterSupport {
override def run(sparkSession: SparkSession): Seq[Row] = {
val tableName = table.identifier.unquotedString
val tableIsExists = sparkSession.sessionState.catalog.tableExists(table.identifier)
if (tableIsExists) {
if (ignoreIfExists) {
// scalastyle:off
return Seq.empty[Row]
// scalastyle:on
} else {
throw new IllegalArgumentException(s"Table $tableName already exists.")
}
}
// Create table in the catalog
val createTable = createTableInCatalog(sparkSession)
// Init the hoodie.properties
initTableIfNeed(sparkSession, createTable)
Seq.empty[Row]
}
def createTableInCatalog(sparkSession: SparkSession,
checkPathForManagedTable: Boolean = true): CatalogTable = {
assert(table.tableType != CatalogTableType.VIEW)
assert(table.provider.isDefined)
val sessionState = sparkSession.sessionState
val tableName = table.identifier.unquotedString
val path = getTableLocation(table, sparkSession)
.getOrElse(s"Missing path for table ${table.identifier}")
val conf = sparkSession.sessionState.newHadoopConf()
val isTableExists = tableExistsInPath(path, conf)
// Get the schema & table options
val (newSchema, tableOptions) = if (table.tableType == CatalogTableType.EXTERNAL &&
isTableExists) {
// If this is an external table & the table has already exists in the location,
// load the schema from the table meta.
val metaClient = HoodieTableMetaClient.builder()
.setBasePath(path)
.setConf(conf)
.build()
val schemaResolver = new TableSchemaResolver(metaClient)
val avroSchema = try Some(schemaResolver.getTableAvroSchema(false))
catch {
case _: Throwable => None
}
val tableSchema = avroSchema.map(SchemaConverters.toSqlType(_).dataType
.asInstanceOf[StructType])
// Get options from the external table
val options = HoodieOptionConfig.mappingTableConfigToSqlOption(
metaClient.getTableConfig.getProps.asScala.toMap)
val userSpecifiedSchema = table.schema
if (userSpecifiedSchema.isEmpty && tableSchema.isDefined) {
(addMetaFields(tableSchema.get), options)
} else if (userSpecifiedSchema.nonEmpty) {
(addMetaFields(userSpecifiedSchema), options)
} else {
throw new IllegalArgumentException(s"Missing schema for Create Table: $tableName")
}
} else {
assert(table.schema.nonEmpty, s"Missing schema for Create Table: $tableName")
// SPARK-19724: the default location of a managed table should be non-existent or empty.
if (checkPathForManagedTable && table.tableType == CatalogTableType.MANAGED
&& !isEmptyPath(path, conf)) {
throw new AnalysisException(s"Can not create the managed table('$tableName')" +
s". The associated location('$path') already exists.")
}
// Add the meta fields to the schema if this is a managed table or an empty external table.
(addMetaFields(table.schema), table.storage.properties)
}
val tableType = HoodieOptionConfig.getTableType(table.storage.properties)
val inputFormat = tableType match {
case DataSourceWriteOptions.COW_TABLE_TYPE_OPT_VAL =>
classOf[HoodieParquetInputFormat].getCanonicalName
case DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL =>
classOf[HoodieParquetRealtimeInputFormat].getCanonicalName
case _=> throw new IllegalArgumentException(s"UnKnow table type:$tableType")
}
val outputFormat = HoodieInputFormatUtils.getOutputFormatClassName(HoodieFileFormat.PARQUET)
val serdeFormat = HoodieInputFormatUtils.getSerDeClassName(HoodieFileFormat.PARQUET)
val newStorage = new CatalogStorageFormat(Some(new Path(path).toUri),
Some(inputFormat), Some(outputFormat), Some(serdeFormat),
table.storage.compressed, tableOptions + ("path" -> path))
val newDatabaseName = formatName(table.identifier.database
.getOrElse(sessionState.catalog.getCurrentDatabase))
val newTableName = formatName(table.identifier.table)
val newTableIdentifier = table.identifier
.copy(table = newTableName, database = Some(newDatabaseName))
val newTable = table.copy(identifier = newTableIdentifier,
schema = newSchema, storage = newStorage, createVersion = SPARK_VERSION)
// validate the table
validateTable(newTable)
// Create table in the catalog
val enableHive = "hive" == sessionState.conf.getConf(StaticSQLConf.CATALOG_IMPLEMENTATION)
if (enableHive) {
createHiveDataSourceTable(newTable, sparkSession)
} else {
sessionState.catalog.createTable(newTable, ignoreIfExists = false,
validateLocation = checkPathForManagedTable)
}
newTable
}
/**
* Create Hive table for hudi.
* Firstly, do some check for the schema & table.
* Secondly, append some table properties need for spark datasource table.
* Thirdly, create hive table using the HiveClient.
* @param table
* @param sparkSession
*/
private def createHiveDataSourceTable(table: CatalogTable, sparkSession: SparkSession): Unit = {
// check schema
verifyDataSchema(table.identifier, table.tableType, table.schema)
val dbName = table.identifier.database.get
// check database
val dbExists = sparkSession.sessionState.catalog.databaseExists(dbName)
if (!dbExists) {
throw new NoSuchDatabaseException(dbName)
}
// check table exists
if (sparkSession.sessionState.catalog.tableExists(table.identifier)) {
throw new TableAlreadyExistsException(dbName, table.identifier.table)
}
// append some table properties need for spark data source table.
val dataSourceProps = tableMetaToTableProps(sparkSession.sparkContext.conf,
table, table.schema)
val tableWithDataSourceProps = table.copy(properties = dataSourceProps)
val client = HiveClientUtils.newClientForMetadata(sparkSession.sparkContext.conf,
sparkSession.sessionState.newHadoopConf())
// create hive table.
client.createTable(tableWithDataSourceProps, ignoreIfExists)
}
private def formatName(name: String): String = {
if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT)
}
// This code is forked from org.apache.spark.sql.hive.HiveExternalCatalog#verifyDataSchema
private def verifyDataSchema(tableName: TableIdentifier,
tableType: CatalogTableType,
dataSchema: StructType): Unit = {
if (tableType != CatalogTableType.VIEW) {
val invalidChars = Seq(",", ":", ";")
def verifyNestedColumnNames(schema: StructType): Unit = schema.foreach { f =>
f.dataType match {
case st: StructType => verifyNestedColumnNames(st)
case _ if invalidChars.exists(f.name.contains) =>
val invalidCharsString = invalidChars.map(c => s"'$c'").mkString(", ")
val errMsg = "Cannot create a table having a nested column whose name contains " +
s"invalid characters ($invalidCharsString) in Hive metastore. Table: $tableName; " +
s"Column: ${f.name}"
throw new AnalysisException(errMsg)
case _ =>
}
}
dataSchema.foreach { f =>
f.dataType match {
// Checks top-level column names
case _ if f.name.contains(",") =>
throw new AnalysisException("Cannot create a table having a column whose name " +
s"contains commas in Hive metastore. Table: $tableName; Column: ${f.name}")
// Checks nested column names
case st: StructType =>
verifyNestedColumnNames(st)
case _ =>
}
}
}
}
// This code is forked from org.apache.spark.sql.hive.HiveExternalCatalog#tableMetaToTableProps
private def tableMetaToTableProps( sparkConf: SparkConf,
table: CatalogTable,
schema: StructType): Map[String, String] = {
val partitionColumns = table.partitionColumnNames
val bucketSpec = table.bucketSpec
val properties = new mutable.HashMap[String, String]
properties.put(DATASOURCE_PROVIDER, "hudi")
properties.put(CREATED_SPARK_VERSION, table.createVersion)
// Serialized JSON schema string may be too long to be stored into a single metastore table
// property. In this case, we split the JSON string and store each part as a separate table
// property.
val threshold = sparkConf.get(SCHEMA_STRING_LENGTH_THRESHOLD)
val schemaJsonString = schema.json
// Split the JSON string.
val parts = schemaJsonString.grouped(threshold).toSeq
properties.put(DATASOURCE_SCHEMA_NUMPARTS, parts.size.toString)
parts.zipWithIndex.foreach { case (part, index) =>
properties.put(s"$DATASOURCE_SCHEMA_PART_PREFIX$index", part)
}
if (partitionColumns.nonEmpty) {
properties.put(DATASOURCE_SCHEMA_NUMPARTCOLS, partitionColumns.length.toString)
partitionColumns.zipWithIndex.foreach { case (partCol, index) =>
properties.put(s"$DATASOURCE_SCHEMA_PARTCOL_PREFIX$index", partCol)
}
}
if (bucketSpec.isDefined) {
val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get
properties.put(DATASOURCE_SCHEMA_NUMBUCKETS, numBuckets.toString)
properties.put(DATASOURCE_SCHEMA_NUMBUCKETCOLS, bucketColumnNames.length.toString)
bucketColumnNames.zipWithIndex.foreach { case (bucketCol, index) =>
properties.put(s"$DATASOURCE_SCHEMA_BUCKETCOL_PREFIX$index", bucketCol)
}
if (sortColumnNames.nonEmpty) {
properties.put(DATASOURCE_SCHEMA_NUMSORTCOLS, sortColumnNames.length.toString)
sortColumnNames.zipWithIndex.foreach { case (sortCol, index) =>
properties.put(s"$DATASOURCE_SCHEMA_SORTCOL_PREFIX$index", sortCol)
}
}
}
properties.toMap
}
private def validateTable(table: CatalogTable): Unit = {
val options = table.storage.properties
// validate the pk if it exist in the table.
HoodieOptionConfig.getPrimaryColumns(options).foreach(pk => table.schema.fieldIndex(pk))
// validate the version column if it exist in the table.
HoodieOptionConfig.getPreCombineField(options).foreach(v => table.schema.fieldIndex(v))
// validate the partition columns
table.partitionColumnNames.foreach(p => table.schema.fieldIndex(p))
// validate table type
options.get(HoodieOptionConfig.SQL_KEY_TABLE_TYPE.sqlKeyName).foreach { tableType =>
ValidationUtils.checkArgument(
tableType.equalsIgnoreCase(HoodieOptionConfig.SQL_VALUE_TABLE_TYPE_COW) ||
tableType.equalsIgnoreCase(HoodieOptionConfig.SQL_VALUE_TABLE_TYPE_MOR),
s"'type' must be '${HoodieOptionConfig.SQL_VALUE_TABLE_TYPE_COW}' or " +
s"'${HoodieOptionConfig.SQL_VALUE_TABLE_TYPE_MOR}'")
}
}
}
object CreateHoodieTableCommand extends Logging {
/**
* Init the table if it is not exists.
* @param sparkSession
* @param table
* @return
*/
def initTableIfNeed(sparkSession: SparkSession, table: CatalogTable): Unit = {
val location = getTableLocation(table, sparkSession).getOrElse(
throw new IllegalArgumentException(s"Missing location for ${table.identifier}"))
val conf = sparkSession.sessionState.newHadoopConf()
// Init the hoodie table
if (!tableExistsInPath(location, conf)) {
val tableName = table.identifier.table
logInfo(s"Table $tableName is not exists, start to create the hudi table")
// Save all the table config to the hoodie.properties.
val parameters = HoodieOptionConfig.mappingSqlOptionToTableConfig(table.storage.properties)
val properties = new Properties()
properties.putAll(parameters.asJava)
HoodieTableMetaClient.withPropertyBuilder()
.fromProperties(properties)
.setTableName(tableName)
.setTableCreateSchema(SchemaConverters.toAvroType(table.schema).toString())
.setPartitionColumns(table.partitionColumnNames.mkString(","))
.initTable(conf, location)
}
}
/**
* Check if the hoodie.properties exists in the table path.
*/
def tableExistsInPath(tablePath: String, conf: Configuration): Boolean = {
val basePath = new Path(tablePath)
val fs = basePath.getFileSystem(conf)
val metaPath = new Path(basePath, HoodieTableMetaClient.METAFOLDER_NAME)
fs.exists(metaPath)
}
/**
* Check if this is a empty table path.
*/
def isEmptyPath(tablePath: String, conf: Configuration): Boolean = {
val basePath = new Path(tablePath)
val fs = basePath.getFileSystem(conf)
if (fs.exists(basePath)) {
fs.listStatus(basePath).isEmpty
} else {
true
}
}
}

View File

@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command
import org.apache.hudi.{DataSourceWriteOptions, SparkAdapterSupport}
import org.apache.hudi.DataSourceWriteOptions.{HIVE_STYLE_PARTITIONING_OPT_KEY, HIVE_SUPPORT_TIMESTAMP, KEYGENERATOR_CLASS_OPT_KEY, OPERATION_OPT_KEY, PARTITIONPATH_FIELD_OPT_KEY, RECORDKEY_FIELD_OPT_KEY}
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.config.HoodieWriteConfig.TABLE_NAME
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, SubqueryAlias}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.hudi.HoodieOptionConfig
import org.apache.spark.sql.hudi.HoodieSqlUtils._
case class DeleteHoodieTableCommand(deleteTable: DeleteFromTable) extends RunnableCommand
with SparkAdapterSupport {
private val table = deleteTable.table
private val tableId = table match {
case SubqueryAlias(name, _) => sparkAdapter.toTableIdentify(name)
case _ => throw new IllegalArgumentException(s"Illegal table: $table")
}
override def run(sparkSession: SparkSession): Seq[Row] = {
logInfo(s"start execute delete command for $tableId")
// Remove meta fields from the data frame
var df = removeMetaFields(Dataset.ofRows(sparkSession, table))
if (deleteTable.condition.isDefined) {
df = df.filter(Column(deleteTable.condition.get))
}
val config = buildHoodieConfig(sparkSession)
df.write
.format("hudi")
.mode(SaveMode.Append)
.options(config)
.save()
sparkSession.catalog.refreshTable(tableId.unquotedString)
logInfo(s"Finish execute delete command for $tableId")
Seq.empty[Row]
}
private def buildHoodieConfig(sparkSession: SparkSession): Map[String, String] = {
val targetTable = sparkSession.sessionState.catalog
.getTableMetadata(tableId)
val path = getTableLocation(targetTable, sparkSession)
.getOrElse(s"missing location for $tableId")
val primaryColumns = HoodieOptionConfig.getPrimaryColumns(targetTable.storage.properties)
assert(primaryColumns.nonEmpty,
s"There are no primary key defined in table $tableId, cannot execute delete operator")
withSparkConf(sparkSession, targetTable.storage.properties) {
Map(
"path" -> path.toString,
KEYGENERATOR_CLASS_OPT_KEY -> classOf[SqlKeyGenerator].getCanonicalName,
TABLE_NAME -> tableId.table,
OPERATION_OPT_KEY -> DataSourceWriteOptions.DELETE_OPERATION_OPT_VAL,
PARTITIONPATH_FIELD_OPT_KEY -> targetTable.partitionColumnNames.mkString(","),
HIVE_SUPPORT_TIMESTAMP -> "true",
HIVE_STYLE_PARTITIONING_OPT_KEY -> "true",
HoodieWriteConfig.DELETE_PARALLELISM -> "200",
SqlKeyGenerator.PARTITION_SCHEMA -> targetTable.partitionSchema.toDDL
)
}
}
}

View File

@@ -0,0 +1,270 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command
import java.util.Properties
import org.apache.avro.Schema
import org.apache.avro.generic.{GenericRecord, IndexedRecord}
import org.apache.hudi.common.model.{DefaultHoodieRecordPayload, HoodieRecord}
import org.apache.hudi.common.util.{Option => HOption}
import org.apache.hudi.exception.HoodieDuplicateKeyException
import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.config.HoodieWriteConfig.TABLE_NAME
import org.apache.hudi.hive.MultiPartKeysValueExtractor
import org.apache.hudi.{HoodieSparkSqlWriter, HoodieWriterUtils}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.hudi.{HoodieOptionConfig, HoodieSqlUtils}
import org.apache.spark.sql.hudi.HoodieSqlUtils._
import org.apache.spark.sql.internal.SQLConf
/**
* Command for insert into hoodie table.
*/
case class InsertIntoHoodieTableCommand(
logicalRelation: LogicalRelation,
query: LogicalPlan,
partition: Map[String, Option[String]],
overwrite: Boolean)
extends RunnableCommand {
override def run(sparkSession: SparkSession): Seq[Row] = {
assert(logicalRelation.catalogTable.isDefined, "Missing catalog table")
val table = logicalRelation.catalogTable.get
InsertIntoHoodieTableCommand.run(sparkSession, table, query, partition, overwrite)
Seq.empty[Row]
}
}
object InsertIntoHoodieTableCommand {
/**
* Run the insert query. We support both dynamic partition insert and static partition insert.
* @param sparkSession The spark session.
* @param table The insert table.
* @param query The insert query.
* @param insertPartitions The specified insert partition map.
* e.g. "insert into h(dt = '2021') select id, name from src"
* "dt" is the key in the map and "2021" is the partition value. If the
* partition value has not specified(in the case of dynamic partition)
* , it is None in the map.
* @param overwrite Whether to overwrite the table.
* @param refreshTable Whether to refresh the table after insert finished.
*/
def run(sparkSession: SparkSession, table: CatalogTable, query: LogicalPlan,
insertPartitions: Map[String, Option[String]],
overwrite: Boolean, refreshTable: Boolean = true): Boolean = {
val config = buildHoodieInsertConfig(table, sparkSession, overwrite, insertPartitions)
val mode = if (overwrite && table.partitionColumnNames.isEmpty) {
// insert overwrite non-partition table
SaveMode.Overwrite
} else {
// for insert into or insert overwrite partition we use append mode.
SaveMode.Append
}
val parameters = HoodieWriterUtils.parametersWithWriteDefaults(config)
val queryData = Dataset.ofRows(sparkSession, query)
val conf = sparkSession.sessionState.conf
val alignedQuery = alignOutputFields(queryData, table, insertPartitions, conf)
val success =
HoodieSparkSqlWriter.write(sparkSession.sqlContext, mode, parameters, alignedQuery)._1
if (success) {
if (refreshTable) {
sparkSession.catalog.refreshTable(table.identifier.unquotedString)
}
true
} else {
false
}
}
/**
* Aligned the type and name of query's output fields with the result table's fields.
* @param query The insert query which to aligned.
* @param table The result table.
* @param insertPartitions The insert partition map.
* @param conf The SQLConf.
* @return
*/
private def alignOutputFields(
query: DataFrame,
table: CatalogTable,
insertPartitions: Map[String, Option[String]],
conf: SQLConf): DataFrame = {
val targetPartitionSchema = table.partitionSchema
val staticPartitionValues = insertPartitions.filter(p => p._2.isDefined).mapValues(_.get)
assert(staticPartitionValues.isEmpty ||
staticPartitionValues.size == targetPartitionSchema.size,
s"Required partition columns is: ${targetPartitionSchema.json}, Current static partitions " +
s"is: ${staticPartitionValues.mkString("," + "")}")
val queryDataFields = if (staticPartitionValues.isEmpty) { // insert dynamic partition
query.logicalPlan.output.dropRight(targetPartitionSchema.fields.length)
} else { // insert static partition
query.logicalPlan.output
}
val targetDataSchema = table.dataSchema
// Align for the data fields of the query
val dataProjects = queryDataFields.zip(targetDataSchema.fields).map {
case (dataAttr, targetField) =>
val castAttr = castIfNeeded(dataAttr,
targetField.dataType, conf)
new Column(Alias(castAttr, targetField.name)())
}
val partitionProjects = if (staticPartitionValues.isEmpty) { // insert dynamic partitions
// The partition attributes is followed the data attributes in the query
// So we init the partitionAttrPosition with the data schema size.
var partitionAttrPosition = targetDataSchema.size
targetPartitionSchema.fields.map(f => {
val partitionAttr = query.logicalPlan.output(partitionAttrPosition)
partitionAttrPosition = partitionAttrPosition + 1
val castAttr = castIfNeeded(partitionAttr, f.dataType, conf)
new Column(Alias(castAttr, f.name)())
})
} else { // insert static partitions
targetPartitionSchema.fields.map(f => {
val staticPartitionValue = staticPartitionValues.getOrElse(f.name,
s"Missing static partition value for: ${f.name}")
val castAttr = Literal.create(staticPartitionValue, f.dataType)
new Column(Alias(castAttr, f.name)())
})
}
// Remove the hoodie meta fileds from the projects as we do not need these to write
val withoutMetaFieldDataProjects = dataProjects.filter(c => !HoodieSqlUtils.isMetaField(c.named.name))
val alignedProjects = withoutMetaFieldDataProjects ++ partitionProjects
query.select(alignedProjects: _*)
}
/**
* Build the default config for insert.
* @return
*/
private def buildHoodieInsertConfig(
table: CatalogTable,
sparkSession: SparkSession,
isOverwrite: Boolean,
insertPartitions: Map[String, Option[String]] = Map.empty): Map[String, String] = {
if (insertPartitions.nonEmpty &&
(insertPartitions.keys.toSet != table.partitionColumnNames.toSet)) {
throw new IllegalArgumentException(s"Insert partition fields" +
s"[${insertPartitions.keys.mkString(" " )}]" +
s" not equal to the defined partition in table[${table.partitionColumnNames.mkString(",")}]")
}
val parameters = HoodieOptionConfig.mappingSqlOptionToHoodieParam(table.storage.properties)
val tableType = parameters.getOrElse(TABLE_TYPE_OPT_KEY, DEFAULT_TABLE_TYPE_OPT_VAL)
val partitionFields = table.partitionColumnNames.mkString(",")
val path = getTableLocation(table, sparkSession)
.getOrElse(s"Missing location for table ${table.identifier}")
val tableSchema = table.schema
val options = table.storage.properties
val primaryColumns = HoodieOptionConfig.getPrimaryColumns(options)
val keyGenClass = if (primaryColumns.nonEmpty) {
classOf[SqlKeyGenerator].getCanonicalName
} else {
classOf[UuidKeyGenerator].getName
}
val dropDuplicate = sparkSession.conf
.getOption(INSERT_DROP_DUPS_OPT_KEY)
.getOrElse(DEFAULT_INSERT_DROP_DUPS_OPT_VAL)
.toBoolean
val operation = if (isOverwrite) {
if (table.partitionColumnNames.nonEmpty) {
INSERT_OVERWRITE_OPERATION_OPT_VAL // overwrite partition
} else {
INSERT_OPERATION_OPT_VAL
}
} else {
if (primaryColumns.nonEmpty && !dropDuplicate) {
UPSERT_OPERATION_OPT_VAL
} else {
INSERT_OPERATION_OPT_VAL
}
}
val payloadClassName = if (primaryColumns.nonEmpty && !dropDuplicate &&
tableType == COW_TABLE_TYPE_OPT_VAL) {
// Only validate duplicate key for COW, for MOR it will do the merge with the DefaultHoodieRecordPayload
// on reading.
classOf[ValidateDuplicateKeyPayload].getCanonicalName
} else {
classOf[DefaultHoodieRecordPayload].getCanonicalName
}
withSparkConf(sparkSession, options) {
Map(
"path" -> path,
TABLE_TYPE_OPT_KEY -> tableType,
TABLE_NAME -> table.identifier.table,
PRECOMBINE_FIELD_OPT_KEY -> tableSchema.fields.last.name,
OPERATION_OPT_KEY -> operation,
KEYGENERATOR_CLASS_OPT_KEY -> keyGenClass,
RECORDKEY_FIELD_OPT_KEY -> primaryColumns.mkString(","),
PARTITIONPATH_FIELD_OPT_KEY -> partitionFields,
PAYLOAD_CLASS_OPT_KEY -> payloadClassName,
META_SYNC_ENABLED_OPT_KEY -> "true",
HIVE_USE_JDBC_OPT_KEY -> "false",
HIVE_DATABASE_OPT_KEY -> table.identifier.database.getOrElse("default"),
HIVE_TABLE_OPT_KEY -> table.identifier.table,
HIVE_SUPPORT_TIMESTAMP -> "true",
HIVE_STYLE_PARTITIONING_OPT_KEY -> "true",
HIVE_PARTITION_FIELDS_OPT_KEY -> partitionFields,
HIVE_PARTITION_EXTRACTOR_CLASS_OPT_KEY -> classOf[MultiPartKeysValueExtractor].getCanonicalName,
URL_ENCODE_PARTITIONING_OPT_KEY -> "true",
HoodieWriteConfig.INSERT_PARALLELISM -> "200",
HoodieWriteConfig.UPSERT_PARALLELISM -> "200",
SqlKeyGenerator.PARTITION_SCHEMA -> table.partitionSchema.toDDL
)
}
}
}
/**
* Validate the duplicate key for insert statement without enable the INSERT_DROP_DUPS_OPT_KEY
* config.
*/
class ValidateDuplicateKeyPayload(record: GenericRecord, orderingVal: Comparable[_])
extends DefaultHoodieRecordPayload(record, orderingVal) {
def this(record: HOption[GenericRecord]) {
this(if (record.isPresent) record.get else null, 0)
}
override def combineAndGetUpdateValue(currentValue: IndexedRecord,
schema: Schema, properties: Properties): HOption[IndexedRecord] = {
val key = currentValue.asInstanceOf[GenericRecord].get(HoodieRecord.RECORD_KEY_METADATA_FIELD).toString
throw new HoodieDuplicateKeyException(key)
}
}

View File

@@ -0,0 +1,456 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command
import java.util.Base64
import org.apache.avro.Schema
import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.config.HoodieWriteConfig.TABLE_NAME
import org.apache.hudi.hive.MultiPartKeysValueExtractor
import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, HoodieWriterUtils, SparkAdapterSupport}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, Cast, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, MergeIntoTable, SubqueryAlias, UpdateAction}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.hudi.HoodieSqlUtils._
import org.apache.spark.sql.types.{BooleanType, StructType}
import org.apache.spark.sql._
import org.apache.spark.sql.hudi.{HoodieOptionConfig, SerDeUtils}
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload._
/**
* The Command for hoodie MergeIntoTable.
* The match on condition must contain the row key fields currently, so that we can use Hoodie
* Index to speed up the performance.
*
* The main algorithm:
*
* We pushed down all the matched and not matched (condition, assignment) expression pairs to the
* ExpressionPayload. And the matched (condition, assignment) expression pairs will execute in the
* ExpressionPayload#combineAndGetUpdateValue to compute the result record, while the not matched
* expression pairs will execute in the ExpressionPayload#getInsertValue.
*
* For Mor table, it is a litter complex than this. The matched record also goes through the getInsertValue
* and write append to the log. So the update actions & insert actions should process by the same
* way. We pushed all the update actions & insert actions together to the
* ExpressionPayload#getInsertValue.
*
*/
case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends RunnableCommand
with SparkAdapterSupport {
private var sparkSession: SparkSession = _
/**
* The target table identify.
*/
private lazy val targetTableIdentify: TableIdentifier = {
val aliaId = mergeInto.targetTable match {
case SubqueryAlias(_, SubqueryAlias(tableId, _)) => tableId
case SubqueryAlias(tableId, _) => tableId
case plan => throw new IllegalArgumentException(s"Illegal plan $plan in target")
}
sparkAdapter.toTableIdentify(aliaId)
}
/**
* The target table schema without hoodie meta fields.
*/
private var sourceDFOutput = mergeInto.sourceTable.output.filter(attr => !isMetaField(attr.name))
/**
* The target table schema without hoodie meta fields.
*/
private lazy val targetTableSchemaWithoutMetaFields =
removeMetaFields(mergeInto.targetTable.schema).fields
private lazy val targetTable =
sparkSession.sessionState.catalog.getTableMetadata(targetTableIdentify)
private lazy val targetTableType =
HoodieOptionConfig.getTableType(targetTable.storage.properties)
/**
*
* Return a map of target key to the source expression from the Merge-On Condition.
* e.g. merge on t.id = s.s_id AND t.name = s.s_name, we return
* Map("id" -> "s_id", "name" ->"s_name")
* TODO Currently Non-equivalent conditions are not supported.
*/
private lazy val targetKey2SourceExpression: Map[String, Expression] = {
val conditions = splitByAnd(mergeInto.mergeCondition)
val allEqs = conditions.forall(p => p.isInstanceOf[EqualTo])
if (!allEqs) {
throw new IllegalArgumentException("Non-Equal condition is not support for Merge " +
s"Into Statement: ${mergeInto.mergeCondition.sql}")
}
val targetAttrs = mergeInto.targetTable.output
val target2Source = conditions.map(_.asInstanceOf[EqualTo])
.map {
case EqualTo(left: AttributeReference, right)
if targetAttrs.indexOf(left) >= 0 => // left is the target field
left.name -> right
case EqualTo(left, right: AttributeReference)
if targetAttrs.indexOf(right) >= 0 => // right is the target field
right.name -> left
case eq =>
throw new AnalysisException(s"Invalidate Merge-On condition: ${eq.sql}." +
"The validate condition should be 'targetColumn = sourceColumnExpression', e.g." +
" t.id = s.id and t.dt = from_unixtime(s.ts)")
}.toMap
target2Source
}
/**
* Get the mapping of target preCombineField to the source expression.
*/
private lazy val target2SourcePreCombineFiled: Option[(String, Expression)] = {
val updateActions = mergeInto.matchedActions.collect { case u: UpdateAction => u }
assert(updateActions.size <= 1, s"Only support one updateAction, current is: ${updateActions.size}")
val updateAction = updateActions.headOption
HoodieOptionConfig.getPreCombineField(targetTable.storage.properties).map(preCombineField => {
val sourcePreCombineField =
updateAction.map(u => u.assignments.filter {
case Assignment(key: AttributeReference, _) => key.name.equalsIgnoreCase(preCombineField)
case _=> false
}.head.value
).getOrElse {
// If there is no update action, mapping the target column to the source by order.
val target2Source = mergeInto.targetTable.output
.filter(attr => !isMetaField(attr.name))
.map(_.name)
.zip(mergeInto.sourceTable.output.filter(attr => !isMetaField(attr.name)))
.toMap
target2Source.getOrElse(preCombineField, null)
}
(preCombineField, sourcePreCombineField)
}).filter(p => p._2 != null)
}
override def run(sparkSession: SparkSession): Seq[Row] = {
this.sparkSession = sparkSession
// Create the write parameters
val parameters = buildMergeIntoConfig(mergeInto)
val sourceDF = buildSourceDF(sparkSession)
if (mergeInto.matchedActions.nonEmpty) { // Do the upsert
executeUpsert(sourceDF, parameters)
} else { // If there is no match actions in the statement, execute insert operation only.
executeInsertOnly(sourceDF, parameters)
}
sparkSession.catalog.refreshTable(targetTableIdentify.unquotedString)
Seq.empty[Row]
}
/**
* Build the sourceDF. We will append the source primary key expressions and
* preCombine field expression to the sourceDF.
* e.g.
* <p>
* merge into h0
* using (select 1 as id, 'a1' as name, 1000 as ts) s0
* on h0.id = s0.id + 1
* when matched then update set id = s0.id, name = s0.name, ts = s0.ts + 1
* </p>
* "ts" is the pre-combine field of h0.
*
* The targetKey2SourceExpression is: ("id", "s0.id + 1").
* The target2SourcePreCombineFiled is:("ts", "s0.ts + 1").
* We will append the "s0.id + 1 as id" and "s0.ts + 1 as ts" to the sourceDF to compute the
* row key and pre-combine field.
*
*/
private def buildSourceDF(sparkSession: SparkSession): DataFrame = {
var sourceDF = Dataset.ofRows(sparkSession, mergeInto.sourceTable)
targetKey2SourceExpression.foreach {
case (targetColumn, sourceExpression)
if !isEqualToTarget(targetColumn, sourceExpression) =>
sourceDF = sourceDF.withColumn(targetColumn, new Column(sourceExpression))
sourceDFOutput = sourceDFOutput :+ AttributeReference(targetColumn, sourceExpression.dataType)()
case _=>
}
target2SourcePreCombineFiled.foreach {
case (targetPreCombineField, sourceExpression)
if !isEqualToTarget(targetPreCombineField, sourceExpression) =>
sourceDF = sourceDF.withColumn(targetPreCombineField, new Column(sourceExpression))
sourceDFOutput = sourceDFOutput :+ AttributeReference(targetPreCombineField, sourceExpression.dataType)()
case _=>
}
sourceDF
}
private def isEqualToTarget(targetColumnName: String, sourceExpression: Expression): Boolean = {
sourceExpression match {
case attr: AttributeReference if attr.name.equalsIgnoreCase(targetColumnName) => true
case Cast(attr: AttributeReference, _, _) if attr.name.equalsIgnoreCase(targetColumnName) => true
case _=> false
}
}
/**
* Execute the update and delete action. All the matched and not-matched actions will
* execute in one upsert write operation. We pushed down the matched condition and assignment
* expressions to the ExpressionPayload#combineAndGetUpdateValue and the not matched
* expressions to the ExpressionPayload#getInsertValue.
*/
private def executeUpsert(sourceDF: DataFrame, parameters: Map[String, String]): Unit = {
val updateActions = mergeInto.matchedActions.filter(_.isInstanceOf[UpdateAction])
.map(_.asInstanceOf[UpdateAction])
// Check for the update actions
checkUpdateAssignments(updateActions)
val deleteActions = mergeInto.matchedActions.filter(_.isInstanceOf[DeleteAction])
.map(_.asInstanceOf[DeleteAction])
assert(deleteActions.size <= 1, "Should be only one delete action in the merge into statement.")
val deleteAction = deleteActions.headOption
val insertActions = if (targetTableType == MOR_TABLE_TYPE_OPT_VAL) {
// For Mor table, the update record goes through the HoodieRecordPayload#getInsertValue
// We append the update actions to the insert actions, so that we can execute the update
// actions in the ExpressionPayload#getInsertValue.
mergeInto.notMatchedActions.map(_.asInstanceOf[InsertAction]) ++
updateActions.map(update => InsertAction(update.condition, update.assignments))
} else {
mergeInto.notMatchedActions.map(_.asInstanceOf[InsertAction])
}
// Check for the insert actions
checkInsertAssignments(insertActions)
// Append the table schema to the parameters. In the case of merge into, the schema of sourceDF
// may be different from the target table, because the are transform logical in the update or
// insert actions.
var writeParams = parameters +
(OPERATION_OPT_KEY -> UPSERT_OPERATION_OPT_VAL) +
(HoodieWriteConfig.WRITE_SCHEMA_PROP -> getTableSchema.toString) +
(DataSourceWriteOptions.TABLE_TYPE_OPT_KEY -> targetTableType)
// Map of Condition -> Assignments
val updateConditionToAssignments =
updateActions.map(update => {
val rewriteCondition = update.condition.map(replaceAttributeInExpression)
.getOrElse(Literal.create(true, BooleanType))
val formatAssignments = rewriteAndReOrderAssignments(update.assignments)
rewriteCondition -> formatAssignments
}).toMap
// Serialize the Map[UpdateCondition, UpdateAssignments] to base64 string
val serializedUpdateConditionAndExpressions = Base64.getEncoder
.encodeToString(SerDeUtils.toBytes(updateConditionToAssignments))
writeParams += (PAYLOAD_UPDATE_CONDITION_AND_ASSIGNMENTS ->
serializedUpdateConditionAndExpressions)
if (deleteAction.isDefined) {
val deleteCondition = deleteAction.get.condition
.map(replaceAttributeInExpression)
.getOrElse(Literal.create(true, BooleanType))
// Serialize the Map[DeleteCondition, empty] to base64 string
val serializedDeleteCondition = Base64.getEncoder
.encodeToString(SerDeUtils.toBytes(Map(deleteCondition -> Seq.empty[Assignment])))
writeParams += (PAYLOAD_DELETE_CONDITION -> serializedDeleteCondition)
}
// Serialize the Map[InsertCondition, InsertAssignments] to base64 string
writeParams += (PAYLOAD_INSERT_CONDITION_AND_ASSIGNMENTS ->
serializedInsertConditionAndExpressions(insertActions))
// Remove the meta fiels from the sourceDF as we do not need these when writing.
val sourceDFWithoutMetaFields = removeMetaFields(sourceDF)
HoodieSparkSqlWriter.write(sparkSession.sqlContext, SaveMode.Append, writeParams, sourceDFWithoutMetaFields)
}
/**
* If there are not matched actions, we only execute the insert operation.
* @param sourceDF
* @param parameters
*/
private def executeInsertOnly(sourceDF: DataFrame, parameters: Map[String, String]): Unit = {
val insertActions = mergeInto.notMatchedActions.map(_.asInstanceOf[InsertAction])
checkInsertAssignments(insertActions)
var writeParams = parameters +
(OPERATION_OPT_KEY -> INSERT_OPERATION_OPT_VAL) +
(HoodieWriteConfig.WRITE_SCHEMA_PROP -> getTableSchema.toString)
writeParams += (PAYLOAD_INSERT_CONDITION_AND_ASSIGNMENTS ->
serializedInsertConditionAndExpressions(insertActions))
// Remove the meta fiels from the sourceDF as we do not need these when writing.
val sourceDFWithoutMetaFields = removeMetaFields(sourceDF)
HoodieSparkSqlWriter.write(sparkSession.sqlContext, SaveMode.Append, writeParams, sourceDFWithoutMetaFields)
}
private def checkUpdateAssignments(updateActions: Seq[UpdateAction]): Unit = {
updateActions.foreach(update =>
assert(update.assignments.length == targetTableSchemaWithoutMetaFields.length,
s"The number of update assignments[${update.assignments.length}] must equal to the " +
s"targetTable field size[${targetTableSchemaWithoutMetaFields.length}]"))
}
private def checkInsertAssignments(insertActions: Seq[InsertAction]): Unit = {
insertActions.foreach(insert =>
assert(insert.assignments.length == targetTableSchemaWithoutMetaFields.length,
s"The number of insert assignments[${insert.assignments.length}] must equal to the " +
s"targetTable field size[${targetTableSchemaWithoutMetaFields.length}]"))
}
private def getTableSchema: Schema = {
val (structName, nameSpace) = AvroConversionUtils
.getAvroRecordNameAndNamespace(targetTableIdentify.identifier)
AvroConversionUtils.convertStructTypeToAvroSchema(
new StructType(targetTableSchemaWithoutMetaFields), structName, nameSpace)
}
/**
* Serialize the Map[InsertCondition, InsertAssignments] to base64 string.
* @param insertActions
* @return
*/
private def serializedInsertConditionAndExpressions(insertActions: Seq[InsertAction]): String = {
val insertConditionAndAssignments =
insertActions.map(insert => {
val rewriteCondition = insert.condition.map(replaceAttributeInExpression)
.getOrElse(Literal.create(true, BooleanType))
val formatAssignments = rewriteAndReOrderAssignments(insert.assignments)
// Do the check for the insert assignments
checkInsertExpression(formatAssignments)
rewriteCondition -> formatAssignments
}).toMap
Base64.getEncoder.encodeToString(
SerDeUtils.toBytes(insertConditionAndAssignments))
}
/**
* Rewrite and ReOrder the assignments.
* The Rewrite is to replace the AttributeReference to BoundReference.
* The ReOrder is to make the assignments's order same with the target table.
* @param assignments
* @return
*/
private def rewriteAndReOrderAssignments(assignments: Seq[Expression]): Seq[Expression] = {
val attr2Assignment = assignments.map {
case Assignment(attr: AttributeReference, value) => {
val rewriteValue = replaceAttributeInExpression(value)
attr -> Alias(rewriteValue, attr.name)()
}
case assignment => throw new IllegalArgumentException(s"Illegal Assignment: ${assignment.sql}")
}.toMap[Attribute, Expression]
// reorder the assignments by the target table field
mergeInto.targetTable.output
.filterNot(attr => isMetaField(attr.name))
.map(attr => {
val assignment = attr2Assignment.getOrElse(attr,
throw new IllegalArgumentException(s"Cannot find related assignment for field: ${attr.name}"))
castIfNeeded(assignment, attr.dataType, sparkSession.sqlContext.conf)
})
}
/**
* Replace the AttributeReference to BoundReference. This is for the convenience of CodeGen
* in ExpressionCodeGen which use the field index to generate the code. So we must replace
* the AttributeReference to BoundReference here.
* @param exp
* @return
*/
private def replaceAttributeInExpression(exp: Expression): Expression = {
val sourceJoinTargetFields = sourceDFOutput ++
mergeInto.targetTable.output.filterNot(attr => isMetaField(attr.name))
exp transform {
case attr: AttributeReference =>
val index = sourceJoinTargetFields.indexWhere(p => p.semanticEquals(attr))
if (index == -1) {
throw new IllegalArgumentException(s"cannot find ${attr.qualifiedName} in source or " +
s"target at the merge into statement")
}
BoundReference(index, attr.dataType, attr.nullable)
case other => other
}
}
/**
* Check the insert action expression.
* The insert expression should not contain target table field.
*/
private def checkInsertExpression(expressions: Seq[Expression]): Unit = {
expressions.foreach(exp => {
val references = exp.collect {
case reference: BoundReference => reference
}
references.foreach(ref => {
if (ref.ordinal >= sourceDFOutput.size) {
val targetColumn = targetTableSchemaWithoutMetaFields(ref.ordinal - sourceDFOutput.size)
throw new IllegalArgumentException(s"Insert clause cannot contain target table field: $targetColumn" +
s" in ${exp.sql}")
}
})
})
}
/**
* Create the config for hoodie writer.
* @param mergeInto
* @return
*/
private def buildMergeIntoConfig(mergeInto: MergeIntoTable): Map[String, String] = {
val targetTableDb = targetTableIdentify.database.getOrElse("default")
val targetTableName = targetTableIdentify.identifier
val path = getTableLocation(targetTable, sparkSession)
.getOrElse(s"missing location for $targetTableIdentify")
val options = targetTable.storage.properties
val definedPk = HoodieOptionConfig.getPrimaryColumns(options)
// TODO Currently the mergeEqualConditionKeys must be the same the primary key.
if (targetKey2SourceExpression.keySet != definedPk.toSet) {
throw new IllegalArgumentException(s"Merge Key[${targetKey2SourceExpression.keySet.mkString(",")}] is not" +
s" Equal to the defined primary key[${definedPk.mkString(",")}] in table $targetTableName")
}
HoodieWriterUtils.parametersWithWriteDefaults(
withSparkConf(sparkSession, options) {
Map(
"path" -> path,
RECORDKEY_FIELD_OPT_KEY -> targetKey2SourceExpression.keySet.mkString(","),
KEYGENERATOR_CLASS_OPT_KEY -> classOf[SqlKeyGenerator].getCanonicalName,
PRECOMBINE_FIELD_OPT_KEY -> targetKey2SourceExpression.keySet.head, // set a default preCombine field
TABLE_NAME -> targetTableName,
PARTITIONPATH_FIELD_OPT_KEY -> targetTable.partitionColumnNames.mkString(","),
PAYLOAD_CLASS_OPT_KEY -> classOf[ExpressionPayload].getCanonicalName,
META_SYNC_ENABLED_OPT_KEY -> "true",
HIVE_USE_JDBC_OPT_KEY -> "false",
HIVE_DATABASE_OPT_KEY -> targetTableDb,
HIVE_TABLE_OPT_KEY -> targetTableName,
HIVE_SUPPORT_TIMESTAMP -> "true",
HIVE_STYLE_PARTITIONING_OPT_KEY -> "true",
HIVE_PARTITION_FIELDS_OPT_KEY -> targetTable.partitionColumnNames.mkString(","),
HIVE_PARTITION_EXTRACTOR_CLASS_OPT_KEY -> classOf[MultiPartKeysValueExtractor].getCanonicalName,
URL_ENCODE_PARTITIONING_OPT_KEY -> "true", // enable the url decode for sql.
HoodieWriteConfig.INSERT_PARALLELISM -> "200", // set the default parallelism to 200 for sql
HoodieWriteConfig.UPSERT_PARALLELISM -> "200",
HoodieWriteConfig.DELETE_PARALLELISM -> "200",
SqlKeyGenerator.PARTITION_SCHEMA -> targetTable.partitionSchema.toDDL
)
})
}
}

View File

@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command
import java.util.concurrent.TimeUnit.{MICROSECONDS, MILLISECONDS}
import org.apache.avro.generic.GenericRecord
import org.apache.hudi.common.config.TypedProperties
import org.apache.hudi.common.util.PartitionPathEncodeUtils
import org.apache.hudi.keygen.{ComplexKeyGenerator, KeyGenUtils}
import org.apache.spark.sql.types.{StructType, TimestampType}
import org.joda.time.format.DateTimeFormat
/**
* A complex key generator for sql command which do some process for the
* timestamp data type partition field.
*/
class SqlKeyGenerator(props: TypedProperties) extends ComplexKeyGenerator(props) {
private lazy val partitionSchema = {
val partitionSchema = props.getString(SqlKeyGenerator.PARTITION_SCHEMA, "")
if (partitionSchema != null && partitionSchema.nonEmpty) {
Some(StructType.fromDDL(partitionSchema))
} else {
None
}
}
override def getPartitionPath(record: GenericRecord): String = {
val partitionPath = super.getPartitionPath(record)
if (partitionSchema.isDefined) {
// we can split the partitionPath here because we enable the URL_ENCODE_PARTITIONING_OPT_KEY
// by default for sql.
val partitionFragments = partitionPath.split(KeyGenUtils.DEFAULT_PARTITION_PATH_SEPARATOR)
assert(partitionFragments.size == partitionSchema.get.size)
partitionFragments.zip(partitionSchema.get.fields).map {
case (partitionValue, partitionField) =>
val hiveStylePrefix = s"${partitionField.name}="
val isHiveStyle = partitionValue.startsWith(hiveStylePrefix)
val _partitionValue = if (isHiveStyle) {
partitionValue.substring(hiveStylePrefix.length)
} else {
partitionValue
}
partitionField.dataType match {
case TimestampType =>
val timeMs = MILLISECONDS.convert(_partitionValue.toLong, MICROSECONDS)
val timestampFormat = PartitionPathEncodeUtils.escapePathName(
SqlKeyGenerator.timestampTimeFormat.print(timeMs))
if (isHiveStyle) {
s"$hiveStylePrefix$timestampFormat"
} else {
timestampFormat
}
case _=> partitionValue
}
}.mkString(KeyGenUtils.DEFAULT_PARTITION_PATH_SEPARATOR)
} else {
partitionPath
}
}
}
object SqlKeyGenerator {
val PARTITION_SCHEMA = "hoodie.sql.partition.schema"
private val timestampTimeFormat = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss")
}

View File

@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command
import org.apache.hudi.{DataSourceWriteOptions, SparkAdapterSupport}
import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.common.model.HoodieRecord
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.config.HoodieWriteConfig.TABLE_NAME
import org.apache.hudi.hive.MultiPartKeysValueExtractor
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, SubqueryAlias, UpdateTable}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.hudi.HoodieOptionConfig
import org.apache.spark.sql.hudi.HoodieSqlUtils._
import org.apache.spark.sql.types.StructField
import scala.collection.JavaConverters._
case class UpdateHoodieTableCommand(updateTable: UpdateTable) extends RunnableCommand
with SparkAdapterSupport {
private val table = updateTable.table
private val tableId = table match {
case SubqueryAlias(name, _) => sparkAdapter.toTableIdentify(name)
case _ => throw new IllegalArgumentException(s"Illegal table: $table")
}
override def run(sparkSession: SparkSession): Seq[Row] = {
logInfo(s"start execute update command for $tableId")
def cast(exp:Expression, field: StructField): Expression = {
castIfNeeded(exp, field.dataType, sparkSession.sqlContext.conf)
}
val name2UpdateValue = updateTable.assignments.map {
case Assignment(attr: AttributeReference, value) =>
attr.name -> value
}.toMap
val updateExpressions = table.output
.map(attr => name2UpdateValue.getOrElse(attr.name, attr))
.filter { // filter the meta columns
case attr: AttributeReference =>
!HoodieRecord.HOODIE_META_COLUMNS.asScala.toSet.contains(attr.name)
case _=> true
}
val projects = updateExpressions.zip(removeMetaFields(table.schema).fields).map {
case (attr: AttributeReference, field) =>
Column(cast(attr, field))
case (exp, field) =>
Column(Alias(cast(exp, field), field.name)())
}
var df = Dataset.ofRows(sparkSession, table)
if (updateTable.condition.isDefined) {
df = df.filter(Column(updateTable.condition.get))
}
df = df.select(projects: _*)
val config = buildHoodieConfig(sparkSession)
df.write
.format("hudi")
.mode(SaveMode.Append)
.options(config)
.save()
sparkSession.catalog.refreshTable(tableId.unquotedString)
logInfo(s"Finish execute update command for $tableId")
Seq.empty[Row]
}
private def buildHoodieConfig(sparkSession: SparkSession): Map[String, String] = {
val targetTable = sparkSession.sessionState.catalog
.getTableMetadata(tableId)
val path = getTableLocation(targetTable, sparkSession)
.getOrElse(s"missing location for $tableId")
val primaryColumns = HoodieOptionConfig.getPrimaryColumns(targetTable.storage.properties)
assert(primaryColumns.nonEmpty,
s"There are no primary key in table $tableId, cannot execute update operator")
withSparkConf(sparkSession, targetTable.storage.properties) {
Map(
"path" -> path.toString,
RECORDKEY_FIELD_OPT_KEY -> primaryColumns.mkString(","),
KEYGENERATOR_CLASS_OPT_KEY -> classOf[SqlKeyGenerator].getCanonicalName,
PRECOMBINE_FIELD_OPT_KEY -> primaryColumns.head, //set the default preCombine field.
TABLE_NAME -> tableId.table,
OPERATION_OPT_KEY -> DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL,
PARTITIONPATH_FIELD_OPT_KEY -> targetTable.partitionColumnNames.mkString(","),
META_SYNC_ENABLED_OPT_KEY -> "false", // TODO make the meta sync enable by default.
HIVE_USE_JDBC_OPT_KEY -> "false",
HIVE_DATABASE_OPT_KEY -> tableId.database.getOrElse("default"),
HIVE_TABLE_OPT_KEY -> tableId.table,
HIVE_PARTITION_FIELDS_OPT_KEY -> targetTable.partitionColumnNames.mkString(","),
HIVE_PARTITION_EXTRACTOR_CLASS_OPT_KEY -> classOf[MultiPartKeysValueExtractor].getCanonicalName,
URL_ENCODE_PARTITIONING_OPT_KEY -> "true",
HIVE_SUPPORT_TIMESTAMP -> "true",
HIVE_STYLE_PARTITIONING_OPT_KEY -> "true",
HoodieWriteConfig.UPSERT_PARALLELISM -> "200",
SqlKeyGenerator.PARTITION_SCHEMA -> targetTable.partitionSchema.toDDL
)
}
}
}

View File

@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command
import java.util.UUID
import org.apache.avro.generic.GenericRecord
import org.apache.hudi.common.config.TypedProperties
/**
* A KeyGenerator which use the uuid as the record key.
*/
class UuidKeyGenerator(props: TypedProperties) extends SqlKeyGenerator(props) {
override def getRecordKey(record: GenericRecord): String = UUID.randomUUID.toString
}

View File

@@ -0,0 +1,185 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.payload
import java.util.UUID
import org.apache.avro.generic.IndexedRecord
import org.apache.hudi.sql.IExpressionEvaluator
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, LeafExpression, UnsafeArrayData, UnsafeMapData, UnsafeRow}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.hudi.command.payload.ExpressionCodeGen.RECORD_NAME
import org.apache.spark.sql.types.{DataType, Decimal}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.ParentClassLoader
import org.apache.spark.{TaskContext, TaskKilledException}
import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.{ClassBodyEvaluator, InternalCompilerException}
/**
* Do CodeGen for expression based on IndexedRecord.
* The mainly difference with the spark's CodeGen for expression is that
* the expression's input is a IndexedRecord but not a Row.
*
*/
object ExpressionCodeGen extends Logging {
val RECORD_NAME = "record"
/**
* CodeGen for expressions.
* @param exprs The expression list to CodeGen.
* @return An IExpressionEvaluator generate by CodeGen which take a IndexedRecord as input
* param and return a Array of results for each expression.
*/
def doCodeGen(exprs: Seq[Expression]): IExpressionEvaluator = {
val ctx = new CodegenContext()
// Set the input_row to null as we do not use row as the input object but Record.
ctx.INPUT_ROW = null
val replacedExprs = exprs.map(replaceBoundReference)
val resultVars = replacedExprs.map(_.genCode(ctx))
val className = s"ExpressionPayloadEvaluator_${UUID.randomUUID().toString.replace("-", "_")}"
val codeBody =
s"""
|private Object[] references;
|private String code;
|
|public $className(Object references, String code) {
| this.references = (Object[])references;
| this.code = code;
|}
|
|public Object[] eval(IndexedRecord $RECORD_NAME) {
| ${resultVars.map(_.code).mkString("\n")}
| Object[] results = new Object[${resultVars.length}];
| ${
(for (i <- resultVars.indices) yield {
s"""
|if (${resultVars(i).isNull}) {
| results[$i] = null;
|} else {
| results[$i] = ${resultVars(i).value.code};
|}
""".stripMargin
}).mkString("\n")
}
return results;
| }
|
|public String getCode() {
| return code;
|}
""".stripMargin
val evaluator = new ClassBodyEvaluator()
val parentClassLoader = new ParentClassLoader(
Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
evaluator.setParentClassLoader(parentClassLoader)
// Cannot be under package codegen, or fail with java.lang.InstantiationException
evaluator.setClassName(s"org.apache.hudi.sql.payload.$className")
evaluator.setDefaultImports(
classOf[Platform].getName,
classOf[InternalRow].getName,
classOf[UnsafeRow].getName,
classOf[UTF8String].getName,
classOf[Decimal].getName,
classOf[CalendarInterval].getName,
classOf[ArrayData].getName,
classOf[UnsafeArrayData].getName,
classOf[MapData].getName,
classOf[UnsafeMapData].getName,
classOf[Expression].getName,
classOf[TaskContext].getName,
classOf[TaskKilledException].getName,
classOf[InputMetrics].getName,
classOf[IndexedRecord].getName
)
evaluator.setImplementedInterfaces(Array(classOf[IExpressionEvaluator]))
try {
evaluator.cook(codeBody)
} catch {
case e: InternalCompilerException =>
val msg = s"failed to compile: $e"
logError(msg, e)
throw new InternalCompilerException(msg, e)
case e: CompileException =>
val msg = s"failed to compile: $e"
logError(msg, e)
throw new CompileException(msg, e.getLocation)
}
val referenceArray = ctx.references.toArray.map(_.asInstanceOf[Object])
val expressionSql = exprs.map(_.sql).mkString(" ")
evaluator.getClazz.getConstructor(classOf[Object], classOf[String])
.newInstance(referenceArray, s"Expressions is: [$expressionSql]\nCodeBody is: {\n$codeBody\n}")
.asInstanceOf[IExpressionEvaluator]
}
/**
* Replace the BoundReference to the Record implement which will override the
* doGenCode method.
*/
private def replaceBoundReference(expression: Expression): Expression = {
expression transformDown {
case BoundReference(ordinal, dataType, nullable) =>
RecordBoundReference(ordinal, dataType, nullable)
case other =>
other
}
}
}
case class RecordBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends LeafExpression {
/**
* Do the CodeGen for RecordBoundReference.
* Use "IndexedRecord" as the input object but not a "Row"
*/
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = JavaCode.javaType(dataType)
val boxType = JavaCode.boxedType(dataType)
val value = s"($boxType)$RECORD_NAME.get($ordinal)"
if (nullable) {
ev.copy(code =
code"""
| boolean ${ev.isNull} = $RECORD_NAME.get($ordinal) == null;
| $javaType ${ev.value} = ${ev.isNull} ?
| ${CodeGenerator.defaultValue(dataType)} : ($value);
"""
)
} else {
ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
}
}
override def eval(input: InternalRow): Any = {
throw new IllegalArgumentException(s"Should not call eval method for " +
s"${getClass.getCanonicalName}")
}
}

View File

@@ -0,0 +1,309 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.payload
import java.util.{Base64, Properties}
import java.util.concurrent.Callable
import scala.collection.JavaConverters._
import com.google.common.cache.CacheBuilder
import org.apache.avro.Schema
import org.apache.avro.generic.{GenericData, GenericRecord, IndexedRecord}
import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.avro.HoodieAvroUtils
import org.apache.hudi.avro.HoodieAvroUtils.bytesToAvro
import org.apache.hudi.common.model.{DefaultHoodieRecordPayload, HoodieRecord}
import org.apache.hudi.common.util.{ValidationUtils, Option => HOption}
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.io.HoodieWriteHandle
import org.apache.hudi.sql.IExpressionEvaluator
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.hudi.SerDeUtils
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.getEvaluator
import org.apache.spark.unsafe.types.UTF8String
import scala.collection.mutable.ArrayBuffer
/**
* A HoodieRecordPayload for MergeIntoHoodieTableCommand.
* It will execute the condition and assignments expression in the
* match and not-match actions and compute the final record to write.
*
* If there is no condition match the record, ExpressionPayload will return
* a HoodieWriteHandle.IGNORE_RECORD, and the write handles will ignore this record.
*/
class ExpressionPayload(record: GenericRecord,
orderingVal: Comparable[_])
extends DefaultHoodieRecordPayload(record, orderingVal) {
def this(recordOpt: HOption[GenericRecord]) {
this(recordOpt.orElse(null), 0)
}
/**
* The schema of this table.
*/
private var writeSchema: Schema = _
override def combineAndGetUpdateValue(currentValue: IndexedRecord,
schema: Schema): HOption[IndexedRecord] = {
throw new IllegalStateException(s"Should not call this method for ${getClass.getCanonicalName}")
}
override def getInsertValue(schema: Schema): HOption[IndexedRecord] = {
throw new IllegalStateException(s"Should not call this method for ${getClass.getCanonicalName}")
}
override def combineAndGetUpdateValue(targetRecord: IndexedRecord,
schema: Schema, properties: Properties): HOption[IndexedRecord] = {
val sourceRecord = bytesToAvro(recordBytes, schema)
val joinSqlRecord = new SqlTypedRecord(joinRecord(sourceRecord, targetRecord))
// Process update
val updateConditionAndAssignmentsText =
properties.get(ExpressionPayload.PAYLOAD_UPDATE_CONDITION_AND_ASSIGNMENTS)
assert(updateConditionAndAssignmentsText != null,
s"${ExpressionPayload.PAYLOAD_UPDATE_CONDITION_AND_ASSIGNMENTS} have not set")
var resultRecordOpt: HOption[IndexedRecord] = null
// Get the Evaluator for each condition and update assignments.
val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString)
for ((conditionEvaluator, assignmentEvaluator) <- updateConditionAndAssignments
if resultRecordOpt == null) {
val conditionVal = evaluate(conditionEvaluator, joinSqlRecord).head.asInstanceOf[Boolean]
// If the update condition matched then execute assignment expression
// to compute final record to update. We will return the first matched record.
if (conditionVal) {
val results = evaluate(assignmentEvaluator, joinSqlRecord)
initWriteSchemaIfNeed(properties)
val resultRecord = convertToRecord(results, writeSchema)
if (needUpdatingPersistedRecord(targetRecord, resultRecord, properties)) {
resultRecordOpt = HOption.of(resultRecord)
} else {
// if the PreCombine field value of targetRecord is greate
// than the new incoming record, just keep the old record value.
resultRecordOpt = HOption.of(targetRecord)
}
}
}
if (resultRecordOpt == null) {
// Process delete
val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION)
if (deleteConditionText != null) {
val deleteCondition = getEvaluator(deleteConditionText.toString).head._1
val deleteConditionVal = evaluate(deleteCondition, joinSqlRecord).head.asInstanceOf[Boolean]
if (deleteConditionVal) {
resultRecordOpt = HOption.empty()
}
}
}
if (resultRecordOpt == null) {
// If there is no condition matched, just filter this record.
// here we return a IGNORE_RECORD, HoodieMergeHandle will not handle it.
HOption.of(HoodieWriteHandle.IGNORE_RECORD)
} else {
resultRecordOpt
}
}
override def getInsertValue(schema: Schema, properties: Properties): HOption[IndexedRecord] = {
val incomingRecord = bytesToAvro(recordBytes, schema)
if (isDeleteRecord(incomingRecord)) {
HOption.empty[IndexedRecord]()
} else {
val insertConditionAndAssignmentsText =
properties.get(ExpressionPayload.PAYLOAD_INSERT_CONDITION_AND_ASSIGNMENTS)
// Process insert
val sqlTypedRecord = new SqlTypedRecord(incomingRecord)
// Get the evaluator for each condition and insert assignment.
val insertConditionAndAssignments =
ExpressionPayload.getEvaluator(insertConditionAndAssignmentsText.toString)
var resultRecordOpt: HOption[IndexedRecord] = null
for ((conditionEvaluator, assignmentEvaluator) <- insertConditionAndAssignments
if resultRecordOpt == null) {
val conditionVal = evaluate(conditionEvaluator, sqlTypedRecord).head.asInstanceOf[Boolean]
// If matched the insert condition then execute the assignment expressions to compute the
// result record. We will return the first matched record.
if (conditionVal) {
val results = evaluate(assignmentEvaluator, sqlTypedRecord)
initWriteSchemaIfNeed(properties)
resultRecordOpt = HOption.of(convertToRecord(results, writeSchema))
}
}
// Process delete for MOR
if (resultRecordOpt == null && isMORTable(properties)) {
val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION)
if (deleteConditionText != null) {
val deleteCondition = getEvaluator(deleteConditionText.toString).head._1
val deleteConditionVal = evaluate(deleteCondition, sqlTypedRecord).head.asInstanceOf[Boolean]
if (deleteConditionVal) {
resultRecordOpt = HOption.empty()
}
}
}
if (resultRecordOpt != null) {
resultRecordOpt
} else {
// If there is no condition matched, just filter this record.
// Here we return a IGNORE_RECORD, HoodieCreateHandle will not handle it.
HOption.of(HoodieWriteHandle.IGNORE_RECORD)
}
}
}
private def isMORTable(properties: Properties): Boolean = {
properties.getProperty(TABLE_TYPE_OPT_KEY, null) == MOR_TABLE_TYPE_OPT_VAL
}
private def convertToRecord(values: Array[AnyRef], schema: Schema): IndexedRecord = {
assert(values.length == schema.getFields.size())
val writeRecord = new GenericData.Record(schema)
for (i <- values.indices) {
writeRecord.put(i, values(i))
}
writeRecord
}
/**
* Init the table schema.
*/
private def initWriteSchemaIfNeed(properties: Properties): Unit = {
if (writeSchema == null) {
ValidationUtils.checkArgument(properties.containsKey(HoodieWriteConfig.WRITE_SCHEMA_PROP),
s"Missing ${HoodieWriteConfig.WRITE_SCHEMA_PROP}")
writeSchema = new Schema.Parser().parse(properties.getProperty(HoodieWriteConfig.WRITE_SCHEMA_PROP))
}
}
/**
* Join the source record with the target record.
*
* @return
*/
private def joinRecord(sourceRecord: IndexedRecord, targetRecord: IndexedRecord): IndexedRecord = {
val leftSchema = sourceRecord.getSchema
// the targetRecord is load from the disk, it contains the meta fields, so we remove it here
val rightSchema = HoodieAvroUtils.removeMetadataFields(targetRecord.getSchema)
val joinSchema = mergeSchema(leftSchema, rightSchema)
val values = new ArrayBuffer[AnyRef]()
for (i <- 0 until joinSchema.getFields.size()) {
val value = if (i < leftSchema.getFields.size()) {
sourceRecord.get(i)
} else { // skip meta field
targetRecord.get(i - leftSchema.getFields.size() + HoodieRecord.HOODIE_META_COLUMNS.size())
}
values += value
}
convertToRecord(values.toArray, joinSchema)
}
private def mergeSchema(a: Schema, b: Schema): Schema = {
val mergedFields =
a.getFields.asScala.map(field =>
new Schema.Field("a_" + field.name,
field.schema, field.doc, field.defaultVal, field.order)) ++
b.getFields.asScala.map(field =>
new Schema.Field("b_" + field.name,
field.schema, field.doc, field.defaultVal, field.order))
Schema.createRecord(a.getName, a.getDoc, a.getNamespace, a.isError, mergedFields.asJava)
}
private def evaluate(evaluator: IExpressionEvaluator, sqlTypedRecord: SqlTypedRecord): Array[Object] = {
try evaluator.eval(sqlTypedRecord) catch {
case e: Throwable =>
throw new RuntimeException(s"Error in execute expression: ${e.getMessage}.\n${evaluator.getCode}", e)
}
}
}
object ExpressionPayload {
/**
* Property for pass the merge-into delete clause condition expresssion.
*/
val PAYLOAD_DELETE_CONDITION = "hoodie.payload.delete.condition"
/**
* Property for pass the merge-into update clauses's condition and assignments.
*/
val PAYLOAD_UPDATE_CONDITION_AND_ASSIGNMENTS = "hoodie.payload.update.condition.assignments"
/**
* Property for pass the merge-into insert clauses's condition and assignments.
*/
val PAYLOAD_INSERT_CONDITION_AND_ASSIGNMENTS = "hoodie.payload.insert.condition.assignments"
/**
* A cache for the serializedConditionAssignments to the compiled class after CodeGen.
* The Map[IExpressionEvaluator, IExpressionEvaluator] is the map of the condition expression
* to the assignments expression.
*/
private val cache = CacheBuilder.newBuilder()
.maximumSize(1024)
.build[String, Map[IExpressionEvaluator, IExpressionEvaluator]]()
/**
* Do the CodeGen for each condition and assignment expressions.We will cache it to reduce
* the compile time for each method call.
* @param serializedConditionAssignments
* @return
*/
def getEvaluator(
serializedConditionAssignments: String): Map[IExpressionEvaluator, IExpressionEvaluator] = {
cache.get(serializedConditionAssignments,
new Callable[Map[IExpressionEvaluator, IExpressionEvaluator]] {
override def call(): Map[IExpressionEvaluator, IExpressionEvaluator] = {
val serializedBytes = Base64.getDecoder.decode(serializedConditionAssignments)
val conditionAssignments = SerDeUtils.toObject(serializedBytes)
.asInstanceOf[Map[Expression, Seq[Assignment]]]
// Do the CodeGen for condition expression and assignment expression
conditionAssignments.map {
case (condition, assignments) =>
val conditionEvaluator = ExpressionCodeGen.doCodeGen(Seq(condition))
val assignmentEvaluator = StringConvertEvaluator(ExpressionCodeGen.doCodeGen(assignments))
conditionEvaluator -> assignmentEvaluator
}
}
})
}
/**
* As the "baseEvaluator" return "UTF8String" for the string type which cannot be process by
* the Avro, The StringConvertEvaluator will convert the "UTF8String" to "String".
*/
case class StringConvertEvaluator(baseEvaluator: IExpressionEvaluator) extends IExpressionEvaluator {
/**
* Convert the UTF8String to String
*/
override def eval(record: IndexedRecord): Array[AnyRef] = {
baseEvaluator.eval(record).map{
case s: UTF8String => s.toString
case o => o
}
}
override def getCode: String = baseEvaluator.getCode
}
}

View File

@@ -0,0 +1,177 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.command.payload
import java.math.BigDecimal
import java.nio.ByteBuffer
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._
import org.apache.avro.generic.{GenericFixed, IndexedRecord}
import org.apache.avro.util.Utf8
import org.apache.avro.{LogicalTypes, Schema}
import org.apache.spark.sql.avro.{IncompatibleSchemaException, SchemaConverters}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import scala.collection.JavaConverters._
/**
* A sql typed record which will convert the avro field to sql typed value.
* This is referred to the org.apache.spark.sql.avro.AvroDeserializer#newWriter in spark project.
* @param record
*/
class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord {
private lazy val decimalConversions = new DecimalConversion()
private lazy val sqlType = SchemaConverters.toSqlType(getSchema).dataType.asInstanceOf[StructType]
override def put(i: Int, v: Any): Unit = {
record.put(i, v)
}
override def get(i: Int): AnyRef = {
val value = record.get(i)
val avroFieldType = getSchema.getFields.get(i).schema()
val sqlFieldType = sqlType.fields(i).dataType
convert(avroFieldType, sqlFieldType, value)
}
private def convert(avroFieldType: Schema, sqlFieldType: DataType, value: AnyRef): AnyRef = {
(avroFieldType.getType, sqlFieldType) match {
case (NULL, NullType) => null
case (BOOLEAN, BooleanType) => value.asInstanceOf[Boolean].asInstanceOf[java.lang.Boolean]
case (INT, IntegerType) => value.asInstanceOf[Int].asInstanceOf[java.lang.Integer]
case (INT, DateType) => value.asInstanceOf[Int].asInstanceOf[java.lang.Integer]
case (LONG, LongType) => value.asInstanceOf[Long].asInstanceOf[java.lang.Long]
case (LONG, TimestampType) => avroFieldType.getLogicalType match {
case _: TimestampMillis => (value.asInstanceOf[Long] * 1000).asInstanceOf[java.lang.Long]
case _: TimestampMicros => value.asInstanceOf[Long].asInstanceOf[java.lang.Long]
case null =>
// For backward compatibility, if the Avro type is Long and it is not logical type,
// the value is processed as timestamp type with millisecond precision.
java.lang.Long.valueOf(value.asInstanceOf[Long] * 1000)
case other => throw new IncompatibleSchemaException(
s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.")
}
// Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date.
// For backward compatibility, we still keep this conversion.
case (LONG, DateType) =>
java.lang.Integer.valueOf((value.asInstanceOf[Long] / SqlTypedRecord.MILLIS_PER_DAY).toInt)
case (FLOAT, FloatType) => value.asInstanceOf[Float].asInstanceOf[java.lang.Float]
case (DOUBLE, DoubleType) => value.asInstanceOf[Double].asInstanceOf[java.lang.Double]
case (STRING, StringType) => value match {
case s: String => UTF8String.fromString(s)
case s: Utf8 => UTF8String.fromString(s.toString)
}
case (ENUM, StringType) => value.toString
case (FIXED, BinaryType) => value.asInstanceOf[GenericFixed].bytes().clone()
case (BYTES, BinaryType) => value match {
case b: ByteBuffer =>
val bytes = new Array[Byte](b.remaining)
b.get(bytes)
bytes
case b: Array[Byte] => b
case other => throw new RuntimeException(s"$other is not a valid avro binary.")
}
case (FIXED, d: DecimalType) =>
val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroFieldType,
LogicalTypes.decimal(d.precision, d.scale))
createDecimal(bigDecimal, d.precision, d.scale)
case (BYTES, d: DecimalType) =>
val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroFieldType,
LogicalTypes.decimal(d.precision, d.scale))
createDecimal(bigDecimal, d.precision, d.scale)
case (RECORD, _: StructType) =>
throw new IllegalArgumentException(s"UnSupport StructType yet")
case (ARRAY, ArrayType(_, _)) =>
throw new IllegalArgumentException(s"UnSupport ARRAY type yet")
case (MAP, MapType(keyType, _, _)) if keyType == StringType =>
throw new IllegalArgumentException(s"UnSupport MAP type yet")
case (UNION, _) =>
val allTypes = avroFieldType.getTypes.asScala
val nonNullTypes = allTypes.filter(_.getType != NULL)
if (nonNullTypes.nonEmpty) {
if (nonNullTypes.length == 1) {
convert(nonNullTypes.head, sqlFieldType, value)
} else {
nonNullTypes.map(_.getType) match {
case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlFieldType == LongType =>
value match {
case null => null
case l: java.lang.Long => l
case i: java.lang.Integer => i.longValue().asInstanceOf[java.lang.Long]
}
case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlFieldType == DoubleType =>
value match {
case null => null
case d: java.lang.Double => d
case f: java.lang.Float => f.doubleValue().asInstanceOf[java.lang.Double]
}
case _ =>
throw new IllegalArgumentException(s"UnSupport UNION type: ${sqlFieldType}")
}
}
} else {
null
}
case _ =>
throw new IncompatibleSchemaException(
s"Cannot convert Avro to catalyst because schema " +
s"is not compatible (avroType = $avroFieldType, sqlType = $sqlFieldType).\n")
}
}
private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
if (precision <= Decimal.MAX_LONG_DIGITS) {
// Constructs a `Decimal` with an unscaled `Long` value if possible.
Decimal(decimal.unscaledValue().longValue(), precision, scale)
} else {
// Otherwise, resorts to an unscaled `BigInteger` instead.
Decimal(decimal, precision, scale)
}
}
override def getSchema: Schema = record.getSchema
}
object SqlTypedRecord {
val MILLIS_PER_DAY = 24 * 60 * 60 * 1000L
}

View File

@@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets
import java.util.Date
import org.apache.hadoop.fs.Path
import org.apache.hudi.{DataSourceReadOptions, HoodieSparkUtils, IncrementalRelation, MergeOnReadIncrementalRelation}
import org.apache.hudi.{DataSourceReadOptions, IncrementalRelation, MergeOnReadIncrementalRelation, SparkAdapterSupport}
import org.apache.hudi.common.model.HoodieTableType
import org.apache.hudi.common.table.timeline.HoodieActiveTimeline
import org.apache.hudi.common.table.{HoodieTableMetaClient, TableSchemaResolver}
@@ -51,7 +51,7 @@ class HoodieStreamSource(
metadataPath: String,
schemaOption: Option[StructType],
parameters: Map[String, String])
extends Source with Logging with Serializable {
extends Source with Logging with Serializable with SparkAdapterSupport {
@transient private val hadoopConf = sqlContext.sparkSession.sessionState.newHadoopConf()
private lazy val tablePath: Path = {
@@ -160,7 +160,7 @@ class HoodieStreamSource(
val rdd = tableType match {
case HoodieTableType.COPY_ON_WRITE =>
val serDe = HoodieSparkUtils.createRowSerDe(RowEncoder(schema))
val serDe = sparkAdapter.createSparkRowSerDe(RowEncoder(schema))
new IncrementalRelation(sqlContext, incParams, schema, metaClient)
.buildScan()
.map(serDe.serializeRow)

View File

@@ -57,6 +57,7 @@ import org.apache.hudi.hadoop.HoodieParquetInputFormat;
import org.apache.hudi.hadoop.realtime.HoodieParquetRealtimeInputFormat;
import org.apache.hudi.index.HoodieIndex.IndexType;
import org.apache.hudi.keygen.NonpartitionedKeyGenerator;
import org.apache.hudi.common.util.PartitionPathEncodeUtils;
import org.apache.hudi.keygen.SimpleKeyGenerator;
import org.apache.hudi.table.action.bootstrap.BootstrapUtils;
import org.apache.hudi.testutils.HoodieClientTestBase;
@@ -82,8 +83,6 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import java.io.IOException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
@@ -568,8 +567,8 @@ public class TestBootstrap extends HoodieClientTestBase {
});
if (isPartitioned) {
sqlContext.udf().register("partgen",
(UDF1<String, String>) (val) -> URLEncoder.encode(partitionPaths.get(
Integer.parseInt(val.split("_")[1]) % partitionPaths.size()), StandardCharsets.UTF_8.toString()),
(UDF1<String, String>) (val) -> PartitionPathEncodeUtils.escapePathName(partitionPaths.get(
Integer.parseInt(val.split("_")[1]) % partitionPaths.size())),
DataTypes.StringType);
}
JavaRDD rdd = jsc.parallelize(records);

View File

@@ -0,0 +1,255 @@
# SET OPTION
set hoodie.insert.shuffle.parallelism = 1;
+----------+
| ok |
+----------+
set hoodie.upsert.shuffle.parallelism = 1;
+----------+
| ok |
+----------+
set hoodie.delete.shuffle.parallelism = 1;
+----------+
| ok |
+----------+
# CTAS
create table h0 using hudi options(type = '${tableType}')
as select 1 as id, 'a1' as name, 10 as price;
+----------+
| ok |
+----------+
select id, name, price from h0;
+-----------+
| 1 a1 10 |
+-----------+
create table h0_p using hudi partitioned by(dt)
options(type = '${tableType}')
as select cast('2021-05-07 00:00:00' as timestamp) as dt,
1 as id, 'a1' as name, 10 as price;
+----------+
| ok |
+----------+
select id, name, price, cast(dt as string) from h0_p;
+--------------------------------+
| 1 a1 10 2021-05-07 00:00:00 |
+--------------------------------+
# CREATE TABLE
create table h1 (
id bigint,
name string,
price double,
ts bigint
) using hudi
options (
type = '${tableType}',
primaryKey = 'id',
preCombineField = 'ts'
)
location '${tmpDir}/h1';
+----------+
| ok |
+----------+
create table h1_p (
id bigint,
name string,
price double,
ts bigint,
dt string
) using hudi
partitioned by (dt)
options (
type = '${tableType}',
primaryKey = 'id',
preCombineField = 'ts'
)
location '${tmpDir}/h1_p';
+----------+
| ok |
+----------+
# INSERT/UPDATE/MERGE/DELETE
insert into h1 values(1, 'a1', 10, 1000);
+----------+
| ok |
+----------+
insert into h1 values(2, 'a2', 11, 1000);
+----------+
| ok |
+----------+
# insert static partition
insert into h1_p partition(dt = '2021-05-07') select * from h1;
+----------+
| ok |
+----------+
select id, name, price, ts, dt from h1_p order by id;
+---------------------------+
| 1 a1 10.0 1000 2021-05-07 |
| 2 a2 11.0 1000 2021-05-07 |
+---------------------------+
# insert overwrite table
insert overwrite table h1_p partition(dt = '2021-05-07') select * from h1 limit 10;
+----------+
| ok |
+----------+
select id, name, price, ts, dt from h1_p order by id;
+---------------------------+
| 1 a1 10.0 1000 2021-05-07 |
| 2 a2 11.0 1000 2021-05-07 |
+---------------------------+
# insert dynamic partition
insert into h1_p
select id, concat('a', id) as name, price, ts, dt
from (
select id + 2 as id, price + 2 as price, ts, '2021-05-08' as dt from h1
)
union all
select 5 as id, 'a5' as name, 10 as price, 1000 as ts, '2021-05-08' as dt;
+----------+
| ok |
+----------+
select id, name, price, ts, dt from h1_p order by id;
+---------------------------+
| 1 a1 10.0 1000 2021-05-07 |
| 2 a2 11.0 1000 2021-05-07 |
| 3 a3 12.0 1000 2021-05-08 |
| 4 a4 13.0 1000 2021-05-08 |
| 5 a5 10.0 1000 2021-05-08 |
+---------------------------+
# update table
update h1_p set price = price * 2 where id % 2 = 1;
+----------+
| ok |
+----------+
select id, name, price, ts, dt from h1_p order by id;
+---------------------------+
| 1 a1 20.0 1000 2021-05-07 |
| 2 a2 11.0 1000 2021-05-07 |
| 3 a3 24.0 1000 2021-05-08 |
| 4 a4 13.0 1000 2021-05-08 |
| 5 a5 20.0 1000 2021-05-08 |
+---------------------------+
update h1 set price = if (id %2 = 1, price * 2, price);
+----------+
| ok |
+----------+
select id, name, price, ts from h1;
+----------------+
| 1 a1 20.0 1000 |
| 2 a2 11.0 1000 |
+----------------+
# delete table
delete from h1_p where id = 5;
+----------+
| ok |
+----------+
select count(1) from h1_p;
+----------+
| 4 |
+----------+
# merge into
merge into h1_p t0
using (
select *, '2021-05-07' as dt from h1
) s0
on t0.id = s0.id
when matched then update set id = s0.id, name = s0.name, price = s0.price *2, ts = s0.ts, dt = s0.dt
when not matched then insert *;
+----------+
| ok |
+----------+
select id, name, price, ts, dt from h1_p order by id;
+---------------------------+
| 1 a1 40.0 1000 2021-05-07 |
| 2 a2 22.0 1000 2021-05-07 |
| 3 a3 24.0 1000 2021-05-08 |
| 4 a4 13.0 1000 2021-05-08 |
+---------------------------+
merge into h1_p t0
using (
select 5 as _id, 'a5' as _name, 10 as _price, 1000 as _ts, '2021-05-08' as dt
) s0
on s0._id = t0.id
when matched then update set *
when not matched then insert (id, name, price, ts, dt) values(_id, _name, _price, _ts, s0.dt);
+----------+
| ok |
+----------+
select id, name, price, ts, dt from h1_p order by id;
+---------------------------+
| 1 a1 40.0 1000 2021-05-07 |
| 2 a2 22.0 1000 2021-05-07 |
| 3 a3 24.0 1000 2021-05-08 |
| 4 a4 13.0 1000 2021-05-08 |
| 5 a5 10.0 1000 2021-05-08 |
+---------------------------+
merge into h1_p t0
using (
select 1 as id, '_delete' as name, 10 as price, 1000 as ts, '2021-05-07' as dt
union
select 2 as id, '_update' as name, 12 as price, 1001 as ts, '2021-05-07' as dt
union
select 6 as id, '_insert' as name, 10 as price, 1000 as ts, '2021-05-08' as dt
) s0
on s0.id = t0.id
when matched and name = '_update'
then update set id = s0.id, name = s0.name, price = s0.price, ts = s0.ts, dt = s0.dt
when matched and name = '_delete' then delete
when not matched and name = '_insert' then insert *;
+----------+
| ok |
+----------+
select id, name, price, ts, dt from h1_p order by id;
+--------------------------------+
| 2 _update 12.0 1001 2021-05-07 |
| 3 a3 24.0 1000 2021-05-08 |
| 4 a4 13.0 1000 2021-05-08 |
| 5 a5 10.0 1000 2021-05-08 |
| 6 _insert 10.0 1000 2021-05-08 |
+--------------------------------+
# DROP TABLE
drop table h0;
+----------+
| ok |
+----------+
drop table h0_p;
+----------+
| ok |
+----------+
drop table h1;
+----------+
| ok |
+----------+
drop table h1_p;
+----------+
| ok |
+----------+

View File

@@ -67,14 +67,6 @@ class TestConvertFilterToCatalystExpression {
"((`ts` < 10) AND (`ts` > 1))")
}
@Test
def testUnSupportConvert(): Unit = {
checkConvertFilters(Array(unsupport()), null)
checkConvertFilters(Array(and(unsupport(), eq("id", 1))), null)
checkConvertFilters(Array(or(unsupport(), eq("id", 1))), null)
checkConvertFilters(Array(and(eq("id", 1), not(unsupport()))), null)
}
private def checkConvertFilter(filter: Filter, expectExpression: String): Unit = {
val exp = convertToCatalystExpression(filter, tableSchema)
if (expectExpression == null) {
@@ -154,12 +146,4 @@ class TestConvertFilterToCatalystExpression {
private def contains(attribute: String, value: String): Filter = {
StringContains(attribute, value)
}
private def unsupport(): Filter = {
UnSupportFilter("")
}
case class UnSupportFilter(value: Any) extends Filter {
override def references: Array[String] = Array.empty
}
}

View File

@@ -17,14 +17,13 @@
package org.apache.hudi
import java.net.URLEncoder
import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.common.config.HoodieMetadataConfig
import org.apache.hudi.common.table.HoodieTableMetaClient
import org.apache.hudi.common.table.view.HoodieTableFileSystemView
import org.apache.hudi.common.testutils.HoodieTestDataGenerator
import org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings
import org.apache.hudi.common.util.PartitionPathEncodeUtils
import org.apache.hudi.config.HoodieWriteConfig
import org.apache.hudi.keygen.ComplexKeyGenerator
import org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator.{Config, TimestampType}
@@ -137,7 +136,8 @@ class TestHoodieFileIndex extends HoodieClientTestBase {
val fileIndex = HoodieFileIndex(spark, metaClient, None, Map("path" -> basePath))
val partitionFilter1 = EqualTo(attribute("partition"), literal("2021/03/08"))
val partitionName = if (partitionEncode) URLEncoder.encode("2021/03/08") else "2021/03/08"
val partitionName = if (partitionEncode) PartitionPathEncodeUtils.escapePathName("2021/03/08")
else "2021/03/08"
val partitionAndFilesAfterPrune = fileIndex.listFiles(Seq(partitionFilter1), Seq.empty)
assertEquals(1, partitionAndFilesAfterPrune.size)

View File

@@ -528,7 +528,13 @@ class HoodieSparkSqlWriterSuite extends FunSuite with Matchers {
"spark.sql.sources.schema.numParts=1\n" +
"spark.sql.sources.schema.numPartCols=1\n" +
"spark.sql.sources.schema.part.0=" +
"{\"type\":\"struct\",\"fields\":[{\"name\":\"_row_key\",\"type\":\"string\",\"nullable\":false,\"metadata\":{}}," +
"{\"type\":\"struct\",\"fields\":[{\"name\":\"_hoodie_commit_time\"," +
"\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":" +
"\"_hoodie_commit_seqno\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}," +
"{\"name\":\"_hoodie_record_key\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}," +
"{\"name\":\"_hoodie_partition_path\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}," +
"{\"name\":\"_hoodie_file_name\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}," +
"{\"name\":\"_row_key\",\"type\":\"string\",\"nullable\":false,\"metadata\":{}}," +
"{\"name\":\"ts\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}}," +
"{\"name\":\"partition\",\"type\":\"string\",\"nullable\":false,\"metadata\":{}}]}")(hiveSyncConfig.tableProperties)

View File

@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi.functional
import org.apache.hudi.common.util.FileIOUtils
import org.apache.spark.sql.hudi.TestHoodieSqlBase
class TestSqlStatement extends TestHoodieSqlBase {
val STATE_INIT = 0
val STATE_SKIP_COMMENT = 1
val STATE_FINISH_COMMENT = 2
val STATE_READ_SQL = 3
val STATE_FINISH_READ_SQL = 4
val STATE_START_FIRST_RESULT_LINE = 5
val STATE_END_FIRST_RESULT_LINE = 6
val STATE_READ_RESULT_LINE = 7
val STATE_FINISH_READ_RESULT_LINE = 8
val STATE_AFTER_FINISH_READ_RESULT_LINE = 9
val STATE_START_LAST_RESULT_LINE = 10
val STATE_END_LAST_RESULT_LINE = 11
val STATE_FINISH_ALL = 12
test("Test Sql Statements") {
Seq("cow", "mor").foreach { tableType =>
withTempDir { tmp =>
val params = Map(
"tableType" -> tableType,
"tmpDir" -> tmp.getCanonicalPath
)
execSqlFile("/sql-statements.sql", params)
}
}
}
private def execSqlFile(sqlFile: String, params: Map[String, String]): Unit = {
val inputStream = getClass.getResourceAsStream(sqlFile)
var sqlText = FileIOUtils.readAsUTFString(inputStream)
// replace parameters in the sql file
params.foreach { case (k, v) =>
sqlText = sqlText.replace("${" + k + "}", v)
}
var pos = 0
var state = STATE_INIT
val sqlBuffer = new StringBuilder
var sqlResult: String = null
val sqlExpectResult = new StringBuilder
val sqlExpectLineResult = new StringBuilder
while (pos < sqlText.length) {
var c = sqlText.charAt(pos)
val (changedState, needFetchNext) = changeState(c, state)
state = changedState
pos = pos + 1
if (needFetchNext) {
c = sqlText.charAt(pos)
}
state match {
case STATE_READ_SQL =>
sqlBuffer.append(c)
case STATE_FINISH_READ_SQL =>
val sql = sqlBuffer.toString().trim
try {
if (sql.startsWith("select")) {
sqlResult = spark.sql(sql).collect()
.map(row => row.toSeq.mkString(" ")).mkString("\n")
} else {
spark.sql(sql)
sqlResult = "ok"
}
} catch {
case e: Throwable =>
throw new RuntimeException(s"Error in execute: $sql", e)
}
case STATE_READ_RESULT_LINE =>
sqlExpectLineResult.append(c)
case STATE_FINISH_READ_RESULT_LINE =>
if (sqlExpectResult.nonEmpty) {
sqlExpectResult.append("\n")
}
sqlExpectResult.append(sqlExpectLineResult.toString().trim)
sqlExpectLineResult.clear()
case STATE_END_LAST_RESULT_LINE =>
val expectResult = sqlExpectResult.toString()
.split("\n").map(line => line.split("\\s+").mkString(" "))
.mkString("\n")
if (expectResult != sqlResult) {
throw new IllegalArgumentException(s"UnExpect result for: $sqlBuffer\n" +
s"Expect:\n $expectResult, Actual:\n $sqlResult")
}
sqlBuffer.clear()
sqlExpectResult.clear()
sqlResult = null
case _=>
}
}
state = STATE_FINISH_ALL
}
/**
* Change current state.
* @param c Current char.
* @param state Current state.
* @return (changedState, needFetchNext)
*/
private def changeState(c: Char, state: Int): (Int, Boolean) = {
state match {
case STATE_INIT | STATE_FINISH_COMMENT |
STATE_FINISH_READ_SQL | STATE_END_LAST_RESULT_LINE =>
if (c == '#') {
(STATE_SKIP_COMMENT, false)
} else if (c == '+') {
(STATE_START_FIRST_RESULT_LINE, false)
} else if (!Character.isWhitespace(c)) {
(STATE_READ_SQL, false)
} else {
(STATE_INIT, false)
}
case STATE_SKIP_COMMENT =>
if (c == '\n' || c == '\r') {
(STATE_FINISH_COMMENT, false)
} else {
(state, false)
}
case STATE_READ_SQL =>
if (c == ';') {
(STATE_FINISH_READ_SQL, false)
} else {
(state, false)
}
case STATE_START_FIRST_RESULT_LINE =>
if (c == '+') {
(STATE_END_FIRST_RESULT_LINE, false)
} else {
(state, false)
}
case STATE_END_FIRST_RESULT_LINE =>
if (c == '|') {
(STATE_READ_RESULT_LINE, true)
} else {
(state, false)
}
case STATE_READ_RESULT_LINE =>
if (c == '|') {
(STATE_FINISH_READ_RESULT_LINE, false)
} else {
(state, false)
}
case STATE_FINISH_READ_RESULT_LINE | STATE_AFTER_FINISH_READ_RESULT_LINE =>
if (c == '+') {
(STATE_START_LAST_RESULT_LINE, false)
} else if (c == '|') {
(STATE_READ_RESULT_LINE, true)
} else {
(STATE_AFTER_FINISH_READ_RESULT_LINE, false)
}
case STATE_START_LAST_RESULT_LINE =>
if (c == '+') {
(STATE_END_LAST_RESULT_LINE, false)
} else {
(state, false)
}
case _ =>
throw new IllegalArgumentException(s"Illegal State: $state meet '$c'")
}
}
}

View File

@@ -0,0 +1,275 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import scala.collection.JavaConverters._
import org.apache.hudi.common.model.HoodieRecord
import org.apache.hudi.hadoop.realtime.HoodieParquetRealtimeInputFormat
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField}
class TestCreateTable extends TestHoodieSqlBase {
test("Test Create Managed Hoodie Table") {
val tableName = generateTableName
// Create a managed table
spark.sql(
s"""
| create table $tableName (
| id int,
| name string,
| price double,
| ts long
| ) using hudi
| options (
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName))
assertResult(tableName)(table.identifier.table)
assertResult("hudi")(table.provider.get)
assertResult(CatalogTableType.MANAGED)(table.tableType)
assertResult(
HoodieRecord.HOODIE_META_COLUMNS.asScala.map(StructField(_, StringType))
++ Seq(
StructField("id", IntegerType),
StructField("name", StringType),
StructField("price", DoubleType),
StructField("ts", LongType))
)(table.schema.fields)
}
test("Test Create External Hoodie Table") {
withTempDir { tmp =>
// Test create cow table.
val tableName = generateTableName
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| options (
| primaryKey = 'id,name',
| type = 'cow'
| )
| location '${tmp.getCanonicalPath}'
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName))
assertResult(tableName)(table.identifier.table)
assertResult("hudi")(table.provider.get)
assertResult(CatalogTableType.EXTERNAL)(table.tableType)
assertResult(
HoodieRecord.HOODIE_META_COLUMNS.asScala.map(StructField(_, StringType))
++ Seq(
StructField("id", IntegerType),
StructField("name", StringType),
StructField("price", DoubleType),
StructField("ts", LongType))
)(table.schema.fields)
assertResult(table.storage.properties("type"))("cow")
assertResult(table.storage.properties("primaryKey"))("id,name")
spark.sql(s"drop table $tableName")
// Test create mor partitioned table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long,
| dt string
|) using hudi
| partitioned by (dt)
| options (
| primaryKey = 'id',
| type = 'mor'
| )
| location '${tmp.getCanonicalPath}/h0'
""".stripMargin)
val table2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName))
assertResult(table2.storage.properties("type"))("mor")
assertResult(table2.storage.properties("primaryKey"))("id")
assertResult(Seq("dt"))(table2.partitionColumnNames)
assertResult(classOf[HoodieParquetRealtimeInputFormat].getCanonicalName)(table2.storage.inputFormat.get)
// Test create a external table with an exist table in the path
val tableName3 = generateTableName
spark.sql(
s"""
|create table $tableName3
|using hudi
|location '${tmp.getCanonicalPath}/h0'
""".stripMargin)
val table3 = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName3))
assertResult(table3.storage.properties("type"))("mor")
assertResult(table3.storage.properties("primaryKey"))("id")
assertResult(
HoodieRecord.HOODIE_META_COLUMNS.asScala.map(StructField(_, StringType))
++ Seq(
StructField("id", IntegerType),
StructField("name", StringType),
StructField("price", DoubleType),
StructField("ts", LongType),
StructField("dt", StringType)
)
)(table3.schema.fields)
}
}
test("Test Table Column Validate") {
withTempDir {tmp =>
val tableName = generateTableName
assertThrows[IllegalArgumentException] {
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| options (
| primaryKey = 'id1',
| type = 'cow'
| )
| location '${tmp.getCanonicalPath}'
""".stripMargin)
}
assertThrows[IllegalArgumentException] {
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| options (
| primaryKey = 'id',
| preCombineField = 'ts1',
| type = 'cow'
| )
| location '${tmp.getCanonicalPath}'
""".stripMargin)
}
assertThrows[IllegalArgumentException] {
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| options (
| primaryKey = 'id',
| preCombineField = 'ts',
| type = 'cow1'
| )
| location '${tmp.getCanonicalPath}'
""".stripMargin)
}
}
}
test("Test Create Table As Select") {
withTempDir { tmp =>
// Create Non-Partitioned table
val tableName1 = generateTableName
spark.sql(
s"""
|create table $tableName1 using hudi
| location '${tmp.getCanonicalPath}/$tableName1'
| AS
| select 1 as id, 'a1' as name, 10 as price, 1000 as ts
""".stripMargin)
checkAnswer(s"select id, name, price, ts from $tableName1")(
Seq(1, "a1", 10.0, 1000)
)
// Create Partitioned table
val tableName2 = generateTableName
spark.sql(
s"""
| create table $tableName2 using hudi
| partitioned by (dt)
| location '${tmp.getCanonicalPath}/$tableName2'
| AS
| select 1 as id, 'a1' as name, 10 as price, '2021-04-01' as dt
""".stripMargin
)
checkAnswer(s"select id, name, price, dt from $tableName2") (
Seq(1, "a1", 10, "2021-04-01")
)
// Create Partitioned table with timestamp data type
val tableName3 = generateTableName
// CTAS failed with null primaryKey
assertThrows[Exception] {
spark.sql(
s"""
| create table $tableName3 using hudi
| partitioned by (dt)
| options(primaryKey = 'id')
| location '${tmp.getCanonicalPath}/$tableName3'
| AS
| select null as id, 'a1' as name, 10 as price, '2021-05-07' as dt
|
""".stripMargin
)}
// Create table with timestamp type partition
spark.sql(
s"""
| create table $tableName3 using hudi
| partitioned by (dt)
| location '${tmp.getCanonicalPath}/$tableName3'
| AS
| select cast('2021-05-06 00:00:00' as timestamp) as dt, 1 as id, 'a1' as name, 10 as
| price
""".stripMargin
)
checkAnswer(s"select id, name, price, cast(dt as string) from $tableName3")(
Seq(1, "a1", 10, "2021-05-06 00:00:00")
)
// Create table with date type partition
val tableName4 = generateTableName
spark.sql(
s"""
| create table $tableName4 using hudi
| partitioned by (dt)
| location '${tmp.getCanonicalPath}/$tableName4'
| AS
| select cast('2021-05-06' as date) as dt, 1 as id, 'a1' as name, 10 as
| price
""".stripMargin
)
checkAnswer(s"select id, name, price, cast(dt as string) from $tableName4")(
Seq(1, "a1", 10, "2021-05-06")
)
}
}
}

View File

@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
class TestDeleteTable extends TestHoodieSqlBase {
test("Test Delete Table") {
withTempDir { tmp =>
Seq("cow", "mor").foreach {tableType =>
val tableName = generateTableName
// create table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| options (
| type = '$tableType',
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
// insert data to table
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
checkAnswer(s"select id, name, price, ts from $tableName")(
Seq(1, "a1", 10.0, 1000)
)
// delete data from table
spark.sql(s"delete from $tableName where id = 1")
checkAnswer(s"select count(1) from $tableName") (
Seq(0)
)
spark.sql(s"insert into $tableName select 2, 'a2', 10, 1000")
spark.sql(s"delete from $tableName where id = 1")
checkAnswer(s"select id, name, price, ts from $tableName")(
Seq(2, "a2", 10.0, 1000)
)
spark.sql(s"delete from $tableName")
checkAnswer(s"select count(1) from $tableName")(
Seq(0)
)
}
}
}
}

View File

@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import java.io.File
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.util.Utils
import org.scalactic.source
import org.scalatest.{BeforeAndAfterAll, FunSuite, Tag}
class TestHoodieSqlBase extends FunSuite with BeforeAndAfterAll {
private lazy val sparkWareHouse = {
val dir = Utils.createTempDir()
Utils.deleteRecursively(dir)
dir
}
protected lazy val spark: SparkSession = SparkSession.builder()
.master("local[1]")
.appName("hoodie sql test")
.withExtensions(new HoodieSparkSessionExtension)
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("hoodie.datasource.meta.sync.enable", "false")
.config("hoodie.insert.shuffle.parallelism", "4")
.config("hoodie.upsert.shuffle.parallelism", "4")
.config("hoodie.delete.shuffle.parallelism", "4")
.config("spark.sql.warehouse.dir", sparkWareHouse.getCanonicalPath)
.getOrCreate()
private var tableId = 0
protected def withTempDir(f: File => Unit): Unit = {
val tempDir = Utils.createTempDir()
try f(tempDir) finally {
Utils.deleteRecursively(tempDir)
}
}
override protected def test(testName: String, testTags: Tag*)(testFun: => Any /* Assertion */)(implicit pos: source.Position): Unit = {
try super.test(testName, testTags: _*)(try testFun finally {
val catalog = spark.sessionState.catalog
catalog.listDatabases().foreach{db =>
catalog.listTables(db).foreach {table =>
catalog.dropTable(table, true, true)
}
}
})
}
protected def generateTableName: String = {
val name = s"h$tableId"
tableId = tableId + 1
name
}
override protected def afterAll(): Unit = {
Utils.deleteRecursively(sparkWareHouse)
spark.stop()
}
protected def checkAnswer(sql: String)(expects: Seq[Any]*): Unit = {
assertResult(expects.map(row => Row(row: _*)).toArray)(spark.sql(sql).collect())
}
}

View File

@@ -0,0 +1,223 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import org.apache.hudi.exception.HoodieDuplicateKeyException
class TestInsertTable extends TestHoodieSqlBase {
test("Test Insert Into") {
withTempDir { tmp =>
val tableName = generateTableName
// Create a partitioned table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long,
| dt string
|) using hudi
| partitioned by (dt)
| location '${tmp.getCanonicalPath}'
""".stripMargin)
// Insert into dynamic partition
spark.sql(
s"""
| insert into $tableName
| select 1 as id, 'a1' as name, 10 as price, 1000 as ts, '2021-01-05' as dt
""".stripMargin)
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
Seq(1, "a1", 10.0, 1000, "2021-01-05")
)
// Insert into static partition
spark.sql(
s"""
| insert into $tableName partition(dt = '2021-01-05')
| select 2 as id, 'a2' as name, 10 as price, 1000 as ts
""".stripMargin)
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
Seq(1, "a1", 10.0, 1000, "2021-01-05"),
Seq(2, "a2", 10.0, 1000, "2021-01-05")
)
}
}
test("Test Insert Into None Partitioned Table") {
withTempDir { tmp =>
val tableName = generateTableName
// Create none partitioned cow table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| options (
| type = 'cow',
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
spark.sql(s"insert into $tableName values(1, 'a1', 10, 1000)")
checkAnswer(s"select id, name, price, ts from $tableName")(
Seq(1, "a1", 10.0, 1000)
)
spark.sql(s"insert into $tableName select 2, 'a2', 12, 1000")
checkAnswer(s"select id, name, price, ts from $tableName")(
Seq(1, "a1", 10.0, 1000),
Seq(2, "a2", 12.0, 1000)
)
assertThrows[HoodieDuplicateKeyException] {
try {
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
} catch {
case e: Exception =>
var root: Throwable = e
while (root.getCause != null) {
root = root.getCause
}
throw root
}
}
// Create table with dropDup is true
val tableName2 = generateTableName
spark.sql("set hoodie.datasource.write.insert.drop.duplicates = true")
spark.sql(
s"""
|create table $tableName2 (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName2'
| options (
| type = 'mor',
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
spark.sql(s"insert into $tableName2 select 1, 'a1', 10, 1000")
// This record will be drop when dropDup is true
spark.sql(s"insert into $tableName2 select 1, 'a1', 12, 1000")
checkAnswer(s"select id, name, price, ts from $tableName2")(
Seq(1, "a1", 10.0, 1000)
)
}
}
test("Test Insert Overwrite") {
withTempDir { tmp =>
val tableName = generateTableName
// Create a partitioned table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long,
| dt string
|) using hudi
| partitioned by (dt)
| location '${tmp.getCanonicalPath}/$tableName'
""".stripMargin)
// Insert overwrite dynamic partition
spark.sql(
s"""
| insert overwrite table $tableName
| select 1 as id, 'a1' as name, 10 as price, 1000 as ts, '2021-01-05' as dt
""".stripMargin)
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
Seq(1, "a1", 10.0, 1000, "2021-01-05")
)
// Insert overwrite dynamic partition
spark.sql(
s"""
| insert overwrite table $tableName
| select 2 as id, 'a2' as name, 10 as price, 1000 as ts, '2021-01-06' as dt
""".stripMargin)
checkAnswer(s"select id, name, price, ts, dt from $tableName order by id")(
Seq(1, "a1", 10.0, 1000, "2021-01-05"),
Seq(2, "a2", 10.0, 1000, "2021-01-06")
)
// Insert overwrite static partition
spark.sql(
s"""
| insert overwrite table $tableName partition(dt = '2021-01-05')
| select * from (select 2 , 'a2', 12, 1000) limit 10
""".stripMargin)
checkAnswer(s"select id, name, price, ts, dt from $tableName order by dt")(
Seq(2, "a2", 12.0, 1000, "2021-01-05"),
Seq(2, "a2", 10.0, 1000, "2021-01-06")
)
// Insert data from another table
val tblNonPartition = generateTableName
spark.sql(
s"""
| create table $tblNonPartition (
| id int,
| name string,
| price double,
| ts long
| ) using hudi
| location '${tmp.getCanonicalPath}/$tblNonPartition'
""".stripMargin)
spark.sql(s"insert into $tblNonPartition select 1, 'a1', 10, 1000")
spark.sql(
s"""
| insert overwrite table $tableName partition(dt ='2021-01-04')
| select * from $tblNonPartition limit 10
""".stripMargin)
checkAnswer(s"select id, name, price, ts, dt from $tableName order by id,dt")(
Seq(1, "a1", 10.0, 1000, "2021-01-04"),
Seq(2, "a2", 12.0, 1000, "2021-01-05"),
Seq(2, "a2", 10.0, 1000, "2021-01-06")
)
spark.sql(
s"""
| insert overwrite table $tableName
| select id + 2, name, price, ts , '2021-01-04' from $tblNonPartition limit 10
""".stripMargin)
checkAnswer(s"select id, name, price, ts, dt from $tableName " +
s"where dt >='2021-01-04' and dt <= '2021-01-06' order by id,dt")(
Seq(2, "a2", 12.0, 1000, "2021-01-05"),
Seq(2, "a2", 10.0, 1000, "2021-01-06"),
Seq(3, "a1", 10.0, 1000, "2021-01-04")
)
// test insert overwrite non-partitioned table
spark.sql(s"insert overwrite table $tblNonPartition select 2, 'a2', 10, 1000")
checkAnswer(s"select id, name, price, ts from $tblNonPartition")(
Seq(2, "a2", 10.0, 1000)
)
}
}
}

View File

@@ -0,0 +1,535 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
import org.apache.hudi.{DataSourceReadOptions, HoodieDataSourceHelpers}
import org.apache.hudi.common.fs.FSUtils
class TestMergeIntoTable extends TestHoodieSqlBase {
test("Test MergeInto Basic") {
withTempDir { tmp =>
val tableName = generateTableName
// Create table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}'
| options (
| primaryKey ='id',
| preCombineField = 'ts'
| )
""".stripMargin)
// First merge with a extra input field 'flag' (insert a new record)
spark.sql(
s"""
| merge into $tableName
| using (
| select 1 as id, 'a1' as name, 10 as price, 1000 as ts, '1' as flag
| ) s0
| on s0.id = $tableName.id
| when matched and flag = '1' then update set
| id = s0.id, name = s0.name, price = s0.price, ts = s0.ts
| when not matched and flag = '1' then insert *
""".stripMargin)
checkAnswer(s"select id, name, price, ts from $tableName")(
Seq(1, "a1", 10.0, 1000)
)
// Second merge (update the record)
spark.sql(
s"""
| merge into $tableName
| using (
| select 1 as id, 'a1' as name, 10 as price, 1001 as ts
| ) s0
| on s0.id = $tableName.id
| when matched then update set
| id = s0.id, name = s0.name, price = s0.price + $tableName.price, ts = s0.ts
| when not matched then insert *
""".stripMargin)
checkAnswer(s"select id, name, price, ts from $tableName")(
Seq(1, "a1", 20.0, 1001)
)
// the third time merge (update & insert the record)
spark.sql(
s"""
| merge into $tableName
| using (
| select * from (
| select 1 as id, 'a1' as name, 10 as price, 1002 as ts
| union all
| select 2 as id, 'a2' as name, 12 as price, 1001 as ts
| )
| ) s0
| on s0.id = $tableName.id
| when matched then update set
| id = s0.id, name = s0.name, price = s0.price + $tableName.price, ts = s0.ts
| when not matched and id % 2 = 0 then insert *
""".stripMargin)
checkAnswer(s"select id, name, price, ts from $tableName")(
Seq(1, "a1", 30.0, 1002),
Seq(2, "a2", 12.0, 1001)
)
// the fourth merge (delete the record)
spark.sql(
s"""
| merge into $tableName
| using (
| select 1 as id, 'a1' as name, 12 as price, 1003 as ts
| ) s0
| on s0.id = $tableName.id
| when matched and id != 1 then update set
| id = s0.id, name = s0.name, price = s0.price, ts = s0.ts
| when matched and id = 1 then delete
| when not matched then insert *
""".stripMargin)
val cnt = spark.sql(s"select * from $tableName where id = 1").count()
assertResult(0)(cnt)
}
}
test("Test MergeInto with ignored record") {
withTempDir {tmp =>
val sourceTable = generateTableName
val targetTable = generateTableName
// Create source table
spark.sql(
s"""
| create table $sourceTable (
| id int,
| name string,
| price double,
| ts long
| ) using parquet
| location '${tmp.getCanonicalPath}/$sourceTable'
""".stripMargin)
// Create target table
spark.sql(
s"""
|create table $targetTable (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$targetTable'
| options (
| primaryKey ='id',
| preCombineField = 'ts'
| )
""".stripMargin)
// Insert data to source table
spark.sql(s"insert into $sourceTable values(1, 'a1', 10, 1000)")
spark.sql(s"insert into $sourceTable values(2, 'a2', 11, 1000)")
spark.sql(
s"""
| merge into $targetTable as t0
| using (select * from $sourceTable) as s0
| on t0.id = s0.id
| when matched then update set *
| when not matched and s0.name = 'a1' then insert *
""".stripMargin)
// The record of "name = 'a2'" will be filter
checkAnswer(s"select id, name, price, ts from $targetTable")(
Seq(1, "a1", 10.0, 1000)
)
spark.sql(s"insert into $targetTable select 3, 'a3', 12, 1000")
checkAnswer(s"select id, name, price, ts from $targetTable")(
Seq(1, "a1", 10.0, 1000),
Seq(3, "a3", 12, 1000)
)
spark.sql(
s"""
| merge into $targetTable as t0
| using (
| select * from (
| select 1 as s_id, 'a1' as name, 20 as price, 1001 as ts
| union all
| select 3 as s_id, 'a3' as name, 20 as price, 1003 as ts
| union all
| select 4 as s_id, 'a4' as name, 10 as price, 1004 as ts
| )
| ) s0
| on s0.s_id = t0.id
| when matched and ts = 1001 then update set id = s0.s_id, name = t0.name, price =
| s0.price, ts = s0.ts
""".stripMargin
)
// Ignore the update for id = 3
checkAnswer(s"select id, name, price, ts from $targetTable")(
Seq(1, "a1", 20.0, 1001),
Seq(3, "a3", 12.0, 1000)
)
}
}
test("Test MergeInto for MOR table ") {
withTempDir {tmp =>
val tableName = generateTableName
// Create a mor partitioned table.
spark.sql(
s"""
| create table $tableName (
| id int,
| name string,
| price double,
| ts long,
| dt string
| ) using hudi
| options (
| type = 'mor',
| primaryKey = 'id',
| preCombineField = 'ts'
| )
| partitioned by(dt)
| location '${tmp.getCanonicalPath}'
""".stripMargin)
// Insert data
spark.sql(
s"""
| merge into $tableName as t0
| using (
| select 1 as id, 'a1' as name, 10 as price, 1000 as ts, '2021-03-21' as dt
| ) as s0
| on t0.id = s0.id
| when not matched and s0.id % 2 = 1 then insert *
""".stripMargin
)
checkAnswer(s"select id,name,price,dt from $tableName")(
Seq(1, "a1", 10, "2021-03-21")
)
// Update data when matched-condition not matched.
spark.sql(
s"""
| merge into $tableName as t0
| using (
| select 1 as id, 'a1' as name, 12 as price, 1001 as ts, '2021-03-21' as dt
| ) as s0
| on t0.id = s0.id
| when matched and id % 2 = 0 then update set *
""".stripMargin
)
checkAnswer(s"select id,name,price,dt from $tableName")(
Seq(1, "a1", 10, "2021-03-21")
)
// Update data when matched-condition matched.
spark.sql(
s"""
| merge into $tableName as t0
| using (
| select 1 as id, 'a1' as name, 12 as price, 1001 as ts, '2021-03-21' as dt
| ) as s0
| on t0.id = s0.id
| when matched and s0.id % 2 = 1 then update set *
""".stripMargin
)
checkAnswer(s"select id,name,price,dt from $tableName")(
Seq(1, "a1", 12, "2021-03-21")
)
// Insert a new data.
spark.sql(
s"""
| merge into $tableName as t0
| using (
| select 2 as id, 'a2' as name, 10 as price, 1000 as ts, '2021-03-21' as dt
| ) as s0
| on t0.id = s0.id
| when not matched and s0.id % 2 = 0 then insert *
""".stripMargin
)
checkAnswer(s"select id,name,price,dt from $tableName order by id")(
Seq(1, "a1", 12, "2021-03-21"),
Seq(2, "a2", 10, "2021-03-21")
)
// Update with different source column names.
spark.sql(
s"""
| merge into $tableName t0
| using (
| select 2 as s_id, 'a2' as s_name, 15 as s_price, 1001 as s_ts, '2021-03-21' as dt
| ) s0
| on t0.id = s0.s_id
| when matched and s_ts = 1001 then update set *
""".stripMargin
)
checkAnswer(s"select id,name,price,dt from $tableName order by id")(
Seq(1, "a1", 12, "2021-03-21"),
Seq(2, "a2", 15, "2021-03-21")
)
// Delete with condition expression.
spark.sql(
s"""
| merge into $tableName t0
| using (
| select 1 as s_id, 'a2' as s_name, 15 as s_price, 1001 as s_ts, '2021-03-21' as dt
| ) s0
| on t0.id = s0.s_id + 1
| when matched and s_ts = 1001 then delete
""".stripMargin
)
checkAnswer(s"select id,name,price,dt from $tableName order by id")(
Seq(1, "a1", 12, "2021-03-21")
)
}
}
test("Test MergeInto with insert only") {
withTempDir {tmp =>
// Create a partitioned mor table
val tableName = generateTableName
spark.sql(
s"""
| create table $tableName (
| id bigint,
| name string,
| price double,
| dt string
| ) using hudi
| options (
| type = 'mor',
| primaryKey = 'id'
| )
| partitioned by(dt)
| location '${tmp.getCanonicalPath}'
""".stripMargin)
spark.sql(s"insert into $tableName select 1, 'a1', 10, '2021-03-21'")
spark.sql(
s"""
| merge into $tableName as t0
| using (
| select 2 as id, 'a2' as name, 10 as price, 1000 as ts, '2021-03-20' as dt
| ) s0
| on s0.id = t0.id
| when not matched and s0.id % 2 = 0 then insert (id,name,price,dt)
| values(s0.id,s0.name,s0.price,s0.dt)
""".stripMargin)
checkAnswer(s"select id,name,price,dt from $tableName order by id")(
Seq(1, "a1", 10, "2021-03-21"),
Seq(2, "a2", 10, "2021-03-20")
)
spark.sql(
s"""
| merge into $tableName as t0
| using (
| select 3 as id, 'a3' as name, 10 as price, 1000 as ts, '2021-03-20' as dt
| ) s0
| on s0.id = t0.id
| when not matched and s0.id % 2 = 0 then insert (id,name,price,dt)
| values(s0.id,s0.name,s0.price,s0.dt)
""".stripMargin)
// id = 3 should not write to the table as it has filtered by id % 2 = 0
checkAnswer(s"select id,name,price,dt from $tableName order by id")(
Seq(1, "a1", 10, "2021-03-21"),
Seq(2, "a2", 10, "2021-03-20")
)
}
}
test("Test MergeInto For PreCombineField") {
withTempDir { tmp =>
Seq("cow", "mor").foreach { tableType =>
val tableName1 = generateTableName
// Create a mor partitioned table.
spark.sql(
s"""
| create table $tableName1 (
| id int,
| name string,
| price double,
| v long,
| dt string
| ) using hudi
| options (
| type = '$tableType',
| primaryKey = 'id',
| preCombineField = 'v'
| )
| partitioned by(dt)
| location '${tmp.getCanonicalPath}/$tableName1'
""".stripMargin)
// Insert data
spark.sql(
s"""
| merge into $tableName1 as t0
| using (
| select 1 as id, 'a1' as name, 10 as price, 1001 as v, '2021-03-21' as dt
| ) as s0
| on t0.id = s0.id
| when not matched and s0.id % 2 = 1 then insert *
""".stripMargin
)
checkAnswer(s"select id,name,price,dt,v from $tableName1")(
Seq(1, "a1", 10, "2021-03-21", 1001)
)
// Update data with a smaller version value
spark.sql(
s"""
| merge into $tableName1 as t0
| using (
| select 1 as id, 'a1' as name, 11 as price, 1000 as v, '2021-03-21' as dt
| ) as s0
| on t0.id = s0.id
| when matched and s0.id % 2 = 1 then update set *
""".stripMargin
)
// Update failed as v = 1000 < 1001
checkAnswer(s"select id,name,price,dt,v from $tableName1")(
Seq(1, "a1", 10, "2021-03-21", 1001)
)
// Update data with a bigger version value
spark.sql(
s"""
| merge into $tableName1 as t0
| using (
| select 1 as id, 'a1' as name, 12 as price, 1002 as v, '2021-03-21' as dt
| ) as s0
| on t0.id = s0.id
| when matched and s0.id % 2 = 1 then update set *
""".stripMargin
)
// Update success
checkAnswer(s"select id,name,price,dt,v from $tableName1")(
Seq(1, "a1", 12, "2021-03-21", 1002)
)
}
}
}
test("Merge Hudi to Hudi") {
withTempDir { tmp =>
Seq("cow", "mor").foreach { tableType =>
val sourceTable = generateTableName
spark.sql(
s"""
|create table $sourceTable (
| id int,
| name string,
| price double,
| _ts long
|) using hudi
|options(
| type ='$tableType',
| primaryKey = 'id',
| preCombineField = '_ts'
|)
|location '${tmp.getCanonicalPath}/$sourceTable'
""".stripMargin)
val targetTable = generateTableName
val targetBasePath = s"${tmp.getCanonicalPath}/$targetTable"
spark.sql(
s"""
|create table $targetTable (
| id int,
| name string,
| price double,
| _ts long
|) using hudi
|options(
| type ='$tableType',
| primaryKey = 'id',
| preCombineField = '_ts'
|)
|location '$targetBasePath'
""".stripMargin)
// First merge
spark.sql(s"insert into $sourceTable values(1, 'a1', 10, 1000)")
spark.sql(
s"""
|merge into $targetTable t0
|using $sourceTable s0
|on t0.id = s0.id
|when not matched then insert *
""".stripMargin)
checkAnswer(s"select id, name, price, _ts from $targetTable")(
Seq(1, "a1", 10, 1000)
)
val fs = FSUtils.getFs(targetBasePath, spark.sessionState.newHadoopConf())
val firstCommitTime = HoodieDataSourceHelpers.latestCommit(fs, targetBasePath)
// Second merge
spark.sql(s"update $sourceTable set price = 12, _ts = 1001 where id = 1")
spark.sql(
s"""
|merge into $targetTable t0
|using $sourceTable s0
|on t0.id = s0.id
|when matched and cast(_ts as string) > '1000' then update set *
""".stripMargin)
checkAnswer(s"select id, name, price, _ts from $targetTable")(
Seq(1, "a1", 12, 1001)
)
// Test incremental query
val hudiIncDF1 = spark.read.format("org.apache.hudi")
.option(DataSourceReadOptions.QUERY_TYPE_OPT_KEY, DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL)
.option(DataSourceReadOptions.BEGIN_INSTANTTIME_OPT_KEY, "000")
.option(DataSourceReadOptions.END_INSTANTTIME_OPT_KEY, firstCommitTime)
.load(targetBasePath)
hudiIncDF1.createOrReplaceTempView("inc1")
checkAnswer(s"select id, name, price, _ts from inc1")(
Seq(1, "a1", 10, 1000)
)
val secondCommitTime = HoodieDataSourceHelpers.latestCommit(fs, targetBasePath)
// Third merge
spark.sql(s"insert into $sourceTable values(2, 'a2', 10, 1001)")
spark.sql(
s"""
|merge into $targetTable t0
|using $sourceTable s0
|on t0.id = s0.id
|when matched then update set *
|when not matched and name = 'a2' then insert *
""".stripMargin)
checkAnswer(s"select id, name, price, _ts from $targetTable order by id")(
Seq(1, "a1", 12, 1001),
Seq(2, "a2", 10, 1001)
)
// Test incremental query
val hudiIncDF2 = spark.read.format("org.apache.hudi")
.option(DataSourceReadOptions.QUERY_TYPE_OPT_KEY, DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL)
.option(DataSourceReadOptions.BEGIN_INSTANTTIME_OPT_KEY, secondCommitTime)
.load(targetBasePath)
hudiIncDF2.createOrReplaceTempView("inc2")
checkAnswer(s"select id, name, price, _ts from inc2 order by id")(
Seq(1, "a1", 12, 1001),
Seq(2, "a2", 10, 1001)
)
}
}
}
}

View File

@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi
class TestUpdateTable extends TestHoodieSqlBase {
test("Test Update Table") {
withTempDir { tmp =>
Seq("cow", "mor").foreach {tableType =>
val tableName = generateTableName
// create table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| ts long
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| options (
| type = '$tableType',
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
// insert data to table
spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000")
checkAnswer(s"select id, name, price, ts from $tableName")(
Seq(1, "a1", 10.0, 1000)
)
// update data
spark.sql(s"update $tableName set price = 20 where id = 1")
checkAnswer(s"select id, name, price, ts from $tableName")(
Seq(1, "a1", 20.0, 1000)
)
// update data
spark.sql(s"update $tableName set price = price * 2 where id = 1")
checkAnswer(s"select id, name, price, ts from $tableName")(
Seq(1, "a1", 40.0, 1000)
)
}
}
}
}

View File

@@ -143,6 +143,24 @@
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.antlr</groupId>
<artifactId>antlr4-maven-plugin</artifactId>
<version>4.7</version>
<executions>
<execution>
<goals>
<goal>antlr4</goal>
</goals>
</execution>
</executions>
<configuration>
<visitor>true</visitor>
<listener>true</listener>
<sourceDirectory>../hudi-spark2/src/main/antlr4/</sourceDirectory>
<libDirectory>../hudi-spark2/src/main/antlr4/imports</libDirectory>
</configuration>
</plugin>
</plugins>
</build>
@@ -172,7 +190,7 @@
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark-common</artifactId>
<artifactId>hudi-spark-common_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
@@ -223,7 +241,7 @@
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark-common</artifactId>
<artifactId>hudi-spark-common_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
grammar HoodieSqlBase;
import SqlBase;
singleStatement
: statement EOF
;
statement
: mergeInto #mergeIntoTable
| updateTableStmt #updateTable
| deleteTableStmt #deleteTable
| .*? #passThrough
;
mergeInto
: MERGE INTO target=tableIdentifier tableAlias
USING (source=tableIdentifier | '(' subquery = query ')') tableAlias
mergeCondition
matchedClauses*
notMatchedClause*
;
mergeCondition
: ON condition=booleanExpression
;
matchedClauses
: deleteClause
| updateClause
;
notMatchedClause
: insertClause
;
deleteClause
: WHEN MATCHED (AND deleteCond=booleanExpression)? THEN deleteAction
| WHEN deleteCond=booleanExpression THEN deleteAction
;
updateClause
: WHEN MATCHED (AND updateCond=booleanExpression)? THEN updateAction
| WHEN updateCond=booleanExpression THEN updateAction
;
insertClause
: WHEN NOT MATCHED (AND insertCond=booleanExpression)? THEN insertAction
| WHEN insertCond=booleanExpression THEN insertAction
;
deleteAction
: DELETE
;
updateAction
: UPDATE SET ASTERISK
| UPDATE SET assignmentList
;
insertAction
: INSERT ASTERISK
| INSERT '(' columns=qualifiedNameList ')' VALUES '(' expression (',' expression)* ')'
;
assignmentList
: assignment (',' assignment)*
;
assignment
: key=qualifiedName EQ value=expression
;
qualifiedNameList
: qualifiedName (',' qualifiedName)*
;
updateTableStmt
: UPDATE tableIdentifier SET assignmentList (WHERE where=booleanExpression)?
;
deleteTableStmt
: DELETE FROM tableIdentifier (WHERE where=booleanExpression)?
;
PRIMARY: 'PRIMARY';
KEY: 'KEY';
MERGE: 'MERGE';
MATCHED: 'MATCHED';
UPDATE: 'UPDATE';

View File

@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.adapter
import org.apache.hudi.Spark2RowSerDe
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, Like}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Join, LogicalPlan}
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.execution.datasources.{Spark2ParsePartitionUtil, SparkParsePartitionUtil}
import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.hudi.parser.HoodieSqlParser
import org.apache.spark.sql.internal.SQLConf
/**
* A sql adapter for spark2.
*/
class Spark2Adapter extends SparkAdapter {
override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
new Spark2RowSerDe(encoder)
}
override def toTableIdentify(aliasId: AliasIdentifier): TableIdentifier = {
TableIdentifier(aliasId.identifier, aliasId.database)
}
override def toTableIdentify(relation: UnresolvedRelation): TableIdentifier = {
relation.tableIdentifier
}
override def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join = {
Join(left, right, joinType, None)
}
override def isInsertInto(plan: LogicalPlan): Boolean = {
plan.isInstanceOf[InsertIntoTable]
}
override def getInsertIntoChildren(plan: LogicalPlan):
Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = {
plan match {
case InsertIntoTable(table, partition, query, overwrite, ifPartitionNotExists) =>
Some((table, partition, query, overwrite, ifPartitionNotExists))
case _=> None
}
}
override def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]],
query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan = {
InsertIntoTable(table, partition, query, overwrite, ifPartitionNotExists)
}
override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = {
Some(
(spark: SparkSession, delegate: ParserInterface) => new HoodieSqlParser(spark, delegate)
)
}
override def createSparkParsePartitionUtil(conf: SQLConf): SparkParsePartitionUtil = new Spark2ParsePartitionUtil
override def createLike(left: Expression, right: Expression): Expression = {
Like(left, right)
}
}

View File

@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.Expression
// This code is just copy from v2Commands.scala in spark 3.0
case class DeleteFromTable(
table: LogicalPlan,
condition: Option[Expression]) extends Command {
override def children: Seq[LogicalPlan] = Seq(table)
}

View File

@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
import org.apache.spark.sql.types.DataType
// This code is just copy from v2Commands.scala in spark 3.0
/**
* The logical plan of the MERGE INTO command that works for v2 tables.
*/
case class MergeIntoTable(
targetTable: LogicalPlan,
sourceTable: LogicalPlan,
mergeCondition: Expression,
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction]) extends Command {
override def children: Seq[LogicalPlan] = Seq(targetTable, sourceTable)
}
sealed abstract class MergeAction extends Expression with Unevaluable {
def condition: Option[Expression]
override def foldable: Boolean = false
override def nullable: Boolean = false
override def dataType: DataType = throw new UnresolvedException(this, "nullable")
override def children: Seq[Expression] = condition.toSeq
}
case class DeleteAction(condition: Option[Expression]) extends MergeAction
case class UpdateAction(
condition: Option[Expression],
assignments: Seq[Assignment]) extends MergeAction {
override def children: Seq[Expression] = condition.toSeq ++ assignments
}
case class InsertAction(
condition: Option[Expression],
assignments: Seq[Assignment]) extends MergeAction {
override def children: Seq[Expression] = condition.toSeq ++ assignments
}
case class Assignment(key: Expression, value: Expression) extends Expression with Unevaluable {
override def foldable: Boolean = false
override def nullable: Boolean = false
override def dataType: DataType = throw new UnresolvedException(this, "nullable")
override def children: Seq[Expression] = key :: value :: Nil
}

View File

@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.Expression
// This code is just copy from v2Commands.scala in spark 3.0
case class UpdateTable(
table: LogicalPlan,
assignments: Seq[Assignment],
condition: Option[Expression]
) extends Command {
override def children: Seq[LogicalPlan] = Seq(table)
}

View File

@@ -0,0 +1,230 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.parser
import org.antlr.v4.runtime.tree.ParseTree
import org.apache.hudi.spark.sql.parser.HoodieSqlBaseBaseVisitor
import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface, ParserUtils}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.internal.SQLConf
import scala.collection.JavaConverters._
/**
* The AstBuilder for HoodieSqlParser to parser the AST tree to Logical Plan.
* Here we only do the parser for the extended sql syntax. e.g MergeInto. For
* other sql syntax we use the delegate sql parser which is the SparkSqlParser.
*/
class HoodieSqlAstBuilder(conf: SQLConf, delegate: ParserInterface) extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging {
import ParserUtils._
override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
ctx.statement().accept(this).asInstanceOf[LogicalPlan]
}
override def visitMergeIntoTable (ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
visitMergeInto(ctx.mergeInto())
}
override def visitMergeInto(ctx: MergeIntoContext): LogicalPlan = withOrigin(ctx) {
val target = UnresolvedRelation(visitTableIdentifier(ctx.target))
val source = if (ctx.source != null) {
UnresolvedRelation(visitTableIdentifier(ctx.source))
} else {
val queryText = treeToString(ctx.subquery)
delegate.parsePlan(queryText)
}
val aliasedTarget =
if (ctx.tableAlias(0) != null) mayApplyAliasPlan(ctx.tableAlias(0), target) else target
val aliasedSource =
if (ctx.tableAlias(1) != null) mayApplyAliasPlan(ctx.tableAlias(1), source) else source
val mergeCondition = expression(ctx.mergeCondition().condition)
if (ctx.matchedClauses().size() > 2) {
throw new ParseException("There should be at most 2 'WHEN MATCHED' clauses.",
ctx.matchedClauses.get(2))
}
val matchedClauses: Seq[MergeAction] = ctx.matchedClauses().asScala.flatMap {
c =>
val deleteCtx = c.deleteClause()
val deleteClause = if (deleteCtx != null) {
val deleteCond = if (deleteCtx.deleteCond != null) {
Some(expression(deleteCtx.deleteCond))
} else {
None
}
Some(DeleteAction(deleteCond))
} else {
None
}
val updateCtx = c.updateClause()
val updateClause = if (updateCtx != null) {
val updateAction = updateCtx.updateAction()
val updateCond = if (updateCtx.updateCond != null) {
Some(expression(updateCtx.updateCond))
} else {
None
}
if (updateAction.ASTERISK() != null) {
Some(UpdateAction(updateCond, Seq.empty))
} else {
val assignments = withAssignments(updateAction.assignmentList())
Some(UpdateAction(updateCond, assignments))
}
} else {
None
}
(deleteClause ++ updateClause).toSeq
}
val notMatchedClauses: Seq[InsertAction] = ctx.notMatchedClause().asScala.map {
notMatchedClause =>
val insertCtx = notMatchedClause.insertClause()
val insertAction = insertCtx.insertAction()
val insertCond = if (insertCtx.insertCond != null) {
Some(expression(insertCtx.insertCond))
} else {
None
}
if (insertAction.ASTERISK() != null) {
InsertAction(insertCond, Seq.empty)
} else {
val attrList = insertAction.columns.qualifiedName().asScala
.map(attr => UnresolvedAttribute(visitQualifiedName(attr)))
val attrSet = scala.collection.mutable.Set[UnresolvedAttribute]()
attrList.foreach(attr => {
if (attrSet.contains(attr)) {
throw new ParseException(s"find duplicate field :'${attr.name}'",
insertAction.columns)
}
attrSet += attr
})
val valueList = insertAction.expression().asScala.map(expression)
if (attrList.size != valueList.size) {
throw new ParseException("The columns of source and target tables are not equal: " +
s"target: $attrList, source: $valueList", insertAction)
}
val assignments = attrList.zip(valueList).map(kv => Assignment(kv._1, kv._2))
InsertAction(insertCond, assignments)
}
}
MergeIntoTable(aliasedTarget, aliasedSource, mergeCondition,
matchedClauses, notMatchedClauses)
}
private def withAssignments(assignCtx: AssignmentListContext): Seq[Assignment] =
withOrigin(assignCtx) {
assignCtx.assignment().asScala.map { assign =>
Assignment(UnresolvedAttribute(visitQualifiedName(assign.key)),
expression(assign.value))
}
}
override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
val updateStmt = ctx.updateTableStmt()
val table = UnresolvedRelation(visitTableIdentifier(updateStmt.tableIdentifier()))
val condition = if (updateStmt.where != null) Some(expression(updateStmt.where)) else None
val assignments = withAssignments(updateStmt.assignmentList())
UpdateTable(table, assignments, condition)
}
override def visitDeleteTable (ctx: DeleteTableContext): LogicalPlan = withOrigin(ctx) {
val deleteStmt = ctx.deleteTableStmt()
val table = UnresolvedRelation(visitTableIdentifier(deleteStmt.tableIdentifier()))
val condition = if (deleteStmt.where != null) Some(expression(deleteStmt.where)) else None
DeleteFromTable(table, condition)
}
/**
* Convert the tree to string.
*/
private def treeToString(tree: ParseTree): String = {
if (tree.getChildCount == 0) {
tree.getText
} else {
(for (i <- 0 until tree.getChildCount) yield {
treeToString(tree.getChild(i))
}).mkString(" ")
}
}
/**
* Parse the expression tree to spark sql Expression.
* Here we use the SparkSqlParser to do the parse.
*/
private def expression(tree: ParseTree): Expression = {
val expressionText = treeToString(tree)
delegate.parseExpression(expressionText)
}
// ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder
/**
* If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and
* column aliases for a [[LogicalPlan]].
*/
protected def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = {
if (tableAlias.strictIdentifier != null) {
val subquery = SubqueryAlias(tableAlias.strictIdentifier.getText, plan)
if (tableAlias.identifierList != null) {
val columnNames = visitIdentifierList(tableAlias.identifierList)
UnresolvedSubqueryColumnAliases(columnNames, subquery)
} else {
subquery
}
} else {
plan
}
}
/**
* Parse a qualified name to a multipart name.
*/
override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) {
ctx.identifier.asScala.map(_.getText)
}
/**
* Create a Sequence of Strings for a parenthesis enclosed alias list.
*/
override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
visitIdentifierSeq(ctx.identifierSeq)
}
/**
* Create a Sequence of Strings for an identifier list.
*/
override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
ctx.identifier.asScala.map(_.getText)
}
/* ********************************************************************************************
* Table Identifier parsing
* ******************************************************************************************** */
/**
* Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
*/
override def visitTableIdentifier(ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
}
}

View File

@@ -0,0 +1,171 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hudi.parser
import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
import org.antlr.v4.runtime.tree.TerminalNodeImpl
import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser.{NonReservedContext, QuotedIdentifierContext}
import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseListener, HoodieSqlBaseLexer, HoodieSqlBaseParser}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, SparkSession}
class HoodieSqlParser(session: SparkSession, delegate: ParserInterface)
extends ParserInterface with Logging {
private lazy val conf = session.sqlContext.conf
private lazy val builder = new HoodieSqlAstBuilder(conf, delegate)
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
builder.visit(parser.singleStatement()) match {
case plan: LogicalPlan => plan
case _=> delegate.parsePlan(sqlText)
}
}
override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText)
override def parseTableIdentifier(sqlText: String): TableIdentifier =
delegate.parseTableIdentifier(sqlText)
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
delegate.parseFunctionIdentifier(sqlText)
override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText)
override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText)
protected def parse[T](command: String)(toResult: HoodieSqlBaseParser => T): T = {
logDebug(s"Parsing command: $command")
val lexer = new HoodieSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)
lexer.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced
val tokenStream = new CommonTokenStream(lexer)
val parser = new HoodieSqlBaseParser(tokenStream)
parser.addParseListener(PostProcessor)
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
parser.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced
try {
try {
// first, try parsing with potentially faster SLL mode
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
toResult(parser)
}
catch {
case e: ParseCancellationException =>
// if we fail, parse with LL mode
tokenStream.seek(0) // rewind input stream
parser.reset()
// Try Again.
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
toResult(parser)
}
}
catch {
case e: ParseException if e.command.isDefined =>
throw e
case e: ParseException =>
throw e.withCommand(command)
case e: AnalysisException =>
val position = Origin(e.line, e.startPosition)
throw new ParseException(Option(command), e.message, position, position)
}
}
}
/**
* Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`.
*/
class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
override def consume(): Unit = wrapped.consume
override def getSourceName(): String = wrapped.getSourceName
override def index(): Int = wrapped.index
override def mark(): Int = wrapped.mark
override def release(marker: Int): Unit = wrapped.release(marker)
override def seek(where: Int): Unit = wrapped.seek(where)
override def size(): Int = wrapped.size
override def getText(interval: Interval): String = {
// ANTLR 4.7's CodePointCharStream implementations have bugs when
// getText() is called with an empty stream, or intervals where
// the start > end. See
// https://github.com/antlr/antlr4/commit/ac9f7530 for one fix
// that is not yet in a released ANTLR artifact.
if (size() > 0 && (interval.b - interval.a >= 0)) {
wrapped.getText(interval)
} else {
""
}
}
// scalastyle:off
override def LA(i: Int): Int = {
// scalastyle:on
val la = wrapped.LA(i)
if (la == 0 || la == IntStream.EOF) la
else Character.toUpperCase(la)
}
}
/**
* Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`.
*/
case object PostProcessor extends HoodieSqlBaseBaseListener {
/** Remove the back ticks from an Identifier. */
override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
replaceTokenByIdentifier(ctx, 1) { token =>
// Remove the double back ticks in the string.
token.setText(token.getText.replace("``", "`"))
token
}
}
/** Treat non-reserved keywords as Identifiers. */
override def exitNonReserved(ctx: NonReservedContext): Unit = {
replaceTokenByIdentifier(ctx, 0)(identity)
}
private def replaceTokenByIdentifier(
ctx: ParserRuleContext,
stripMargins: Int)(
f: CommonToken => CommonToken = identity): Unit = {
val parent = ctx.getParent
parent.removeLastChild()
val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
val newToken = new CommonToken(
new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
HoodieSqlBaseParser.IDENTIFIER,
token.getChannel,
token.getStartIndex + stripMargins,
token.getStopIndex - stripMargins)
parent.addChild(new TerminalNodeImpl(f(newToken)))
}
}

View File

@@ -183,7 +183,7 @@
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark-common</artifactId>
<artifactId>hudi-spark-common_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
@@ -214,7 +214,7 @@
</dependency>
<dependency>
<groupId>org.apache.hudi</groupId>
<artifactId>hudi-spark-common</artifactId>
<artifactId>hudi-spark-common_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>

View File

@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.adapter
import org.apache.hudi.Spark3RowSerDe
import org.apache.hudi.client.utils.SparkRowSerDe
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, Like}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, JoinHint, LogicalPlan}
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.execution.datasources.{Spark3ParsePartitionUtil, SparkParsePartitionUtil}
import org.apache.spark.sql.hudi.SparkAdapter
import org.apache.spark.sql.internal.SQLConf
/**
* A sql adapter for spark3.
*/
class Spark3Adapter extends SparkAdapter {
override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = {
new Spark3RowSerDe(encoder)
}
override def toTableIdentify(aliasId: AliasIdentifier): TableIdentifier = {
aliasId match {
case AliasIdentifier(name, Seq(database)) =>
TableIdentifier(name, Some(database))
case AliasIdentifier(name, Seq(_, database)) =>
TableIdentifier(name, Some(database))
case AliasIdentifier(name, Seq()) =>
TableIdentifier(name, None)
case _=> throw new IllegalArgumentException(s"Cannot cast $aliasId to TableIdentifier")
}
}
override def toTableIdentify(relation: UnresolvedRelation): TableIdentifier = {
relation.multipartIdentifier.asTableIdentifier
}
override def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join = {
Join(left, right, joinType, None, JoinHint.NONE)
}
override def isInsertInto(plan: LogicalPlan): Boolean = {
plan.isInstanceOf[InsertIntoStatement]
}
override def getInsertIntoChildren(plan: LogicalPlan):
Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = {
plan match {
case InsertIntoStatement(table, partitionSpec, query, overwrite, ifPartitionNotExists) =>
Some((table, partitionSpec, query, overwrite, ifPartitionNotExists))
case _=> None
}
}
override def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]],
query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan = {
InsertIntoStatement(table, partition, query, overwrite, ifPartitionNotExists)
}
override def createSparkParsePartitionUtil(conf: SQLConf): SparkParsePartitionUtil = {
new Spark3ParsePartitionUtil(conf)
}
override def createLike(left: Expression, right: Expression): Expression = {
new Like(left, right)
}
}