[HUDI-3083] Support component data types for flink bulk_insert (#4470)
* [HUDI-3083] Support component data types for flink bulk_insert * add nested row type test
This commit is contained in:
@@ -18,19 +18,22 @@
|
||||
|
||||
package org.apache.hudi.io.storage.row.parquet;
|
||||
|
||||
import org.apache.flink.table.data.ArrayData;
|
||||
import org.apache.flink.table.data.DecimalDataUtils;
|
||||
import org.apache.flink.table.data.MapData;
|
||||
import org.apache.flink.table.data.RowData;
|
||||
import org.apache.flink.table.data.TimestampData;
|
||||
import org.apache.flink.table.types.logical.ArrayType;
|
||||
import org.apache.flink.table.types.logical.DecimalType;
|
||||
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
|
||||
import org.apache.flink.table.types.logical.LogicalType;
|
||||
import org.apache.flink.table.types.logical.MapType;
|
||||
import org.apache.flink.table.types.logical.RowType;
|
||||
import org.apache.flink.table.types.logical.TimestampType;
|
||||
import org.apache.flink.util.Preconditions;
|
||||
import org.apache.parquet.io.api.Binary;
|
||||
import org.apache.parquet.io.api.RecordConsumer;
|
||||
import org.apache.parquet.schema.GroupType;
|
||||
import org.apache.parquet.schema.Type;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
@@ -46,7 +49,8 @@ import static org.apache.flink.formats.parquet.vector.reader.TimestampColumnRead
|
||||
/**
|
||||
* Writes a record to the Parquet API with the expected schema in order to be written to a file.
|
||||
*
|
||||
* <p>Reference org.apache.flink.formats.parquet.row.ParquetRowDataWriter to support timestamp of INT64 8 bytes.
|
||||
* <p>Reference {@code org.apache.flink.formats.parquet.row.ParquetRowDataWriter}
|
||||
* to support timestamp of INT64 8 bytes and complex data types.
|
||||
*/
|
||||
public class ParquetRowDataWriter {
|
||||
|
||||
@@ -67,7 +71,7 @@ public class ParquetRowDataWriter {
|
||||
this.filedWriters = new FieldWriter[rowType.getFieldCount()];
|
||||
this.fieldNames = rowType.getFieldNames().toArray(new String[0]);
|
||||
for (int i = 0; i < rowType.getFieldCount(); i++) {
|
||||
this.filedWriters[i] = createWriter(rowType.getTypeAt(i), schema.getType(i));
|
||||
this.filedWriters[i] = createWriter(rowType.getTypeAt(i));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,59 +95,75 @@ public class ParquetRowDataWriter {
|
||||
recordConsumer.endMessage();
|
||||
}
|
||||
|
||||
private FieldWriter createWriter(LogicalType t, Type type) {
|
||||
if (type.isPrimitive()) {
|
||||
switch (t.getTypeRoot()) {
|
||||
case CHAR:
|
||||
case VARCHAR:
|
||||
return new StringWriter();
|
||||
case BOOLEAN:
|
||||
return new BooleanWriter();
|
||||
case BINARY:
|
||||
case VARBINARY:
|
||||
return new BinaryWriter();
|
||||
case DECIMAL:
|
||||
DecimalType decimalType = (DecimalType) t;
|
||||
return createDecimalWriter(decimalType.getPrecision(), decimalType.getScale());
|
||||
case TINYINT:
|
||||
return new ByteWriter();
|
||||
case SMALLINT:
|
||||
return new ShortWriter();
|
||||
case DATE:
|
||||
case TIME_WITHOUT_TIME_ZONE:
|
||||
case INTEGER:
|
||||
return new IntWriter();
|
||||
case BIGINT:
|
||||
return new LongWriter();
|
||||
case FLOAT:
|
||||
return new FloatWriter();
|
||||
case DOUBLE:
|
||||
return new DoubleWriter();
|
||||
case TIMESTAMP_WITHOUT_TIME_ZONE:
|
||||
TimestampType timestampType = (TimestampType) t;
|
||||
if (timestampType.getPrecision() == 3) {
|
||||
return new Timestamp64Writer();
|
||||
} else {
|
||||
return new Timestamp96Writer(timestampType.getPrecision());
|
||||
}
|
||||
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
|
||||
LocalZonedTimestampType localZonedTimestampType = (LocalZonedTimestampType) t;
|
||||
if (localZonedTimestampType.getPrecision() == 3) {
|
||||
return new Timestamp64Writer();
|
||||
} else {
|
||||
return new Timestamp96Writer(localZonedTimestampType.getPrecision());
|
||||
}
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unsupported type: " + type);
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unsupported data type: " + t);
|
||||
private FieldWriter createWriter(LogicalType t) {
|
||||
switch (t.getTypeRoot()) {
|
||||
case CHAR:
|
||||
case VARCHAR:
|
||||
return new StringWriter();
|
||||
case BOOLEAN:
|
||||
return new BooleanWriter();
|
||||
case BINARY:
|
||||
case VARBINARY:
|
||||
return new BinaryWriter();
|
||||
case DECIMAL:
|
||||
DecimalType decimalType = (DecimalType) t;
|
||||
return createDecimalWriter(decimalType.getPrecision(), decimalType.getScale());
|
||||
case TINYINT:
|
||||
return new ByteWriter();
|
||||
case SMALLINT:
|
||||
return new ShortWriter();
|
||||
case DATE:
|
||||
case TIME_WITHOUT_TIME_ZONE:
|
||||
case INTEGER:
|
||||
return new IntWriter();
|
||||
case BIGINT:
|
||||
return new LongWriter();
|
||||
case FLOAT:
|
||||
return new FloatWriter();
|
||||
case DOUBLE:
|
||||
return new DoubleWriter();
|
||||
case TIMESTAMP_WITHOUT_TIME_ZONE:
|
||||
TimestampType timestampType = (TimestampType) t;
|
||||
if (timestampType.getPrecision() == 3) {
|
||||
return new Timestamp64Writer();
|
||||
} else {
|
||||
return new Timestamp96Writer(timestampType.getPrecision());
|
||||
}
|
||||
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
|
||||
LocalZonedTimestampType localZonedTimestampType = (LocalZonedTimestampType) t;
|
||||
if (localZonedTimestampType.getPrecision() == 3) {
|
||||
return new Timestamp64Writer();
|
||||
} else {
|
||||
return new Timestamp96Writer(localZonedTimestampType.getPrecision());
|
||||
}
|
||||
case ARRAY:
|
||||
ArrayType arrayType = (ArrayType) t;
|
||||
LogicalType elementType = arrayType.getElementType();
|
||||
FieldWriter elementWriter = createWriter(elementType);
|
||||
return new ArrayWriter(elementWriter);
|
||||
case MAP:
|
||||
MapType mapType = (MapType) t;
|
||||
LogicalType keyType = mapType.getKeyType();
|
||||
LogicalType valueType = mapType.getValueType();
|
||||
FieldWriter keyWriter = createWriter(keyType);
|
||||
FieldWriter valueWriter = createWriter(valueType);
|
||||
return new MapWriter(keyWriter, valueWriter);
|
||||
case ROW:
|
||||
RowType rowType = (RowType) t;
|
||||
FieldWriter[] fieldWriters = rowType.getFields().stream()
|
||||
.map(RowType.RowField::getType).map(this::createWriter).toArray(FieldWriter[]::new);
|
||||
String[] fieldNames = rowType.getFields().stream()
|
||||
.map(RowType.RowField::getName).toArray(String[]::new);
|
||||
return new RowWriter(fieldNames, fieldWriters);
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unsupported type: " + t);
|
||||
}
|
||||
}
|
||||
|
||||
private interface FieldWriter {
|
||||
|
||||
void write(RowData row, int ordinal);
|
||||
|
||||
void write(ArrayData array, int ordinal);
|
||||
}
|
||||
|
||||
private class BooleanWriter implements FieldWriter {
|
||||
@@ -152,6 +172,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addBoolean(row.getBoolean(ordinal));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addBoolean(array.getBoolean(ordinal));
|
||||
}
|
||||
}
|
||||
|
||||
private class ByteWriter implements FieldWriter {
|
||||
@@ -160,6 +185,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addInteger(row.getByte(ordinal));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addInteger(array.getByte(ordinal));
|
||||
}
|
||||
}
|
||||
|
||||
private class ShortWriter implements FieldWriter {
|
||||
@@ -168,6 +198,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addInteger(row.getShort(ordinal));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addInteger(array.getShort(ordinal));
|
||||
}
|
||||
}
|
||||
|
||||
private class LongWriter implements FieldWriter {
|
||||
@@ -176,6 +211,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addLong(row.getLong(ordinal));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addLong(array.getLong(ordinal));
|
||||
}
|
||||
}
|
||||
|
||||
private class FloatWriter implements FieldWriter {
|
||||
@@ -184,6 +224,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addFloat(row.getFloat(ordinal));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addFloat(array.getFloat(ordinal));
|
||||
}
|
||||
}
|
||||
|
||||
private class DoubleWriter implements FieldWriter {
|
||||
@@ -192,6 +237,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addDouble(row.getDouble(ordinal));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addDouble(array.getDouble(ordinal));
|
||||
}
|
||||
}
|
||||
|
||||
private class StringWriter implements FieldWriter {
|
||||
@@ -200,6 +250,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addBinary(Binary.fromReusedByteArray(row.getString(ordinal).toBytes()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addBinary(Binary.fromReusedByteArray(array.getString(ordinal).toBytes()));
|
||||
}
|
||||
}
|
||||
|
||||
private class BinaryWriter implements FieldWriter {
|
||||
@@ -208,6 +263,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addBinary(Binary.fromReusedByteArray(array.getBinary(ordinal)));
|
||||
}
|
||||
}
|
||||
|
||||
private class IntWriter implements FieldWriter {
|
||||
@@ -216,6 +276,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addInteger(row.getInt(ordinal));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addInteger(array.getInt(ordinal));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -231,6 +296,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addLong(timestampToInt64(row.getTimestamp(ordinal, 3)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addLong(timestampToInt64(array.getTimestamp(ordinal, 3)));
|
||||
}
|
||||
}
|
||||
|
||||
private long timestampToInt64(TimestampData timestampData) {
|
||||
@@ -254,6 +324,11 @@ public class ParquetRowDataWriter {
|
||||
public void write(RowData row, int ordinal) {
|
||||
recordConsumer.addBinary(timestampToInt96(row.getTimestamp(ordinal, precision)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
recordConsumer.addBinary(timestampToInt96(array.getTimestamp(ordinal, precision)));
|
||||
}
|
||||
}
|
||||
|
||||
private Binary timestampToInt96(TimestampData timestampData) {
|
||||
@@ -304,10 +379,20 @@ public class ParquetRowDataWriter {
|
||||
@Override
|
||||
public void write(RowData row, int ordinal) {
|
||||
long unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong();
|
||||
doWrite(unscaledLong);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
long unscaledLong = array.getDecimal(ordinal, precision, scale).toUnscaledLong();
|
||||
doWrite(unscaledLong);
|
||||
}
|
||||
|
||||
private void doWrite(long unscaled) {
|
||||
int i = 0;
|
||||
int shift = initShift;
|
||||
while (i < numBytes) {
|
||||
decimalBuffer[i] = (byte) (unscaledLong >> shift);
|
||||
decimalBuffer[i] = (byte) (unscaled >> shift);
|
||||
i += 1;
|
||||
shift -= 8;
|
||||
}
|
||||
@@ -328,6 +413,16 @@ public class ParquetRowDataWriter {
|
||||
@Override
|
||||
public void write(RowData row, int ordinal) {
|
||||
byte[] bytes = row.getDecimal(ordinal, precision, scale).toUnscaledBytes();
|
||||
doWrite(bytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
byte[] bytes = array.getDecimal(ordinal, precision, scale).toUnscaledBytes();
|
||||
doWrite(bytes);
|
||||
}
|
||||
|
||||
private void doWrite(byte[] bytes) {
|
||||
byte[] writtenBytes;
|
||||
if (bytes.length == numBytes) {
|
||||
// Avoid copy.
|
||||
@@ -353,5 +448,132 @@ public class ParquetRowDataWriter {
|
||||
// 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY
|
||||
return new UnscaledBytesWriter();
|
||||
}
|
||||
|
||||
private class ArrayWriter implements FieldWriter {
|
||||
private final FieldWriter elementWriter;
|
||||
|
||||
private ArrayWriter(FieldWriter elementWriter) {
|
||||
this.elementWriter = elementWriter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(RowData row, int ordinal) {
|
||||
ArrayData arrayData = row.getArray(ordinal);
|
||||
doWrite(arrayData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
ArrayData arrayData = array.getArray(ordinal);
|
||||
doWrite(arrayData);
|
||||
}
|
||||
|
||||
private void doWrite(ArrayData arrayData) {
|
||||
recordConsumer.startGroup();
|
||||
if (arrayData.size() > 0) {
|
||||
final String repeatedGroup = "list";
|
||||
final String elementField = "element";
|
||||
recordConsumer.startField(repeatedGroup, 0);
|
||||
for (int i = 0; i < arrayData.size(); i++) {
|
||||
recordConsumer.startGroup();
|
||||
if (!arrayData.isNullAt(i)) {
|
||||
// Only creates the element field if the current array element is not null.
|
||||
recordConsumer.startField(elementField, 0);
|
||||
elementWriter.write(arrayData, i);
|
||||
recordConsumer.endField(elementField, 0);
|
||||
}
|
||||
recordConsumer.endGroup();
|
||||
}
|
||||
recordConsumer.endField(repeatedGroup, 0);
|
||||
}
|
||||
recordConsumer.endGroup();
|
||||
}
|
||||
}
|
||||
|
||||
private class MapWriter implements FieldWriter {
|
||||
private final FieldWriter keyWriter;
|
||||
private final FieldWriter valueWriter;
|
||||
|
||||
private MapWriter(FieldWriter keyWriter, FieldWriter valueWriter) {
|
||||
this.keyWriter = keyWriter;
|
||||
this.valueWriter = valueWriter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(RowData row, int ordinal) {
|
||||
MapData map = row.getMap(ordinal);
|
||||
doWrite(map);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
MapData map = array.getMap(ordinal);
|
||||
doWrite(map);
|
||||
}
|
||||
|
||||
private void doWrite(MapData mapData) {
|
||||
ArrayData keyArray = mapData.keyArray();
|
||||
ArrayData valueArray = mapData.valueArray();
|
||||
recordConsumer.startGroup();
|
||||
if (mapData.size() > 0) {
|
||||
final String repeatedGroup = "key_value";
|
||||
final String kField = "key";
|
||||
final String vField = "value";
|
||||
recordConsumer.startField(repeatedGroup, 0);
|
||||
for (int i = 0; i < mapData.size(); i++) {
|
||||
recordConsumer.startGroup();
|
||||
// key
|
||||
recordConsumer.startField(kField, 0);
|
||||
this.keyWriter.write(keyArray, i);
|
||||
recordConsumer.endField(kField, 0);
|
||||
// value
|
||||
if (!valueArray.isNullAt(i)) {
|
||||
// Only creates the "value" field if the value if non-empty
|
||||
recordConsumer.startField(vField, 1);
|
||||
this.valueWriter.write(valueArray, i);
|
||||
recordConsumer.endField(vField, 1);
|
||||
}
|
||||
recordConsumer.endGroup();
|
||||
}
|
||||
recordConsumer.endField(repeatedGroup, 0);
|
||||
}
|
||||
recordConsumer.endGroup();
|
||||
}
|
||||
}
|
||||
|
||||
private class RowWriter implements FieldWriter {
|
||||
private final String[] fieldNames;
|
||||
private final FieldWriter[] fieldWriters;
|
||||
|
||||
private RowWriter(String[] fieldNames, FieldWriter[] fieldWriters) {
|
||||
this.fieldNames = fieldNames;
|
||||
this.fieldWriters = fieldWriters;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(RowData row, int ordinal) {
|
||||
RowData nested = row.getRow(ordinal, fieldWriters.length);
|
||||
doWrite(nested);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ArrayData array, int ordinal) {
|
||||
RowData nested = array.getRow(ordinal, fieldWriters.length);
|
||||
doWrite(nested);
|
||||
}
|
||||
|
||||
private void doWrite(RowData row) {
|
||||
recordConsumer.startGroup();
|
||||
for (int i = 0; i < row.getArity(); i++) {
|
||||
if (!row.isNullAt(i)) {
|
||||
String fieldName = fieldNames[i];
|
||||
recordConsumer.startField(fieldName, i);
|
||||
fieldWriters[i].write(row, i);
|
||||
recordConsumer.endField(fieldName, i);
|
||||
}
|
||||
}
|
||||
recordConsumer.endGroup();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,9 +25,11 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
|
||||
import org.apache.flink.api.java.typeutils.MapTypeInfo;
|
||||
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
|
||||
import org.apache.flink.api.java.typeutils.RowTypeInfo;
|
||||
import org.apache.flink.table.types.logical.ArrayType;
|
||||
import org.apache.flink.table.types.logical.DecimalType;
|
||||
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
|
||||
import org.apache.flink.table.types.logical.LogicalType;
|
||||
import org.apache.flink.table.types.logical.MapType;
|
||||
import org.apache.flink.table.types.logical.RowType;
|
||||
|
||||
import org.apache.flink.table.types.logical.TimestampType;
|
||||
@@ -616,6 +618,45 @@ public class ParquetSchemaConverter {
|
||||
return Types.primitive(PrimitiveType.PrimitiveTypeName.INT96, repetition)
|
||||
.named(name);
|
||||
}
|
||||
case ARRAY:
|
||||
// <list-repetition> group <name> (LIST) {
|
||||
// repeated group list {
|
||||
// <element-repetition> <element-type> element;
|
||||
// }
|
||||
// }
|
||||
ArrayType arrayType = (ArrayType) type;
|
||||
LogicalType elementType = arrayType.getElementType();
|
||||
return Types
|
||||
.buildGroup(repetition).as(OriginalType.LIST)
|
||||
.addField(
|
||||
Types.repeatedGroup()
|
||||
.addField(convertToParquetType("element", elementType, repetition))
|
||||
.named("list"))
|
||||
.named(name);
|
||||
case MAP:
|
||||
// <map-repetition> group <name> (MAP) {
|
||||
// repeated group key_value {
|
||||
// required <key-type> key;
|
||||
// <value-repetition> <value-type> value;
|
||||
// }
|
||||
// }
|
||||
MapType mapType = (MapType) type;
|
||||
LogicalType keyType = mapType.getKeyType();
|
||||
LogicalType valueType = mapType.getValueType();
|
||||
return Types
|
||||
.buildGroup(repetition).as(OriginalType.MAP)
|
||||
.addField(
|
||||
Types
|
||||
.repeatedGroup()
|
||||
.addField(convertToParquetType("key", keyType, repetition))
|
||||
.addField(convertToParquetType("value", valueType, repetition))
|
||||
.named("key_value"))
|
||||
.named(name);
|
||||
case ROW:
|
||||
RowType rowType = (RowType) type;
|
||||
Types.GroupBuilder<GroupType> builder = Types.buildGroup(repetition);
|
||||
rowType.getFields().forEach(field -> builder.addField(convertToParquetType(field.getName(), field.getType(), repetition)));
|
||||
return builder.named(name);
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unsupported type: " + type);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user