From ec6267c30328521621e6e803dfd6aaa65ae489f7 Mon Sep 17 00:00:00 2001 From: lw0090 Date: Sun, 18 Oct 2020 17:18:50 +0800 Subject: [PATCH] [HUDI-307] add test to check timestamp date decimal type write and read consistent (#2177) --- .../hudi/functional/TestCOWDataSource.scala | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala b/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala index c1e45c359..5b746deb0 100644 --- a/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala +++ b/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala @@ -17,13 +17,15 @@ package org.apache.hudi.functional +import java.sql.{Date, Timestamp} + import org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings import org.apache.hudi.config.HoodieWriteConfig import org.apache.hudi.testutils.HoodieClientTestBase import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions, HoodieDataSourceHelpers} -import org.apache.log4j.LogManager import org.apache.spark.sql._ import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{DataTypes, DateType, IntegerType, StringType, StructField, StructType, TimestampType} import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} @@ -32,8 +34,8 @@ import scala.collection.JavaConversions._ /** * Basic tests on the spark datasource for COW table. */ + class TestCOWDataSource extends HoodieClientTestBase { - private val log = LogManager.getLogger(getClass) var spark: SparkSession = null val commonOpts = Map( "hoodie.insert.shuffle.parallelism" -> "4", @@ -194,4 +196,36 @@ class TestCOWDataSource extends HoodieClientTestBase { .load(basePath) assertEquals(hoodieIncViewDF2.count(), insert2NewKeyCnt) } + + @Test def testComplexDataTypeWriteAndReadConsistency(): Unit = { + val schema = StructType(StructField("_row_key", StringType, true) :: StructField("name", StringType, true) + :: StructField("timeStampValue", TimestampType, true) :: StructField("dateValue", DateType, true) + :: StructField("decimalValue", DataTypes.createDecimalType(15, 10), true) :: StructField("timestamp", IntegerType, true) + :: StructField("partition", IntegerType, true) :: Nil) + + val records = Seq(Row("11", "Andy", Timestamp.valueOf("1970-01-01 13:31:24"), Date.valueOf("1991-11-07"), BigDecimal.valueOf(1.0), 11, 1), + Row("22", "lisi", Timestamp.valueOf("1970-01-02 13:31:24"), Date.valueOf("1991-11-08"), BigDecimal.valueOf(2.0), 11, 1), + Row("33", "zhangsan", Timestamp.valueOf("1970-01-03 13:31:24"), Date.valueOf("1991-11-09"), BigDecimal.valueOf(3.0), 11, 1)) + val rdd = jsc.parallelize(records) + val recordsDF = spark.createDataFrame(rdd, schema) + recordsDF.write.format("org.apache.hudi") + .options(commonOpts) + .mode(SaveMode.Append) + .save(basePath) + + val recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*/*") + recordsReadDF.printSchema() + recordsReadDF.schema.foreach(f => { + f.name match { + case "timeStampValue" => + assertEquals(f.dataType, org.apache.spark.sql.types.TimestampType) + case "dateValue" => + assertEquals(f.dataType, org.apache.spark.sql.types.DateType) + case "decimalValue" => + assertEquals(f.dataType, org.apache.spark.sql.types.DecimalType(15, 10)) + case _ => + } + }) + } }