1
0

[HUDI-3083] Support component data types for flink bulk_insert (#4470)

* [HUDI-3083] Support component data types for flink bulk_insert

* add nested row type test
This commit is contained in:
Ron
2021-12-30 11:15:54 +08:00
committed by GitHub
parent 5c0e4ce005
commit 674c149234
24 changed files with 3031 additions and 75 deletions

View File

@@ -18,19 +18,22 @@
package org.apache.hudi.io.storage.row.parquet;
import org.apache.flink.table.data.ArrayData;
import org.apache.flink.table.data.DecimalDataUtils;
import org.apache.flink.table.data.MapData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.TimestampData;
import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.MapType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.TimestampType;
import org.apache.flink.util.Preconditions;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.Type;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -46,7 +49,8 @@ import static org.apache.flink.formats.parquet.vector.reader.TimestampColumnRead
/**
* Writes a record to the Parquet API with the expected schema in order to be written to a file.
*
* <p>Reference org.apache.flink.formats.parquet.row.ParquetRowDataWriter to support timestamp of INT64 8 bytes.
* <p>Reference {@code org.apache.flink.formats.parquet.row.ParquetRowDataWriter}
* to support timestamp of INT64 8 bytes and complex data types.
*/
public class ParquetRowDataWriter {
@@ -67,7 +71,7 @@ public class ParquetRowDataWriter {
this.filedWriters = new FieldWriter[rowType.getFieldCount()];
this.fieldNames = rowType.getFieldNames().toArray(new String[0]);
for (int i = 0; i < rowType.getFieldCount(); i++) {
this.filedWriters[i] = createWriter(rowType.getTypeAt(i), schema.getType(i));
this.filedWriters[i] = createWriter(rowType.getTypeAt(i));
}
}
@@ -91,59 +95,75 @@ public class ParquetRowDataWriter {
recordConsumer.endMessage();
}
private FieldWriter createWriter(LogicalType t, Type type) {
if (type.isPrimitive()) {
switch (t.getTypeRoot()) {
case CHAR:
case VARCHAR:
return new StringWriter();
case BOOLEAN:
return new BooleanWriter();
case BINARY:
case VARBINARY:
return new BinaryWriter();
case DECIMAL:
DecimalType decimalType = (DecimalType) t;
return createDecimalWriter(decimalType.getPrecision(), decimalType.getScale());
case TINYINT:
return new ByteWriter();
case SMALLINT:
return new ShortWriter();
case DATE:
case TIME_WITHOUT_TIME_ZONE:
case INTEGER:
return new IntWriter();
case BIGINT:
return new LongWriter();
case FLOAT:
return new FloatWriter();
case DOUBLE:
return new DoubleWriter();
case TIMESTAMP_WITHOUT_TIME_ZONE:
TimestampType timestampType = (TimestampType) t;
if (timestampType.getPrecision() == 3) {
return new Timestamp64Writer();
} else {
return new Timestamp96Writer(timestampType.getPrecision());
}
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
LocalZonedTimestampType localZonedTimestampType = (LocalZonedTimestampType) t;
if (localZonedTimestampType.getPrecision() == 3) {
return new Timestamp64Writer();
} else {
return new Timestamp96Writer(localZonedTimestampType.getPrecision());
}
default:
throw new UnsupportedOperationException("Unsupported type: " + type);
}
} else {
throw new IllegalArgumentException("Unsupported data type: " + t);
private FieldWriter createWriter(LogicalType t) {
switch (t.getTypeRoot()) {
case CHAR:
case VARCHAR:
return new StringWriter();
case BOOLEAN:
return new BooleanWriter();
case BINARY:
case VARBINARY:
return new BinaryWriter();
case DECIMAL:
DecimalType decimalType = (DecimalType) t;
return createDecimalWriter(decimalType.getPrecision(), decimalType.getScale());
case TINYINT:
return new ByteWriter();
case SMALLINT:
return new ShortWriter();
case DATE:
case TIME_WITHOUT_TIME_ZONE:
case INTEGER:
return new IntWriter();
case BIGINT:
return new LongWriter();
case FLOAT:
return new FloatWriter();
case DOUBLE:
return new DoubleWriter();
case TIMESTAMP_WITHOUT_TIME_ZONE:
TimestampType timestampType = (TimestampType) t;
if (timestampType.getPrecision() == 3) {
return new Timestamp64Writer();
} else {
return new Timestamp96Writer(timestampType.getPrecision());
}
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
LocalZonedTimestampType localZonedTimestampType = (LocalZonedTimestampType) t;
if (localZonedTimestampType.getPrecision() == 3) {
return new Timestamp64Writer();
} else {
return new Timestamp96Writer(localZonedTimestampType.getPrecision());
}
case ARRAY:
ArrayType arrayType = (ArrayType) t;
LogicalType elementType = arrayType.getElementType();
FieldWriter elementWriter = createWriter(elementType);
return new ArrayWriter(elementWriter);
case MAP:
MapType mapType = (MapType) t;
LogicalType keyType = mapType.getKeyType();
LogicalType valueType = mapType.getValueType();
FieldWriter keyWriter = createWriter(keyType);
FieldWriter valueWriter = createWriter(valueType);
return new MapWriter(keyWriter, valueWriter);
case ROW:
RowType rowType = (RowType) t;
FieldWriter[] fieldWriters = rowType.getFields().stream()
.map(RowType.RowField::getType).map(this::createWriter).toArray(FieldWriter[]::new);
String[] fieldNames = rowType.getFields().stream()
.map(RowType.RowField::getName).toArray(String[]::new);
return new RowWriter(fieldNames, fieldWriters);
default:
throw new UnsupportedOperationException("Unsupported type: " + t);
}
}
private interface FieldWriter {
void write(RowData row, int ordinal);
void write(ArrayData array, int ordinal);
}
private class BooleanWriter implements FieldWriter {
@@ -152,6 +172,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addBoolean(row.getBoolean(ordinal));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addBoolean(array.getBoolean(ordinal));
}
}
private class ByteWriter implements FieldWriter {
@@ -160,6 +185,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addInteger(row.getByte(ordinal));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addInteger(array.getByte(ordinal));
}
}
private class ShortWriter implements FieldWriter {
@@ -168,6 +198,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addInteger(row.getShort(ordinal));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addInteger(array.getShort(ordinal));
}
}
private class LongWriter implements FieldWriter {
@@ -176,6 +211,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addLong(row.getLong(ordinal));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addLong(array.getLong(ordinal));
}
}
private class FloatWriter implements FieldWriter {
@@ -184,6 +224,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addFloat(row.getFloat(ordinal));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addFloat(array.getFloat(ordinal));
}
}
private class DoubleWriter implements FieldWriter {
@@ -192,6 +237,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addDouble(row.getDouble(ordinal));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addDouble(array.getDouble(ordinal));
}
}
private class StringWriter implements FieldWriter {
@@ -200,6 +250,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addBinary(Binary.fromReusedByteArray(row.getString(ordinal).toBytes()));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addBinary(Binary.fromReusedByteArray(array.getString(ordinal).toBytes()));
}
}
private class BinaryWriter implements FieldWriter {
@@ -208,6 +263,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal)));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addBinary(Binary.fromReusedByteArray(array.getBinary(ordinal)));
}
}
private class IntWriter implements FieldWriter {
@@ -216,6 +276,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addInteger(row.getInt(ordinal));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addInteger(array.getInt(ordinal));
}
}
/**
@@ -231,6 +296,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addLong(timestampToInt64(row.getTimestamp(ordinal, 3)));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addLong(timestampToInt64(array.getTimestamp(ordinal, 3)));
}
}
private long timestampToInt64(TimestampData timestampData) {
@@ -254,6 +324,11 @@ public class ParquetRowDataWriter {
public void write(RowData row, int ordinal) {
recordConsumer.addBinary(timestampToInt96(row.getTimestamp(ordinal, precision)));
}
@Override
public void write(ArrayData array, int ordinal) {
recordConsumer.addBinary(timestampToInt96(array.getTimestamp(ordinal, precision)));
}
}
private Binary timestampToInt96(TimestampData timestampData) {
@@ -304,10 +379,20 @@ public class ParquetRowDataWriter {
@Override
public void write(RowData row, int ordinal) {
long unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong();
doWrite(unscaledLong);
}
@Override
public void write(ArrayData array, int ordinal) {
long unscaledLong = array.getDecimal(ordinal, precision, scale).toUnscaledLong();
doWrite(unscaledLong);
}
private void doWrite(long unscaled) {
int i = 0;
int shift = initShift;
while (i < numBytes) {
decimalBuffer[i] = (byte) (unscaledLong >> shift);
decimalBuffer[i] = (byte) (unscaled >> shift);
i += 1;
shift -= 8;
}
@@ -328,6 +413,16 @@ public class ParquetRowDataWriter {
@Override
public void write(RowData row, int ordinal) {
byte[] bytes = row.getDecimal(ordinal, precision, scale).toUnscaledBytes();
doWrite(bytes);
}
@Override
public void write(ArrayData array, int ordinal) {
byte[] bytes = array.getDecimal(ordinal, precision, scale).toUnscaledBytes();
doWrite(bytes);
}
private void doWrite(byte[] bytes) {
byte[] writtenBytes;
if (bytes.length == numBytes) {
// Avoid copy.
@@ -353,5 +448,132 @@ public class ParquetRowDataWriter {
// 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY
return new UnscaledBytesWriter();
}
private class ArrayWriter implements FieldWriter {
private final FieldWriter elementWriter;
private ArrayWriter(FieldWriter elementWriter) {
this.elementWriter = elementWriter;
}
@Override
public void write(RowData row, int ordinal) {
ArrayData arrayData = row.getArray(ordinal);
doWrite(arrayData);
}
@Override
public void write(ArrayData array, int ordinal) {
ArrayData arrayData = array.getArray(ordinal);
doWrite(arrayData);
}
private void doWrite(ArrayData arrayData) {
recordConsumer.startGroup();
if (arrayData.size() > 0) {
final String repeatedGroup = "list";
final String elementField = "element";
recordConsumer.startField(repeatedGroup, 0);
for (int i = 0; i < arrayData.size(); i++) {
recordConsumer.startGroup();
if (!arrayData.isNullAt(i)) {
// Only creates the element field if the current array element is not null.
recordConsumer.startField(elementField, 0);
elementWriter.write(arrayData, i);
recordConsumer.endField(elementField, 0);
}
recordConsumer.endGroup();
}
recordConsumer.endField(repeatedGroup, 0);
}
recordConsumer.endGroup();
}
}
private class MapWriter implements FieldWriter {
private final FieldWriter keyWriter;
private final FieldWriter valueWriter;
private MapWriter(FieldWriter keyWriter, FieldWriter valueWriter) {
this.keyWriter = keyWriter;
this.valueWriter = valueWriter;
}
@Override
public void write(RowData row, int ordinal) {
MapData map = row.getMap(ordinal);
doWrite(map);
}
@Override
public void write(ArrayData array, int ordinal) {
MapData map = array.getMap(ordinal);
doWrite(map);
}
private void doWrite(MapData mapData) {
ArrayData keyArray = mapData.keyArray();
ArrayData valueArray = mapData.valueArray();
recordConsumer.startGroup();
if (mapData.size() > 0) {
final String repeatedGroup = "key_value";
final String kField = "key";
final String vField = "value";
recordConsumer.startField(repeatedGroup, 0);
for (int i = 0; i < mapData.size(); i++) {
recordConsumer.startGroup();
// key
recordConsumer.startField(kField, 0);
this.keyWriter.write(keyArray, i);
recordConsumer.endField(kField, 0);
// value
if (!valueArray.isNullAt(i)) {
// Only creates the "value" field if the value if non-empty
recordConsumer.startField(vField, 1);
this.valueWriter.write(valueArray, i);
recordConsumer.endField(vField, 1);
}
recordConsumer.endGroup();
}
recordConsumer.endField(repeatedGroup, 0);
}
recordConsumer.endGroup();
}
}
private class RowWriter implements FieldWriter {
private final String[] fieldNames;
private final FieldWriter[] fieldWriters;
private RowWriter(String[] fieldNames, FieldWriter[] fieldWriters) {
this.fieldNames = fieldNames;
this.fieldWriters = fieldWriters;
}
@Override
public void write(RowData row, int ordinal) {
RowData nested = row.getRow(ordinal, fieldWriters.length);
doWrite(nested);
}
@Override
public void write(ArrayData array, int ordinal) {
RowData nested = array.getRow(ordinal, fieldWriters.length);
doWrite(nested);
}
private void doWrite(RowData row) {
recordConsumer.startGroup();
for (int i = 0; i < row.getArity(); i++) {
if (!row.isNullAt(i)) {
String fieldName = fieldNames[i];
recordConsumer.startField(fieldName, i);
fieldWriters[i].write(row, i);
recordConsumer.endField(fieldName, i);
}
}
recordConsumer.endGroup();
}
}
}

View File

@@ -25,9 +25,11 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.MapTypeInfo;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.MapType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.TimestampType;
@@ -616,6 +618,45 @@ public class ParquetSchemaConverter {
return Types.primitive(PrimitiveType.PrimitiveTypeName.INT96, repetition)
.named(name);
}
case ARRAY:
// <list-repetition> group <name> (LIST) {
// repeated group list {
// <element-repetition> <element-type> element;
// }
// }
ArrayType arrayType = (ArrayType) type;
LogicalType elementType = arrayType.getElementType();
return Types
.buildGroup(repetition).as(OriginalType.LIST)
.addField(
Types.repeatedGroup()
.addField(convertToParquetType("element", elementType, repetition))
.named("list"))
.named(name);
case MAP:
// <map-repetition> group <name> (MAP) {
// repeated group key_value {
// required <key-type> key;
// <value-repetition> <value-type> value;
// }
// }
MapType mapType = (MapType) type;
LogicalType keyType = mapType.getKeyType();
LogicalType valueType = mapType.getValueType();
return Types
.buildGroup(repetition).as(OriginalType.MAP)
.addField(
Types
.repeatedGroup()
.addField(convertToParquetType("key", keyType, repetition))
.addField(convertToParquetType("value", valueType, repetition))
.named("key_value"))
.named(name);
case ROW:
RowType rowType = (RowType) type;
Types.GroupBuilder<GroupType> builder = Types.buildGroup(repetition);
rowType.getFields().forEach(field -> builder.addField(convertToParquetType(field.getName(), field.getType(), repetition)));
return builder.named(name);
default:
throw new UnsupportedOperationException("Unsupported type: " + type);
}