diff --git a/parquet-protobuf/pom.xml b/parquet-protobuf/pom.xml index ddf634a777..9f604372be 100644 --- a/parquet-protobuf/pom.xml +++ b/parquet-protobuf/pom.xml @@ -32,6 +32,7 @@ 4.4 3.25.1 + 2.28.0 1.1.5 @@ -67,6 +68,16 @@ protobuf-java ${protobuf.version} + + com.google.protobuf + protobuf-java-util + ${protobuf.version} + + + com.google.api.grpc + proto-google-common-protos + ${common-protos.version} + org.apache.parquet parquet-common @@ -191,6 +202,7 @@ com.google.protobuf:protoc:${protobuf.version} + direct test all direct diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index da51788f2a..5c17af6fe4 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -18,10 +18,21 @@ */ package org.apache.parquet.proto; +import com.google.protobuf.BoolValue; import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; import com.google.protobuf.DescriptorProtos; import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; import com.google.protobuf.Message; +import com.google.protobuf.StringValue; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.protobuf.util.Timestamps; import com.twitter.elephantbird.util.Protobufs; import org.apache.hadoop.conf.Configuration; import org.apache.parquet.column.Dictionary; @@ -42,6 +53,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.time.LocalDate; +import java.time.LocalTime; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -233,6 +246,21 @@ public Optional visit(LogicalTypeAnnotation.ListLogicalTypeAnnotation public Optional visit(LogicalTypeAnnotation.MapLogicalTypeAnnotation mapLogicalType) { return of(new MapConverter(parentBuilder, fieldDescriptor, parquetType)); } + + @Override + public Optional visit(LogicalTypeAnnotation.TimestampLogicalTypeAnnotation timestampLogicalType) { + return of(new ProtoTimestampConverter(parent, timestampLogicalType)); + } + + @Override + public Optional visit(LogicalTypeAnnotation.DateLogicalTypeAnnotation dateLogicalType) { + return of(new ProtoDateConverter(parent)); + } + + @Override + public Optional visit(LogicalTypeAnnotation.TimeLogicalTypeAnnotation timeLogicalType) { + return of(new ProtoTimeConverter(parent, timeLogicalType)); + } }).orElseGet(() -> newScalarConverter(parent, parentBuilder, fieldDescriptor, parquetType)); } @@ -250,6 +278,37 @@ protected Converter newScalarConverter(ParentValueContainer pvc, Message.Builder case INT: return new ProtoIntConverter(pvc); case LONG: return new ProtoLongConverter(pvc); case MESSAGE: { + if (parquetType.isPrimitive()) { + // if source is a Primitive type yet target is MESSAGE, it's probably a wrapped message + Descriptor messageType = fieldDescriptor.getMessageType(); + if (messageType.equals(DoubleValue.getDescriptor())) { + return new ProtoDoubleValueConverter(pvc); + } + if (messageType.equals(FloatValue.getDescriptor())) { + return new ProtoFloatValueConverter(pvc); + } + if (messageType.equals(Int64Value.getDescriptor())) { + return new ProtoInt64ValueConverter(pvc); + } + if (messageType.equals(UInt64Value.getDescriptor())) { + return new ProtoUInt64ValueConverter(pvc); + } + if (messageType.equals(Int32Value.getDescriptor())) { + return new ProtoInt32ValueConverter(pvc); + } + if (messageType.equals(UInt32Value.getDescriptor())) { + return new ProtoUInt32ValueConverter(pvc); + } + if (messageType.equals(BoolValue.getDescriptor())) { + return new ProtoBoolValueConverter(pvc); + } + if (messageType.equals(StringValue.getDescriptor())) { + return new ProtoStringValueConverter(pvc); + } + if (messageType.equals(BytesValue.getDescriptor())) { + return new ProtoBytesValueConverter(pvc); + } + } Message.Builder subBuilder = parentBuilder.newBuilderForField(fieldDescriptor); return new ProtoMessageConverter(conf, pvc, subBuilder, parquetType.asGroupType(), extraMetadata); } @@ -295,7 +354,7 @@ public ProtoEnumConverter(ParentValueContainer parent, Descriptors.FieldDescript * Fills lookup structure for translating between parquet enum values and Protocol buffer enum values. * */ private Map makeLookupStructure(Descriptors.EnumDescriptor enumType) { - Map lookupStructure = new HashMap(); + Map lookupStructure = new HashMap<>(); if (extraMetadata.containsKey(METADATA_ENUM_PREFIX + enumType.getFullName())) { String enumNameNumberPairs = extraMetadata.get(METADATA_ENUM_PREFIX + enumType.getFullName()); @@ -366,7 +425,7 @@ private Descriptors.EnumValueDescriptor translateEnumValue(Binary binaryValue) { } @Override - final public void addBinary(Binary binaryValue) { + public void addBinary(Binary binaryValue) { Descriptors.EnumValueDescriptor protoValue = translateEnumValue(binaryValue); parent.add(protoValue); } @@ -392,7 +451,7 @@ public void setDictionary(Dictionary dictionary) { } - final class ProtoBinaryConverter extends PrimitiveConverter { + static final class ProtoBinaryConverter extends PrimitiveConverter { final ParentValueContainer parent; @@ -408,7 +467,7 @@ public void addBinary(Binary binary) { } - final class ProtoBooleanConverter extends PrimitiveConverter { + static final class ProtoBooleanConverter extends PrimitiveConverter { final ParentValueContainer parent; @@ -417,13 +476,13 @@ public ProtoBooleanConverter(ParentValueContainer parent) { } @Override - final public void addBoolean(boolean value) { + public void addBoolean(boolean value) { parent.add(value); } } - final class ProtoDoubleConverter extends PrimitiveConverter { + static final class ProtoDoubleConverter extends PrimitiveConverter { final ParentValueContainer parent; @@ -437,7 +496,7 @@ public void addDouble(double value) { } } - final class ProtoFloatConverter extends PrimitiveConverter { + static final class ProtoFloatConverter extends PrimitiveConverter { final ParentValueContainer parent; @@ -451,7 +510,7 @@ public void addFloat(float value) { } } - final class ProtoIntConverter extends PrimitiveConverter { + static final class ProtoIntConverter extends PrimitiveConverter { final ParentValueContainer parent; @@ -465,7 +524,7 @@ public void addInt(int value) { } } - final class ProtoLongConverter extends PrimitiveConverter { + static final class ProtoLongConverter extends PrimitiveConverter { final ParentValueContainer parent; @@ -479,7 +538,7 @@ public void addLong(long value) { } } - final class ProtoStringConverter extends PrimitiveConverter { + static final class ProtoStringConverter extends PrimitiveConverter { final ParentValueContainer parent; @@ -495,6 +554,218 @@ public void addBinary(Binary binary) { } + static final class ProtoTimestampConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + final LogicalTypeAnnotation.TimestampLogicalTypeAnnotation logicalTypeAnnotation; + + public ProtoTimestampConverter(ParentValueContainer parent, LogicalTypeAnnotation.TimestampLogicalTypeAnnotation logicalTypeAnnotation) { + this.parent = parent; + this.logicalTypeAnnotation = logicalTypeAnnotation; + } + + @Override + public void addLong(long value) { + switch (logicalTypeAnnotation.getUnit()) { + case MICROS: + parent.add(Timestamps.fromMicros(value)); + break; + case MILLIS: + parent.add(Timestamps.fromMillis(value)); + break; + case NANOS: + parent.add(Timestamps.fromNanos(value)); + break; + } + } + } + + static final class ProtoDateConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + + public ProtoDateConverter(ParentValueContainer parent) { + this.parent = parent; + } + + @Override + public void addInt(int value) { + LocalDate localDate = LocalDate.ofEpochDay(value); + com.google.type.Date date = com.google.type.Date.newBuilder() + .setYear(localDate.getYear()) + .setMonth(localDate.getMonthValue()) + .setDay(localDate.getDayOfMonth()) + .build(); + parent.add(date); + } + } + + static final class ProtoTimeConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + final LogicalTypeAnnotation.TimeLogicalTypeAnnotation logicalTypeAnnotation; + + public ProtoTimeConverter(ParentValueContainer parent, LogicalTypeAnnotation.TimeLogicalTypeAnnotation logicalTypeAnnotation) { + this.parent = parent; + this.logicalTypeAnnotation = logicalTypeAnnotation; + } + + @Override + public void addLong(long value) { + LocalTime localTime; + switch (logicalTypeAnnotation.getUnit()) { + case MILLIS: + localTime = LocalTime.ofNanoOfDay(value * 1_000_000); + break; + case MICROS: + localTime = LocalTime.ofNanoOfDay(value * 1_000); + break; + case NANOS: + localTime = LocalTime.ofNanoOfDay(value); + break; + default: + throw new IllegalArgumentException("Unrecognized TimeUnit: " + logicalTypeAnnotation.getUnit()); + } + com.google.type.TimeOfDay timeOfDay = com.google.type.TimeOfDay.newBuilder() + .setHours(localTime.getHour()) + .setMinutes(localTime.getMinute()) + .setSeconds(localTime.getSecond()) + .setNanos(localTime.getNano()) + .build(); + parent.add(timeOfDay); + } + } + + static final class ProtoDoubleValueConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + + public ProtoDoubleValueConverter(ParentValueContainer parent) { + this.parent = parent; + } + + @Override + public void addDouble(double value) { + parent.add(DoubleValue.of(value)); + } + } + + static final class ProtoFloatValueConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + + public ProtoFloatValueConverter(ParentValueContainer parent) { + this.parent = parent; + } + + @Override + public void addFloat(float value) { + parent.add(FloatValue.of(value)); + } + } + + static final class ProtoInt64ValueConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + + public ProtoInt64ValueConverter(ParentValueContainer parent) { + this.parent = parent; + } + + @Override + public void addLong(long value) { + parent.add(Int64Value.of(value)); + } + } + + static final class ProtoUInt64ValueConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + + public ProtoUInt64ValueConverter(ParentValueContainer parent) { + this.parent = parent; + } + + @Override + public void addLong(long value) { + parent.add(UInt64Value.of(value)); + } + } + + static final class ProtoInt32ValueConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + + public ProtoInt32ValueConverter(ParentValueContainer parent) { + this.parent = parent; + } + + @Override + public void addInt(int value) { + parent.add(Int32Value.of(value)); + } + } + + static final class ProtoUInt32ValueConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + + public ProtoUInt32ValueConverter(ParentValueContainer parent) { + this.parent = parent; + } + + @Override + public void addLong(long value) { + parent.add(UInt32Value.of(Math.toIntExact(value))); + } + } + + static final class ProtoBoolValueConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + + public ProtoBoolValueConverter(ParentValueContainer parent) { + this.parent = parent; + } + + @Override + public void addBoolean(boolean value) { + parent.add(BoolValue.of(value)); + } + + } + + static final class ProtoStringValueConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + + public ProtoStringValueConverter(ParentValueContainer parent) { + this.parent = parent; + } + + @Override + public void addBinary(Binary binary) { + String str = binary.toStringUsingUTF8(); + parent.add(StringValue.of(str)); + } + + } + + static final class ProtoBytesValueConverter extends PrimitiveConverter { + + final ParentValueContainer parent; + + public ProtoBytesValueConverter(ParentValueContainer parent) { + this.parent = parent; + } + + @Override + public void addBinary(Binary binary) { + ByteString byteString = ByteString.copyFrom(binary.toByteBuffer()); + parent.add(BytesValue.of(byteString)); + } + } + /** * This class unwraps the additional LIST wrapper and makes it possible to read the underlying data and then convert * it to protobuf. diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java index a6a779d074..83f3970c23 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java @@ -18,11 +18,24 @@ */ package org.apache.parquet.proto; +import com.google.protobuf.BoolValue; +import com.google.protobuf.BytesValue; import com.google.common.collect.ImmutableSetMultimap; import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor.JavaType; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; import com.google.protobuf.Message; +import com.google.protobuf.StringValue; +import com.google.protobuf.Timestamp; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.type.Date; +import com.google.type.TimeOfDay; import com.twitter.elephantbird.util.Protobufs; import org.apache.hadoop.conf.Configuration; import org.apache.parquet.conf.HadoopParquetConfiguration; @@ -31,6 +44,7 @@ import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; import org.apache.parquet.schema.Types.Builder; import org.apache.parquet.schema.Types.GroupBuilder; @@ -40,11 +54,20 @@ import java.util.List; import javax.annotation.Nullable; +import static org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit; +import static org.apache.parquet.schema.LogicalTypeAnnotation.dateType; import static org.apache.parquet.schema.LogicalTypeAnnotation.enumType; import static org.apache.parquet.schema.LogicalTypeAnnotation.listType; import static org.apache.parquet.schema.LogicalTypeAnnotation.mapType; import static org.apache.parquet.schema.LogicalTypeAnnotation.stringType; -import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.*; +import static org.apache.parquet.schema.LogicalTypeAnnotation.timeType; +import static org.apache.parquet.schema.LogicalTypeAnnotation.timestampType; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FLOAT; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; /** * Converts a Protocol Buffer Descriptor into a Parquet schema. @@ -55,6 +78,7 @@ public class ProtoSchemaConverter { public static final String PB_MAX_RECURSION = "parquet.proto.maxRecursion"; private final boolean parquetSpecsCompliant; + private final boolean unwrapProtoWrappers; // TODO: use proto custom options to override per field. private final int maxRecursion; @@ -76,7 +100,7 @@ public ProtoSchemaConverter() { * by the parquet specifications. If set to true, specs compliant schemas are used. */ public ProtoSchemaConverter(boolean parquetSpecsCompliant) { - this(parquetSpecsCompliant, 5); + this(parquetSpecsCompliant, 5, false); } /** @@ -98,7 +122,9 @@ public ProtoSchemaConverter(Configuration config) { public ProtoSchemaConverter(ParquetConfiguration config) { this( config.getBoolean(ProtoWriteSupport.PB_SPECS_COMPLIANT_WRITE, false), - config.getInt(PB_MAX_RECURSION, 5)); + config.getInt(PB_MAX_RECURSION, 5), + config.getBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, false) + ); } /** @@ -112,8 +138,25 @@ public ProtoSchemaConverter(ParquetConfiguration config) { * bytes instead of their actual schema. */ public ProtoSchemaConverter(boolean parquetSpecsCompliant, int maxRecursion) { + this(parquetSpecsCompliant, maxRecursion, false); + } + + /** + * Instantiate a schema converter to get the parquet schema corresponding to protobuf classes. + * + * @param parquetSpecsCompliant If set to false, the parquet schema generated will be using the old + * schema style (prior to PARQUET-968) to provide backward-compatibility + * but which does not use LIST and MAP wrappers around collections as required + * by the parquet specifications. If set to true, specs compliant schemas are used. + * @param maxRecursion The maximum recursion depth messages are allowed to go before terminating as + * bytes instead of their actual schema. + * @param unwrapProtoWrappers If set to true, unwrap common Proto wrappers like Timestamp and DoubleValue + * with corresponding OPTIONAL logical annotations. Primitive types become REQUIRED. + */ + public ProtoSchemaConverter(boolean parquetSpecsCompliant, int maxRecursion, boolean unwrapProtoWrappers) { this.parquetSpecsCompliant = parquetSpecsCompliant; this.maxRecursion = maxRecursion; + this.unwrapProtoWrappers = unwrapProtoWrappers; } /** @@ -178,6 +221,46 @@ private static Type.Repetition getRepetition(FieldDescriptor descriptor) { private Builder>, GroupBuilder> addField(FieldDescriptor descriptor, final GroupBuilder builder, ImmutableSetMultimap seen, int depth) { if (descriptor.getJavaType() == JavaType.MESSAGE) { + if (unwrapProtoWrappers) { + Descriptor messageType = descriptor.getMessageType(); + if (messageType.equals(Timestamp.getDescriptor())) { + return builder.primitive(INT64, getRepetition(descriptor)).as(timestampType(true, TimeUnit.NANOS)); + } + if (messageType.equals(Date.getDescriptor())) { + return builder.primitive(INT32, getRepetition(descriptor)).as(dateType()); + } + if (messageType.equals(TimeOfDay.getDescriptor())) { + return builder.primitive(INT64, getRepetition(descriptor)).as(timeType(true, TimeUnit.NANOS)); + } + if (messageType.equals(DoubleValue.getDescriptor())) { + return builder.primitive(DOUBLE, getRepetition(descriptor)); + } + if (messageType.equals(StringValue.getDescriptor())) { + return builder.primitive(BINARY, getRepetition(descriptor)).as(stringType()); + } + if (messageType.equals(BoolValue.getDescriptor())) { + return builder.primitive(BOOLEAN, getRepetition(descriptor)); + } + if (messageType.equals(FloatValue.getDescriptor())) { + return builder.primitive(FLOAT, getRepetition(descriptor)); + } + if (messageType.equals(Int64Value.getDescriptor())) { + return builder.primitive(INT64, getRepetition(descriptor)); + } + if (messageType.equals(UInt64Value.getDescriptor())) { + return builder.primitive(INT64, getRepetition(descriptor)); + } + if (messageType.equals(Int32Value.getDescriptor())) { + return builder.primitive(INT32, getRepetition(descriptor)); + } + if (messageType.equals(UInt32Value.getDescriptor())) { + return builder.primitive(INT64, getRepetition(descriptor)); + } + if (messageType.equals(BytesValue.getDescriptor())) { + return builder.primitive(BINARY, getRepetition(descriptor)); + } + } + return addMessageField(descriptor, builder, seen, depth); } @@ -186,8 +269,8 @@ private Builder>, GroupBuilder> addF // the old schema style did not include the LIST wrapper around repeated fields return addRepeatedPrimitive(parquetType.primitiveType, parquetType.logicalTypeAnnotation, builder); } - - return builder.primitive(parquetType.primitiveType, getRepetition(descriptor)).as(parquetType.logicalTypeAnnotation); + Repetition repetition = unwrapProtoWrappers ? Repetition.REQUIRED : getRepetition(descriptor); + return builder.primitive(parquetType.primitiveType, repetition).as(parquetType.logicalTypeAnnotation); } private static Builder>, GroupBuilder> addRepeatedPrimitive(PrimitiveTypeName primitiveType, diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java index b13acd2a57..c5081e759e 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java @@ -21,6 +21,9 @@ import com.google.protobuf.*; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.util.Timestamps; +import com.google.type.Date; +import com.google.type.TimeOfDay; import com.twitter.elephantbird.util.Protobufs; import org.apache.hadoop.conf.Configuration; import org.apache.parquet.conf.HadoopParquetConfiguration; @@ -36,6 +39,8 @@ import org.slf4j.LoggerFactory; import java.lang.reflect.Array; +import java.time.LocalDate; +import java.time.LocalTime; import java.util.*; import static java.util.Optional.ofNullable; @@ -57,7 +62,10 @@ public class ProtoWriteSupport extends WriteSupport< // but is set to false by default to keep backward compatibility. public static final String PB_SPECS_COMPLIANT_WRITE = "parquet.proto.writeSpecsCompliant"; + public static final String PB_UNWRAP_PROTO_WRAPPERS = "parquet.proto.unwrapProtoWrappers"; + private boolean writeSpecsCompliant = false; + private boolean unwrapProtoWrappers = false; private RecordConsumer recordConsumer; private Class protoMessage; private Descriptor descriptor; @@ -96,6 +104,10 @@ public static void setWriteSpecsCompliant(Configuration configuration, boolean w configuration.setBoolean(PB_SPECS_COMPLIANT_WRITE, writeSpecsCompliant); } + public static void setUnwrapProtoWrappers(Configuration configuration, boolean unwrapProtoWrappers) { + configuration.setBoolean(PB_UNWRAP_PROTO_WRAPPERS, unwrapProtoWrappers); + } + /** * Writes Protocol buffer to parquet file. * @param record instance of Message.Builder or Message. @@ -144,6 +156,7 @@ public WriteContext init(ParquetConfiguration configuration) { extraMetaData.put(ProtoReadSupport.PB_CLASS, protoMessage.getName()); } + unwrapProtoWrappers = configuration.getBoolean(PB_UNWRAP_PROTO_WRAPPERS, unwrapProtoWrappers); writeSpecsCompliant = configuration.getBoolean(PB_SPECS_COMPLIANT_WRITE, writeSpecsCompliant); MessageType rootSchema = new ProtoSchemaConverter(configuration).convert(descriptor); validatedMapping(descriptor, rootSchema); @@ -152,6 +165,7 @@ public WriteContext init(ParquetConfiguration configuration) { extraMetaData.put(ProtoReadSupport.PB_DESCRIPTOR, descriptor.toProto().toString()); extraMetaData.put(PB_SPECS_COMPLIANT_WRITE, String.valueOf(writeSpecsCompliant)); + extraMetaData.put(PB_UNWRAP_PROTO_WRAPPERS, String.valueOf(unwrapProtoWrappers)); return new WriteContext(rootSchema, extraMetaData); } @@ -265,6 +279,46 @@ private FieldWriter createMessageWriter(FieldDescriptor fieldDescriptor, Type ty return createMapWriter(fieldDescriptor, type); } + if (unwrapProtoWrappers) { + Descriptor messageType = fieldDescriptor.getMessageType(); + if (messageType.equals(Timestamp.getDescriptor())) { + return new TimestampWriter(); + } + if (messageType.equals(Date.getDescriptor())) { + return new DateWriter(); + } + if (messageType.equals(TimeOfDay.getDescriptor())) { + return new TimeWriter(); + } + if (messageType.equals(DoubleValue.getDescriptor())) { + return new DoubleValueWriter(); + } + if (messageType.equals(FloatValue.getDescriptor())) { + return new FloatValueWriter(); + } + if (messageType.equals(Int64Value.getDescriptor())) { + return new Int64ValueWriter(); + } + if (messageType.equals(UInt64Value.getDescriptor())) { + return new UInt64ValueWriter(); + } + if (messageType.equals(Int32Value.getDescriptor())) { + return new Int32ValueWriter(); + } + if (messageType.equals(UInt32Value.getDescriptor())) { + return new UInt32ValueWriter(); + } + if (messageType.equals(BoolValue.getDescriptor())) { + return new BoolValueWriter(); + } + if (messageType.equals(StringValue.getDescriptor())) { + return new StringValueWriter(); + } + if (messageType.equals(BytesValue.getDescriptor())) { + return new BytesValueWriter(); + } + } + // This can happen now that recursive schemas get truncated to bytes. Write the bytes. if (type.isPrimitive() && type.asPrimitiveType().getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) { return new BinaryWriter(); @@ -584,6 +638,98 @@ final void writeRawValue(Object value) { } } + class TimestampWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + Timestamp timestamp = (Timestamp) value; + recordConsumer.addLong(Timestamps.toNanos(timestamp)); + } + } + + class DateWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + Date date = (Date) value; + LocalDate localDate = LocalDate.of(date.getYear(), date.getMonth(), date.getDay()); + recordConsumer.addInteger((int) localDate.toEpochDay()); + } + } + + class TimeWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + com.google.type.TimeOfDay timeOfDay = (com.google.type.TimeOfDay) value; + LocalTime localTime = LocalTime.of(timeOfDay.getHours(), timeOfDay.getMinutes(), timeOfDay.getSeconds(), timeOfDay.getNanos()); + recordConsumer.addLong(localTime.toNanoOfDay()); + } + } + + class DoubleValueWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + recordConsumer.addDouble(((DoubleValue) value).getValue()); + } + } + + class FloatValueWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + recordConsumer.addFloat(((FloatValue) value).getValue()); + } + } + + class Int64ValueWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + recordConsumer.addLong(((Int64Value) value).getValue()); + } + } + + class UInt64ValueWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + recordConsumer.addLong(((UInt64Value) value).getValue()); + } + } + + class Int32ValueWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + recordConsumer.addInteger(((Int32Value) value).getValue()); + } + } + + class UInt32ValueWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + recordConsumer.addLong(((UInt32Value) value).getValue()); + } + } + + class BoolValueWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + recordConsumer.addBoolean(((BoolValue) value).getValue()); + } + } + + class StringValueWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + Binary binaryString = Binary.fromString(((StringValue) value).getValue()); + recordConsumer.addBinary(binaryString); + } + } + + class BytesValueWriter extends FieldWriter { + @Override + void writeRawValue(Object value) { + byte[] byteArray = ((BytesValue) value).getValue().toByteArray(); + Binary binary = Binary.fromConstantByteArray(byteArray); + recordConsumer.addBinary(binary); + } + } + private FieldWriter unknownType(FieldDescriptor fieldDescriptor) { String exceptionMsg = "Unknown type with descriptor \"" + fieldDescriptor + "\" and type \"" + fieldDescriptor.getJavaType() + "\"."; diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java index 4debe77f8f..605c322664 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java @@ -18,8 +18,12 @@ */ package org.apache.parquet.proto; +import com.google.protobuf.BoolValue; import com.google.protobuf.ByteString; +import com.google.protobuf.DoubleValue; import com.google.protobuf.Message; +import com.google.protobuf.Timestamp; +import com.google.protobuf.util.Timestamps; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.parquet.proto.test.TestProto3; @@ -32,7 +36,9 @@ import java.util.List; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; public class ProtoInputOutputFormatTest { @@ -607,6 +613,49 @@ public void testProto3RepeatedMessages() throws Exception { assertTrue(third.getTwo().isEmpty()); } + @Test + public void testProto3TimestampMessageClass() throws Exception { + Timestamp timestamp = Timestamps.parse("2021-05-02T15:04:03.748Z"); + TestProto3.DateTimeMessage msgEmpty = TestProto3.DateTimeMessage.newBuilder().build(); + TestProto3.DateTimeMessage msgNonEmpty = TestProto3.DateTimeMessage.newBuilder() + .setTimestamp(timestamp) + .build(); + + Configuration conf = new Configuration(); + conf.setBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, true); + Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); + ReadUsingMR readUsingMR = new ReadUsingMR(); + String customClass = TestProto3.DateTimeMessage.class.getName(); + ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(), customClass); + List result = readUsingMR.read(outputPath); + + assertEquals(2, result.size()); + assertEquals(msgEmpty, result.get(0)); + assertEquals(msgNonEmpty, result.get(1)); + } + + @Test + public void testProto3WrappedMessageClass() throws Exception { + TestProto3.WrappedMessage msgEmpty = TestProto3.WrappedMessage.newBuilder().build(); + TestProto3.WrappedMessage msgNonEmpty = TestProto3.WrappedMessage.newBuilder() + .setWrappedDouble(DoubleValue.of(0.577)) + .setWrappedBool(BoolValue.of(true)) + .build(); + + + Configuration conf = new Configuration(); + conf.setBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, true); + Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); + ReadUsingMR readUsingMR = new ReadUsingMR(); + String customClass = TestProto3.WrappedMessage.class.getName(); + ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(), customClass); + List result = readUsingMR.read(outputPath); + + assertEquals(2, result.size()); + assertEquals(msgEmpty, result.get(0)); + assertEquals(msgNonEmpty, result.get(1)); + } + /** * Runs job that writes input to file and then job reading data back. */ diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java index 8c64197c33..2871590025 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java @@ -23,7 +23,6 @@ import com.google.protobuf.Struct; import com.google.protobuf.Value; import org.junit.Test; -import org.apache.parquet.proto.TestUtils; import org.apache.parquet.proto.test.TestProto3; import org.apache.parquet.proto.test.TestProtobuf; import org.apache.parquet.proto.test.Trees; @@ -40,12 +39,12 @@ public class ProtoSchemaConverterTest { /** * Converts given pbClass to parquet schema and compares it with expected parquet schema. */ - private static void testConversion(Class pbClass, String parquetSchemaString, boolean parquetSpecsCompliant) { - testConversion(pbClass, parquetSchemaString, new ProtoSchemaConverter(parquetSpecsCompliant)); + private static void testConversion(Class pbClass, String parquetSchemaString, boolean parquetSpecsCompliant, boolean unwrapWrappers) { + testConversion(pbClass, parquetSchemaString, new ProtoSchemaConverter(parquetSpecsCompliant, 5, unwrapWrappers)); } private static void testConversion(Class pbClass, String parquetSchemaString) { - testConversion(pbClass, parquetSchemaString, true); + testConversion(pbClass, parquetSchemaString, true, false); } private static void testConversion(Class pbClass, String parquetSchemaString, ProtoSchemaConverter converter) { @@ -54,6 +53,9 @@ private static void testConversion(Class pbClass, String parq assertEquals(expectedMT.toString(), schema.toString()); } + private void testConversion(Class pbClass, String parquetSchemaString, boolean parquetSpecsCompliant) throws Exception { + testConversion(pbClass, parquetSchemaString, parquetSpecsCompliant, false); + } /** * Tests that all protocol buffer datatypes are converted to correct parquet datatypes. @@ -206,7 +208,7 @@ public void testConvertRepeatedIntMessageNonSpecsCompliant() { " repeated int32 repeatedInt = 1;", "}"); - testConversion(TestProtobuf.RepeatedIntMessage.class, expectedSchema, false); + testConversion(TestProtobuf.RepeatedIntMessage.class, expectedSchema, false, false); } @Test @@ -231,7 +233,7 @@ public void testProto3ConvertRepeatedIntMessageNonSpecsCompliant() { " repeated int32 repeatedInt = 1;", "}"); - testConversion(TestProto3.RepeatedIntMessage.class, expectedSchema, false); + testConversion(TestProto3.RepeatedIntMessage.class, expectedSchema, false, false); } @Test @@ -263,7 +265,7 @@ public void testConvertRepeatedInnerMessageNonSpecsCompliant() { " }", "}"); - testConversion(TestProtobuf.RepeatedInnerMessage.class, expectedSchema, false); + testConversion(TestProtobuf.RepeatedInnerMessage.class, expectedSchema, false, false); } @Test @@ -295,7 +297,7 @@ public void testProto3ConvertRepeatedInnerMessageNonSpecsCompliant() { " }", "}"); - testConversion(TestProto3.RepeatedInnerMessage.class, expectedSchema, false); + testConversion(TestProto3.RepeatedInnerMessage.class, expectedSchema, false, false); } @Test @@ -323,7 +325,7 @@ public void testConvertMapIntMessageNonSpecsCompliant() { " }", "}"); - testConversion(TestProtobuf.MapIntMessage.class, expectedSchema, false); + testConversion(TestProtobuf.MapIntMessage.class, expectedSchema, false, false); } @Test @@ -351,7 +353,61 @@ public void testProto3ConvertMapIntMessageNonSpecsCompliant() { " }", "}"); - testConversion(TestProto3.MapIntMessage.class, expectedSchema, false); + testConversion(TestProto3.MapIntMessage.class, expectedSchema, false, false); + } + + @Test + public void testProto3ConvertDateTimeMessageWrapped() throws Exception { + String expectedSchema = + "message TestProto3.DateTimeMessage {\n" + + " optional group timestamp = 1 {\n" + + " optional int64 seconds = 1;\n" + + " optional int32 nanos = 2;\n" + + " }\n" + + " optional group date = 2 {\n" + + " optional int32 year = 1;\n" + + " optional int32 month = 2;\n" + + " optional int32 day = 3;\n" + + " }\n" + + " optional group time = 3 {\n" + + " optional int32 hours = 1;\n" + + " optional int32 minutes = 2;\n" + + " optional int32 seconds = 3;\n" + + " optional int32 nanos = 4;\n" + + " }\n" + + "}"; + + testConversion(TestProto3.DateTimeMessage.class, expectedSchema, false, false); + } + + @Test + public void testProto3ConvertDateTimeMessageUnwrapped() throws Exception { + String expectedSchema = + "message TestProto3.DateTimeMessage {\n" + + " optional int64 timestamp (TIMESTAMP(NANOS,true)) = 1;\n" + + " optional int32 date (DATE) = 2;\n" + + " optional int64 time (TIME(NANOS,true)) = 3;\n" + + "}"; + + testConversion(TestProto3.DateTimeMessage.class, expectedSchema, false, true); + } + + @Test + public void testProto3ConvertWrappedMessageUnwrapped() throws Exception { + String expectedSchema = + "message TestProto3.WrappedMessage {\n" + + " optional double wrappedDouble = 1;\n" + + " optional float wrappedFloat = 2;\n" + + " optional int64 wrappedInt64 = 3;\n" + + " optional int64 wrappedUInt64 = 4;\n" + + " optional int32 wrappedInt32 = 5;\n" + + " optional int64 wrappedUInt32 = 6;\n" + + " optional boolean wrappedBool = 7;\n" + + " optional binary wrappedString (UTF8) = 8;\n" + + " optional binary wrappedBytes = 9;\n" + + "}"; + + testConversion(TestProto3.WrappedMessage.class, expectedSchema, false, true); } @Test @@ -379,8 +435,8 @@ public void testBinaryTreeRecursion() throws Exception { " optional binary right = 3;", " }", "}"); - testConversion(Trees.BinaryTree.class, expectedSchema, new ProtoSchemaConverter(true, 1)); - testConversion(Trees.BinaryTree.class, TestUtils.readResource("BinaryTree.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH)); + testConversion(Trees.BinaryTree.class, expectedSchema, new ProtoSchemaConverter(true, 1, false)); + testConversion(Trees.BinaryTree.class, TestUtils.readResource("BinaryTree.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH, false)); } @@ -404,8 +460,8 @@ public void testWideTreeRecursion() throws Exception { " }", " }", "}"); - testConversion(Trees.WideTree.class, expectedSchema, new ProtoSchemaConverter(true, 1)); - testConversion(Trees.WideTree.class, TestUtils.readResource("WideTree.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH)); + testConversion(Trees.WideTree.class, expectedSchema, new ProtoSchemaConverter(true, 1, false)); + testConversion(Trees.WideTree.class, TestUtils.readResource("WideTree.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH, false)); } @@ -455,8 +511,8 @@ public void testValueRecursion() throws Exception { " }", " }", "}"); - testConversion(Value.class, expectedSchema, new ProtoSchemaConverter(true, 1)); - testConversion(Value.class, TestUtils.readResource("Value.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH)); + testConversion(Value.class, expectedSchema, new ProtoSchemaConverter(true, 1, false)); + testConversion(Value.class, TestUtils.readResource("Value.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH, false)); } @Test @@ -510,8 +566,8 @@ public void testStructRecursion() throws Exception { " }", " }", "}"); - testConversion(Struct.class, expectedSchema, new ProtoSchemaConverter(true, 1)); - testConversion(Struct.class, TestUtils.readResource("Struct.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH)); + testConversion(Struct.class, expectedSchema, new ProtoSchemaConverter(true, 1, false)); + testConversion(Struct.class, TestUtils.readResource("Struct.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH, false)); } @Test @@ -521,16 +577,16 @@ public void testDeepRecursion() { long expectedBinaryTreeSize = 4; long expectedStructSize = 7; for (int i = 0; i < 10; ++i) { - MessageType deepSchema = new ProtoSchemaConverter(true, i).convert(Trees.WideTree.class); + MessageType deepSchema = new ProtoSchemaConverter(true, i, false).convert(Trees.WideTree.class); // 3, 5, 7, 9, 11, 13, 15, 17, 19, 21 assertEquals(2 * i + 3, deepSchema.getPaths().size()); - deepSchema = new ProtoSchemaConverter(true, i).convert(Trees.BinaryTree.class); + deepSchema = new ProtoSchemaConverter(true, i, false).convert(Trees.BinaryTree.class); // 4, 10, 22, 46, 94, 190, 382, 766, 1534, 3070 assertEquals(expectedBinaryTreeSize, deepSchema.getPaths().size()); expectedBinaryTreeSize = 2 * expectedBinaryTreeSize + 2; - deepSchema = new ProtoSchemaConverter(true, i).convert(Struct.class); + deepSchema = new ProtoSchemaConverter(true, i, false).convert(Struct.class); // 7, 18, 40, 84, 172, 348, 700, 1404, 2812, 5628 assertEquals(expectedStructSize, deepSchema.getPaths().size()); expectedStructSize = 2 * expectedStructSize + 4; diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java index ef0356d01b..c4c34c9000 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java @@ -18,25 +18,44 @@ */ package org.apache.parquet.proto; +import com.google.protobuf.BoolValue; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; import com.google.protobuf.Descriptors; import com.google.protobuf.DynamicMessage; import com.google.protobuf.Message; +import com.google.protobuf.MessageOrBuilder; +import com.google.protobuf.StringValue; +import com.google.protobuf.Timestamp; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.protobuf.util.Timestamps; import com.google.protobuf.Value; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; -import org.junit.Test; -import static org.junit.Assert.*; -import org.mockito.InOrder; -import org.mockito.Mockito; +import org.apache.parquet.hadoop.ParquetWriter; import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.RecordConsumer; import org.apache.parquet.proto.test.TestProto3; import org.apache.parquet.proto.test.TestProtobuf; +import org.junit.Test; +import org.mockito.InOrder; +import org.mockito.Mockito; import org.apache.parquet.proto.test.Trees; import java.io.IOException; +import java.time.LocalDate; +import java.time.LocalTime; import java.util.List; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + public class ProtoWriteSupportTest { private ProtoWriteSupport createReadConsumerInstance(Class cls, RecordConsumer readConsumerMock) { @@ -1201,4 +1220,180 @@ public void testMapRecursion() { Mockito.verifyNoMoreInteractions(readConsumerMock); } + + @Test + public void testProto3DateTimeMessageUnwrapped() throws Exception { + Timestamp timestamp = Timestamps.parse("2021-05-02T15:04:03.748Z"); + LocalDate date = LocalDate.of(2021, 5, 2); + LocalTime time = LocalTime.of(15, 4, 3, 748_000_000); + + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setUnwrapProtoWrappers(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance( + TestProto3.DateTimeMessage.class, readConsumerMock, conf); + + TestProto3.DateTimeMessage.Builder msg = TestProto3.DateTimeMessage.newBuilder(); + msg.setTimestamp(timestamp); + msg.setDate(com.google.type.Date.newBuilder() + .setYear(date.getYear()) + .setMonth(date.getMonthValue()) + .setDay(date.getDayOfMonth()) + ); + msg.setTime(com.google.type.TimeOfDay.newBuilder() + .setHours(time.getHour()) + .setMinutes(time.getMinute()) + .setSeconds(time.getSecond()) + .setNanos(time.getNano()) + ); + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("timestamp", 0); + inOrder.verify(readConsumerMock).addLong(Timestamps.toNanos(timestamp)); + inOrder.verify(readConsumerMock).endField("timestamp", 0); + inOrder.verify(readConsumerMock).startField("date", 1); + inOrder.verify(readConsumerMock).addInteger((int) date.toEpochDay()); + inOrder.verify(readConsumerMock).endField("date", 1); + inOrder.verify(readConsumerMock).startField("time", 2); + inOrder.verify(readConsumerMock).addLong(time.toNanoOfDay()); + inOrder.verify(readConsumerMock).endField("time", 2); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3DateTimeMessageRoundTrip() throws Exception { + Timestamp timestamp = Timestamps.parse("2021-05-02T15:04:03.748Z"); + LocalDate date = LocalDate.of(2021, 5, 2); + LocalTime time = LocalTime.of(15, 4, 3, 748_000_000); + com.google.type.Date protoDate = com.google.type.Date.newBuilder() + .setYear(date.getYear()) + .setMonth(date.getMonthValue()) + .setDay(date.getDayOfMonth()) + .build(); + com.google.type.TimeOfDay protoTime = com.google.type.TimeOfDay.newBuilder() + .setHours(time.getHour()) + .setMinutes(time.getMinute()) + .setSeconds(time.getSecond()) + .setNanos(time.getNano()) + .build(); + + TestProto3.DateTimeMessage msg = TestProto3.DateTimeMessage.newBuilder() + .setTimestamp(timestamp) + .setDate(protoDate) + .setTime(protoTime) + .build(); + + //Write them out and read them back + Path tmpFilePath = TestUtils.someTemporaryFilePath(); + ParquetWriter writer = + ProtoParquetWriter.builder(tmpFilePath) + .withMessage(TestProto3.DateTimeMessage.class) + .config(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, "true") + .build(); + writer.write(msg); + writer.close(); + List gotBack = TestUtils.readMessages(tmpFilePath, TestProto3.DateTimeMessage.class); + + TestProto3.DateTimeMessage gotBackFirst = gotBack.get(0); + assertEquals(timestamp, gotBackFirst.getTimestamp()); + assertEquals(protoDate, gotBackFirst.getDate()); + assertEquals(protoTime, gotBackFirst.getTime()); + } + + @Test + public void testProto3WrappedMessageUnwrapped() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setUnwrapProtoWrappers(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance( + TestProto3.WrappedMessage.class, readConsumerMock, conf); + + TestProto3.WrappedMessage.Builder msg = TestProto3.WrappedMessage.newBuilder(); + msg.setWrappedDouble(DoubleValue.of(3.1415)); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("wrappedDouble", 0); + inOrder.verify(readConsumerMock).addDouble(3.1415); + inOrder.verify(readConsumerMock).endField("wrappedDouble", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3WrappedMessageUnwrappedRoundTrip() throws Exception { + TestProto3.WrappedMessage.Builder msg = TestProto3.WrappedMessage.newBuilder(); + msg.setWrappedDouble(DoubleValue.of(0.577)); + msg.setWrappedFloat(FloatValue.of(3.1415f)); + msg.setWrappedInt64(Int64Value.of(1_000_000_000L * 4)); + msg.setWrappedUInt64(UInt64Value.of(1_000_000_000L * 9)); + msg.setWrappedInt32(Int32Value.of(1_000_000 * 3)); + msg.setWrappedUInt32(UInt32Value.of(1_000_000 * 8)); + msg.setWrappedBool(BoolValue.of(true)); + msg.setWrappedString(StringValue.of("Good Will Hunting")); + msg.setWrappedBytes(BytesValue.of(ByteString.copyFrom("someText", "UTF-8"))); + + //Write them out and read them back + Path tmpFilePath = TestUtils.someTemporaryFilePath(); + ParquetWriter writer = + ProtoParquetWriter.builder(tmpFilePath) + .withMessage(TestProto3.WrappedMessage.class) + .config(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, "true") + .build(); + writer.write(msg); + writer.close(); + List gotBack = TestUtils.readMessages(tmpFilePath, TestProto3.WrappedMessage.class); + + TestProto3.WrappedMessage gotBackFirst = gotBack.get(0); + assertEquals(0.577, gotBackFirst.getWrappedDouble().getValue(), 1e-5); + assertEquals(3.1415f, gotBackFirst.getWrappedFloat().getValue(), 1e-5f); + assertEquals(1_000_000_000L * 4, gotBackFirst.getWrappedInt64().getValue()); + assertEquals(1_000_000_000L * 9, gotBackFirst.getWrappedUInt64().getValue()); + assertEquals(1_000_000 * 3, gotBackFirst.getWrappedInt32().getValue()); + assertEquals(1_000_000 * 8, gotBackFirst.getWrappedUInt32().getValue()); + assertEquals(BoolValue.of(true), gotBackFirst.getWrappedBool()); + assertEquals("Good Will Hunting", gotBackFirst.getWrappedString().getValue()); + assertEquals(ByteString.copyFrom("someText", "UTF-8"), gotBackFirst.getWrappedBytes().getValue()); + } + + @Test + public void testProto3WrappedMessageWithNullsRoundTrip() throws Exception { + TestProto3.WrappedMessage.Builder msg = TestProto3.WrappedMessage.newBuilder(); + msg.setWrappedFloat(FloatValue.of(3.1415f)); + msg.setWrappedString(StringValue.of("Good Will Hunting")); + msg.setWrappedInt32(Int32Value.of(0)); + + //Write them out and read them back + Path tmpFilePath = TestUtils.someTemporaryFilePath(); + ParquetWriter writer = + ProtoParquetWriter.builder(tmpFilePath) + .withMessage(TestProto3.WrappedMessage.class) + .config(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, "true") + .build(); + writer.write(msg); + writer.close(); + List gotBack = TestUtils.readMessages(tmpFilePath, TestProto3.WrappedMessage.class); + + TestProto3.WrappedMessage gotBackFirst = gotBack.get(0); + assertFalse(gotBackFirst.hasWrappedDouble()); + assertEquals(3.1415f, gotBackFirst.getWrappedFloat().getValue(), 1e-5f); + + // double-check that nulls are honored + assertTrue(gotBackFirst.hasWrappedFloat()); + assertFalse(gotBackFirst.hasWrappedInt64()); + assertFalse(gotBackFirst.hasWrappedUInt64()); + assertTrue(gotBackFirst.hasWrappedInt32()); + assertFalse(gotBackFirst.hasWrappedUInt32()); + assertEquals(0, gotBackFirst.getWrappedUInt32().getValue()); + assertFalse(gotBackFirst.hasWrappedBool()); + assertEquals("Good Will Hunting", gotBackFirst.getWrappedString().getValue()); + assertFalse(gotBackFirst.hasWrappedBytes()); + } } diff --git a/parquet-protobuf/src/test/resources/TestProto3.proto b/parquet-protobuf/src/test/resources/TestProto3.proto index fb4da1b0c1..c303fd1f5d 100644 --- a/parquet-protobuf/src/test/resources/TestProto3.proto +++ b/parquet-protobuf/src/test/resources/TestProto3.proto @@ -23,6 +23,11 @@ package TestProto3; option java_package = "org.apache.parquet.proto.test"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; +import "google/type/date.proto"; +import "google/type/timeofday.proto"; + // original Dremel paper structures: Original paper used groups, not internal // messages but groups were deprecated. @@ -156,3 +161,21 @@ message FirstCustomClassMessage { message SecondCustomClassMessage { string string = 11; } + +message DateTimeMessage { + google.protobuf.Timestamp timestamp = 1; + google.type.Date date = 2; + google.type.TimeOfDay time = 3; +} + +message WrappedMessage { + google.protobuf.DoubleValue wrappedDouble = 1; + google.protobuf.FloatValue wrappedFloat = 2; + google.protobuf.Int64Value wrappedInt64 = 3; + google.protobuf.UInt64Value wrappedUInt64 = 4; + google.protobuf.Int32Value wrappedInt32 = 5; + google.protobuf.UInt32Value wrappedUInt32 = 6; + google.protobuf.BoolValue wrappedBool = 7; + google.protobuf.StringValue wrappedString = 8; + google.protobuf.BytesValue wrappedBytes = 9; +}