diff --git a/parquet-protobuf/pom.xml b/parquet-protobuf/pom.xml
index ddf634a777..9f604372be 100644
--- a/parquet-protobuf/pom.xml
+++ b/parquet-protobuf/pom.xml
@@ -32,6 +32,7 @@
4.4
3.25.1
+ 2.28.0
1.1.5
@@ -67,6 +68,16 @@
protobuf-java
${protobuf.version}
+
+ com.google.protobuf
+ protobuf-java-util
+ ${protobuf.version}
+
+
+ com.google.api.grpc
+ proto-google-common-protos
+ ${common-protos.version}
+
org.apache.parquet
parquet-common
@@ -191,6 +202,7 @@
com.google.protobuf:protoc:${protobuf.version}
+ direct
test
all
direct
diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java
index da51788f2a..5c17af6fe4 100644
--- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java
+++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java
@@ -18,10 +18,21 @@
*/
package org.apache.parquet.proto;
+import com.google.protobuf.BoolValue;
import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
import com.google.protobuf.DescriptorProtos;
import com.google.protobuf.Descriptors;
+import com.google.protobuf.Descriptors.Descriptor;
+import com.google.protobuf.DoubleValue;
+import com.google.protobuf.FloatValue;
+import com.google.protobuf.Int32Value;
+import com.google.protobuf.Int64Value;
import com.google.protobuf.Message;
+import com.google.protobuf.StringValue;
+import com.google.protobuf.UInt32Value;
+import com.google.protobuf.UInt64Value;
+import com.google.protobuf.util.Timestamps;
import com.twitter.elephantbird.util.Protobufs;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.column.Dictionary;
@@ -42,6 +53,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.time.LocalDate;
+import java.time.LocalTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -233,6 +246,21 @@ public Optional visit(LogicalTypeAnnotation.ListLogicalTypeAnnotation
public Optional visit(LogicalTypeAnnotation.MapLogicalTypeAnnotation mapLogicalType) {
return of(new MapConverter(parentBuilder, fieldDescriptor, parquetType));
}
+
+ @Override
+ public Optional visit(LogicalTypeAnnotation.TimestampLogicalTypeAnnotation timestampLogicalType) {
+ return of(new ProtoTimestampConverter(parent, timestampLogicalType));
+ }
+
+ @Override
+ public Optional visit(LogicalTypeAnnotation.DateLogicalTypeAnnotation dateLogicalType) {
+ return of(new ProtoDateConverter(parent));
+ }
+
+ @Override
+ public Optional visit(LogicalTypeAnnotation.TimeLogicalTypeAnnotation timeLogicalType) {
+ return of(new ProtoTimeConverter(parent, timeLogicalType));
+ }
}).orElseGet(() -> newScalarConverter(parent, parentBuilder, fieldDescriptor, parquetType));
}
@@ -250,6 +278,37 @@ protected Converter newScalarConverter(ParentValueContainer pvc, Message.Builder
case INT: return new ProtoIntConverter(pvc);
case LONG: return new ProtoLongConverter(pvc);
case MESSAGE: {
+ if (parquetType.isPrimitive()) {
+ // if source is a Primitive type yet target is MESSAGE, it's probably a wrapped message
+ Descriptor messageType = fieldDescriptor.getMessageType();
+ if (messageType.equals(DoubleValue.getDescriptor())) {
+ return new ProtoDoubleValueConverter(pvc);
+ }
+ if (messageType.equals(FloatValue.getDescriptor())) {
+ return new ProtoFloatValueConverter(pvc);
+ }
+ if (messageType.equals(Int64Value.getDescriptor())) {
+ return new ProtoInt64ValueConverter(pvc);
+ }
+ if (messageType.equals(UInt64Value.getDescriptor())) {
+ return new ProtoUInt64ValueConverter(pvc);
+ }
+ if (messageType.equals(Int32Value.getDescriptor())) {
+ return new ProtoInt32ValueConverter(pvc);
+ }
+ if (messageType.equals(UInt32Value.getDescriptor())) {
+ return new ProtoUInt32ValueConverter(pvc);
+ }
+ if (messageType.equals(BoolValue.getDescriptor())) {
+ return new ProtoBoolValueConverter(pvc);
+ }
+ if (messageType.equals(StringValue.getDescriptor())) {
+ return new ProtoStringValueConverter(pvc);
+ }
+ if (messageType.equals(BytesValue.getDescriptor())) {
+ return new ProtoBytesValueConverter(pvc);
+ }
+ }
Message.Builder subBuilder = parentBuilder.newBuilderForField(fieldDescriptor);
return new ProtoMessageConverter(conf, pvc, subBuilder, parquetType.asGroupType(), extraMetadata);
}
@@ -295,7 +354,7 @@ public ProtoEnumConverter(ParentValueContainer parent, Descriptors.FieldDescript
* Fills lookup structure for translating between parquet enum values and Protocol buffer enum values.
* */
private Map makeLookupStructure(Descriptors.EnumDescriptor enumType) {
- Map lookupStructure = new HashMap();
+ Map lookupStructure = new HashMap<>();
if (extraMetadata.containsKey(METADATA_ENUM_PREFIX + enumType.getFullName())) {
String enumNameNumberPairs = extraMetadata.get(METADATA_ENUM_PREFIX + enumType.getFullName());
@@ -366,7 +425,7 @@ private Descriptors.EnumValueDescriptor translateEnumValue(Binary binaryValue) {
}
@Override
- final public void addBinary(Binary binaryValue) {
+ public void addBinary(Binary binaryValue) {
Descriptors.EnumValueDescriptor protoValue = translateEnumValue(binaryValue);
parent.add(protoValue);
}
@@ -392,7 +451,7 @@ public void setDictionary(Dictionary dictionary) {
}
- final class ProtoBinaryConverter extends PrimitiveConverter {
+ static final class ProtoBinaryConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -408,7 +467,7 @@ public void addBinary(Binary binary) {
}
- final class ProtoBooleanConverter extends PrimitiveConverter {
+ static final class ProtoBooleanConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -417,13 +476,13 @@ public ProtoBooleanConverter(ParentValueContainer parent) {
}
@Override
- final public void addBoolean(boolean value) {
+ public void addBoolean(boolean value) {
parent.add(value);
}
}
- final class ProtoDoubleConverter extends PrimitiveConverter {
+ static final class ProtoDoubleConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -437,7 +496,7 @@ public void addDouble(double value) {
}
}
- final class ProtoFloatConverter extends PrimitiveConverter {
+ static final class ProtoFloatConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -451,7 +510,7 @@ public void addFloat(float value) {
}
}
- final class ProtoIntConverter extends PrimitiveConverter {
+ static final class ProtoIntConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -465,7 +524,7 @@ public void addInt(int value) {
}
}
- final class ProtoLongConverter extends PrimitiveConverter {
+ static final class ProtoLongConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -479,7 +538,7 @@ public void addLong(long value) {
}
}
- final class ProtoStringConverter extends PrimitiveConverter {
+ static final class ProtoStringConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -495,6 +554,218 @@ public void addBinary(Binary binary) {
}
+ static final class ProtoTimestampConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+ final LogicalTypeAnnotation.TimestampLogicalTypeAnnotation logicalTypeAnnotation;
+
+ public ProtoTimestampConverter(ParentValueContainer parent, LogicalTypeAnnotation.TimestampLogicalTypeAnnotation logicalTypeAnnotation) {
+ this.parent = parent;
+ this.logicalTypeAnnotation = logicalTypeAnnotation;
+ }
+
+ @Override
+ public void addLong(long value) {
+ switch (logicalTypeAnnotation.getUnit()) {
+ case MICROS:
+ parent.add(Timestamps.fromMicros(value));
+ break;
+ case MILLIS:
+ parent.add(Timestamps.fromMillis(value));
+ break;
+ case NANOS:
+ parent.add(Timestamps.fromNanos(value));
+ break;
+ }
+ }
+ }
+
+ static final class ProtoDateConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoDateConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addInt(int value) {
+ LocalDate localDate = LocalDate.ofEpochDay(value);
+ com.google.type.Date date = com.google.type.Date.newBuilder()
+ .setYear(localDate.getYear())
+ .setMonth(localDate.getMonthValue())
+ .setDay(localDate.getDayOfMonth())
+ .build();
+ parent.add(date);
+ }
+ }
+
+ static final class ProtoTimeConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+ final LogicalTypeAnnotation.TimeLogicalTypeAnnotation logicalTypeAnnotation;
+
+ public ProtoTimeConverter(ParentValueContainer parent, LogicalTypeAnnotation.TimeLogicalTypeAnnotation logicalTypeAnnotation) {
+ this.parent = parent;
+ this.logicalTypeAnnotation = logicalTypeAnnotation;
+ }
+
+ @Override
+ public void addLong(long value) {
+ LocalTime localTime;
+ switch (logicalTypeAnnotation.getUnit()) {
+ case MILLIS:
+ localTime = LocalTime.ofNanoOfDay(value * 1_000_000);
+ break;
+ case MICROS:
+ localTime = LocalTime.ofNanoOfDay(value * 1_000);
+ break;
+ case NANOS:
+ localTime = LocalTime.ofNanoOfDay(value);
+ break;
+ default:
+ throw new IllegalArgumentException("Unrecognized TimeUnit: " + logicalTypeAnnotation.getUnit());
+ }
+ com.google.type.TimeOfDay timeOfDay = com.google.type.TimeOfDay.newBuilder()
+ .setHours(localTime.getHour())
+ .setMinutes(localTime.getMinute())
+ .setSeconds(localTime.getSecond())
+ .setNanos(localTime.getNano())
+ .build();
+ parent.add(timeOfDay);
+ }
+ }
+
+ static final class ProtoDoubleValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoDoubleValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addDouble(double value) {
+ parent.add(DoubleValue.of(value));
+ }
+ }
+
+ static final class ProtoFloatValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoFloatValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addFloat(float value) {
+ parent.add(FloatValue.of(value));
+ }
+ }
+
+ static final class ProtoInt64ValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoInt64ValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addLong(long value) {
+ parent.add(Int64Value.of(value));
+ }
+ }
+
+ static final class ProtoUInt64ValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoUInt64ValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addLong(long value) {
+ parent.add(UInt64Value.of(value));
+ }
+ }
+
+ static final class ProtoInt32ValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoInt32ValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addInt(int value) {
+ parent.add(Int32Value.of(value));
+ }
+ }
+
+ static final class ProtoUInt32ValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoUInt32ValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addLong(long value) {
+ parent.add(UInt32Value.of(Math.toIntExact(value)));
+ }
+ }
+
+ static final class ProtoBoolValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoBoolValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addBoolean(boolean value) {
+ parent.add(BoolValue.of(value));
+ }
+
+ }
+
+ static final class ProtoStringValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoStringValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addBinary(Binary binary) {
+ String str = binary.toStringUsingUTF8();
+ parent.add(StringValue.of(str));
+ }
+
+ }
+
+ static final class ProtoBytesValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoBytesValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addBinary(Binary binary) {
+ ByteString byteString = ByteString.copyFrom(binary.toByteBuffer());
+ parent.add(BytesValue.of(byteString));
+ }
+ }
+
/**
* This class unwraps the additional LIST wrapper and makes it possible to read the underlying data and then convert
* it to protobuf.
diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java
index a6a779d074..83f3970c23 100644
--- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java
+++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java
@@ -18,11 +18,24 @@
*/
package org.apache.parquet.proto;
+import com.google.protobuf.BoolValue;
+import com.google.protobuf.BytesValue;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.protobuf.Descriptors;
+import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType;
+import com.google.protobuf.DoubleValue;
+import com.google.protobuf.FloatValue;
+import com.google.protobuf.Int32Value;
+import com.google.protobuf.Int64Value;
import com.google.protobuf.Message;
+import com.google.protobuf.StringValue;
+import com.google.protobuf.Timestamp;
+import com.google.protobuf.UInt32Value;
+import com.google.protobuf.UInt64Value;
+import com.google.type.Date;
+import com.google.type.TimeOfDay;
import com.twitter.elephantbird.util.Protobufs;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.conf.HadoopParquetConfiguration;
@@ -31,6 +44,7 @@
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
import org.apache.parquet.schema.Type;
+import org.apache.parquet.schema.Type.Repetition;
import org.apache.parquet.schema.Types;
import org.apache.parquet.schema.Types.Builder;
import org.apache.parquet.schema.Types.GroupBuilder;
@@ -40,11 +54,20 @@
import java.util.List;
import javax.annotation.Nullable;
+import static org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit;
+import static org.apache.parquet.schema.LogicalTypeAnnotation.dateType;
import static org.apache.parquet.schema.LogicalTypeAnnotation.enumType;
import static org.apache.parquet.schema.LogicalTypeAnnotation.listType;
import static org.apache.parquet.schema.LogicalTypeAnnotation.mapType;
import static org.apache.parquet.schema.LogicalTypeAnnotation.stringType;
-import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.*;
+import static org.apache.parquet.schema.LogicalTypeAnnotation.timeType;
+import static org.apache.parquet.schema.LogicalTypeAnnotation.timestampType;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FLOAT;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;
/**
* Converts a Protocol Buffer Descriptor into a Parquet schema.
@@ -55,6 +78,7 @@ public class ProtoSchemaConverter {
public static final String PB_MAX_RECURSION = "parquet.proto.maxRecursion";
private final boolean parquetSpecsCompliant;
+ private final boolean unwrapProtoWrappers;
// TODO: use proto custom options to override per field.
private final int maxRecursion;
@@ -76,7 +100,7 @@ public ProtoSchemaConverter() {
* by the parquet specifications. If set to true, specs compliant schemas are used.
*/
public ProtoSchemaConverter(boolean parquetSpecsCompliant) {
- this(parquetSpecsCompliant, 5);
+ this(parquetSpecsCompliant, 5, false);
}
/**
@@ -98,7 +122,9 @@ public ProtoSchemaConverter(Configuration config) {
public ProtoSchemaConverter(ParquetConfiguration config) {
this(
config.getBoolean(ProtoWriteSupport.PB_SPECS_COMPLIANT_WRITE, false),
- config.getInt(PB_MAX_RECURSION, 5));
+ config.getInt(PB_MAX_RECURSION, 5),
+ config.getBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, false)
+ );
}
/**
@@ -112,8 +138,25 @@ public ProtoSchemaConverter(ParquetConfiguration config) {
* bytes instead of their actual schema.
*/
public ProtoSchemaConverter(boolean parquetSpecsCompliant, int maxRecursion) {
+ this(parquetSpecsCompliant, maxRecursion, false);
+ }
+
+ /**
+ * Instantiate a schema converter to get the parquet schema corresponding to protobuf classes.
+ *
+ * @param parquetSpecsCompliant If set to false, the parquet schema generated will be using the old
+ * schema style (prior to PARQUET-968) to provide backward-compatibility
+ * but which does not use LIST and MAP wrappers around collections as required
+ * by the parquet specifications. If set to true, specs compliant schemas are used.
+ * @param maxRecursion The maximum recursion depth messages are allowed to go before terminating as
+ * bytes instead of their actual schema.
+ * @param unwrapProtoWrappers If set to true, unwrap common Proto wrappers like Timestamp and DoubleValue
+ * with corresponding OPTIONAL logical annotations. Primitive types become REQUIRED.
+ */
+ public ProtoSchemaConverter(boolean parquetSpecsCompliant, int maxRecursion, boolean unwrapProtoWrappers) {
this.parquetSpecsCompliant = parquetSpecsCompliant;
this.maxRecursion = maxRecursion;
+ this.unwrapProtoWrappers = unwrapProtoWrappers;
}
/**
@@ -178,6 +221,46 @@ private static Type.Repetition getRepetition(FieldDescriptor descriptor) {
private Builder extends Builder, GroupBuilder>, GroupBuilder> addField(FieldDescriptor descriptor, final GroupBuilder builder, ImmutableSetMultimap seen, int depth) {
if (descriptor.getJavaType() == JavaType.MESSAGE) {
+ if (unwrapProtoWrappers) {
+ Descriptor messageType = descriptor.getMessageType();
+ if (messageType.equals(Timestamp.getDescriptor())) {
+ return builder.primitive(INT64, getRepetition(descriptor)).as(timestampType(true, TimeUnit.NANOS));
+ }
+ if (messageType.equals(Date.getDescriptor())) {
+ return builder.primitive(INT32, getRepetition(descriptor)).as(dateType());
+ }
+ if (messageType.equals(TimeOfDay.getDescriptor())) {
+ return builder.primitive(INT64, getRepetition(descriptor)).as(timeType(true, TimeUnit.NANOS));
+ }
+ if (messageType.equals(DoubleValue.getDescriptor())) {
+ return builder.primitive(DOUBLE, getRepetition(descriptor));
+ }
+ if (messageType.equals(StringValue.getDescriptor())) {
+ return builder.primitive(BINARY, getRepetition(descriptor)).as(stringType());
+ }
+ if (messageType.equals(BoolValue.getDescriptor())) {
+ return builder.primitive(BOOLEAN, getRepetition(descriptor));
+ }
+ if (messageType.equals(FloatValue.getDescriptor())) {
+ return builder.primitive(FLOAT, getRepetition(descriptor));
+ }
+ if (messageType.equals(Int64Value.getDescriptor())) {
+ return builder.primitive(INT64, getRepetition(descriptor));
+ }
+ if (messageType.equals(UInt64Value.getDescriptor())) {
+ return builder.primitive(INT64, getRepetition(descriptor));
+ }
+ if (messageType.equals(Int32Value.getDescriptor())) {
+ return builder.primitive(INT32, getRepetition(descriptor));
+ }
+ if (messageType.equals(UInt32Value.getDescriptor())) {
+ return builder.primitive(INT64, getRepetition(descriptor));
+ }
+ if (messageType.equals(BytesValue.getDescriptor())) {
+ return builder.primitive(BINARY, getRepetition(descriptor));
+ }
+ }
+
return addMessageField(descriptor, builder, seen, depth);
}
@@ -186,8 +269,8 @@ private Builder extends Builder, GroupBuilder>, GroupBuilder> addF
// the old schema style did not include the LIST wrapper around repeated fields
return addRepeatedPrimitive(parquetType.primitiveType, parquetType.logicalTypeAnnotation, builder);
}
-
- return builder.primitive(parquetType.primitiveType, getRepetition(descriptor)).as(parquetType.logicalTypeAnnotation);
+ Repetition repetition = unwrapProtoWrappers ? Repetition.REQUIRED : getRepetition(descriptor);
+ return builder.primitive(parquetType.primitiveType, repetition).as(parquetType.logicalTypeAnnotation);
}
private static Builder extends Builder, GroupBuilder>, GroupBuilder> addRepeatedPrimitive(PrimitiveTypeName primitiveType,
diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java
index b13acd2a57..c5081e759e 100644
--- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java
+++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java
@@ -21,6 +21,9 @@
import com.google.protobuf.*;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
+import com.google.protobuf.util.Timestamps;
+import com.google.type.Date;
+import com.google.type.TimeOfDay;
import com.twitter.elephantbird.util.Protobufs;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.conf.HadoopParquetConfiguration;
@@ -36,6 +39,8 @@
import org.slf4j.LoggerFactory;
import java.lang.reflect.Array;
+import java.time.LocalDate;
+import java.time.LocalTime;
import java.util.*;
import static java.util.Optional.ofNullable;
@@ -57,7 +62,10 @@ public class ProtoWriteSupport extends WriteSupport<
// but is set to false by default to keep backward compatibility.
public static final String PB_SPECS_COMPLIANT_WRITE = "parquet.proto.writeSpecsCompliant";
+ public static final String PB_UNWRAP_PROTO_WRAPPERS = "parquet.proto.unwrapProtoWrappers";
+
private boolean writeSpecsCompliant = false;
+ private boolean unwrapProtoWrappers = false;
private RecordConsumer recordConsumer;
private Class extends Message> protoMessage;
private Descriptor descriptor;
@@ -96,6 +104,10 @@ public static void setWriteSpecsCompliant(Configuration configuration, boolean w
configuration.setBoolean(PB_SPECS_COMPLIANT_WRITE, writeSpecsCompliant);
}
+ public static void setUnwrapProtoWrappers(Configuration configuration, boolean unwrapProtoWrappers) {
+ configuration.setBoolean(PB_UNWRAP_PROTO_WRAPPERS, unwrapProtoWrappers);
+ }
+
/**
* Writes Protocol buffer to parquet file.
* @param record instance of Message.Builder or Message.
@@ -144,6 +156,7 @@ public WriteContext init(ParquetConfiguration configuration) {
extraMetaData.put(ProtoReadSupport.PB_CLASS, protoMessage.getName());
}
+ unwrapProtoWrappers = configuration.getBoolean(PB_UNWRAP_PROTO_WRAPPERS, unwrapProtoWrappers);
writeSpecsCompliant = configuration.getBoolean(PB_SPECS_COMPLIANT_WRITE, writeSpecsCompliant);
MessageType rootSchema = new ProtoSchemaConverter(configuration).convert(descriptor);
validatedMapping(descriptor, rootSchema);
@@ -152,6 +165,7 @@ public WriteContext init(ParquetConfiguration configuration) {
extraMetaData.put(ProtoReadSupport.PB_DESCRIPTOR, descriptor.toProto().toString());
extraMetaData.put(PB_SPECS_COMPLIANT_WRITE, String.valueOf(writeSpecsCompliant));
+ extraMetaData.put(PB_UNWRAP_PROTO_WRAPPERS, String.valueOf(unwrapProtoWrappers));
return new WriteContext(rootSchema, extraMetaData);
}
@@ -265,6 +279,46 @@ private FieldWriter createMessageWriter(FieldDescriptor fieldDescriptor, Type ty
return createMapWriter(fieldDescriptor, type);
}
+ if (unwrapProtoWrappers) {
+ Descriptor messageType = fieldDescriptor.getMessageType();
+ if (messageType.equals(Timestamp.getDescriptor())) {
+ return new TimestampWriter();
+ }
+ if (messageType.equals(Date.getDescriptor())) {
+ return new DateWriter();
+ }
+ if (messageType.equals(TimeOfDay.getDescriptor())) {
+ return new TimeWriter();
+ }
+ if (messageType.equals(DoubleValue.getDescriptor())) {
+ return new DoubleValueWriter();
+ }
+ if (messageType.equals(FloatValue.getDescriptor())) {
+ return new FloatValueWriter();
+ }
+ if (messageType.equals(Int64Value.getDescriptor())) {
+ return new Int64ValueWriter();
+ }
+ if (messageType.equals(UInt64Value.getDescriptor())) {
+ return new UInt64ValueWriter();
+ }
+ if (messageType.equals(Int32Value.getDescriptor())) {
+ return new Int32ValueWriter();
+ }
+ if (messageType.equals(UInt32Value.getDescriptor())) {
+ return new UInt32ValueWriter();
+ }
+ if (messageType.equals(BoolValue.getDescriptor())) {
+ return new BoolValueWriter();
+ }
+ if (messageType.equals(StringValue.getDescriptor())) {
+ return new StringValueWriter();
+ }
+ if (messageType.equals(BytesValue.getDescriptor())) {
+ return new BytesValueWriter();
+ }
+ }
+
// This can happen now that recursive schemas get truncated to bytes. Write the bytes.
if (type.isPrimitive() && type.asPrimitiveType().getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) {
return new BinaryWriter();
@@ -584,6 +638,98 @@ final void writeRawValue(Object value) {
}
}
+ class TimestampWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ Timestamp timestamp = (Timestamp) value;
+ recordConsumer.addLong(Timestamps.toNanos(timestamp));
+ }
+ }
+
+ class DateWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ Date date = (Date) value;
+ LocalDate localDate = LocalDate.of(date.getYear(), date.getMonth(), date.getDay());
+ recordConsumer.addInteger((int) localDate.toEpochDay());
+ }
+ }
+
+ class TimeWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ com.google.type.TimeOfDay timeOfDay = (com.google.type.TimeOfDay) value;
+ LocalTime localTime = LocalTime.of(timeOfDay.getHours(), timeOfDay.getMinutes(), timeOfDay.getSeconds(), timeOfDay.getNanos());
+ recordConsumer.addLong(localTime.toNanoOfDay());
+ }
+ }
+
+ class DoubleValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addDouble(((DoubleValue) value).getValue());
+ }
+ }
+
+ class FloatValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addFloat(((FloatValue) value).getValue());
+ }
+ }
+
+ class Int64ValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addLong(((Int64Value) value).getValue());
+ }
+ }
+
+ class UInt64ValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addLong(((UInt64Value) value).getValue());
+ }
+ }
+
+ class Int32ValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addInteger(((Int32Value) value).getValue());
+ }
+ }
+
+ class UInt32ValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addLong(((UInt32Value) value).getValue());
+ }
+ }
+
+ class BoolValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addBoolean(((BoolValue) value).getValue());
+ }
+ }
+
+ class StringValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ Binary binaryString = Binary.fromString(((StringValue) value).getValue());
+ recordConsumer.addBinary(binaryString);
+ }
+ }
+
+ class BytesValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ byte[] byteArray = ((BytesValue) value).getValue().toByteArray();
+ Binary binary = Binary.fromConstantByteArray(byteArray);
+ recordConsumer.addBinary(binary);
+ }
+ }
+
private FieldWriter unknownType(FieldDescriptor fieldDescriptor) {
String exceptionMsg = "Unknown type with descriptor \"" + fieldDescriptor
+ "\" and type \"" + fieldDescriptor.getJavaType() + "\".";
diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java
index 4debe77f8f..605c322664 100644
--- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java
+++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java
@@ -18,8 +18,12 @@
*/
package org.apache.parquet.proto;
+import com.google.protobuf.BoolValue;
import com.google.protobuf.ByteString;
+import com.google.protobuf.DoubleValue;
import com.google.protobuf.Message;
+import com.google.protobuf.Timestamp;
+import com.google.protobuf.util.Timestamps;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.proto.test.TestProto3;
@@ -32,7 +36,9 @@
import java.util.List;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
public class ProtoInputOutputFormatTest {
@@ -607,6 +613,49 @@ public void testProto3RepeatedMessages() throws Exception {
assertTrue(third.getTwo().isEmpty());
}
+ @Test
+ public void testProto3TimestampMessageClass() throws Exception {
+ Timestamp timestamp = Timestamps.parse("2021-05-02T15:04:03.748Z");
+ TestProto3.DateTimeMessage msgEmpty = TestProto3.DateTimeMessage.newBuilder().build();
+ TestProto3.DateTimeMessage msgNonEmpty = TestProto3.DateTimeMessage.newBuilder()
+ .setTimestamp(timestamp)
+ .build();
+
+ Configuration conf = new Configuration();
+ conf.setBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, true);
+ Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty);
+ ReadUsingMR readUsingMR = new ReadUsingMR();
+ String customClass = TestProto3.DateTimeMessage.class.getName();
+ ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(), customClass);
+ List result = readUsingMR.read(outputPath);
+
+ assertEquals(2, result.size());
+ assertEquals(msgEmpty, result.get(0));
+ assertEquals(msgNonEmpty, result.get(1));
+ }
+
+ @Test
+ public void testProto3WrappedMessageClass() throws Exception {
+ TestProto3.WrappedMessage msgEmpty = TestProto3.WrappedMessage.newBuilder().build();
+ TestProto3.WrappedMessage msgNonEmpty = TestProto3.WrappedMessage.newBuilder()
+ .setWrappedDouble(DoubleValue.of(0.577))
+ .setWrappedBool(BoolValue.of(true))
+ .build();
+
+
+ Configuration conf = new Configuration();
+ conf.setBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, true);
+ Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty);
+ ReadUsingMR readUsingMR = new ReadUsingMR();
+ String customClass = TestProto3.WrappedMessage.class.getName();
+ ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(), customClass);
+ List result = readUsingMR.read(outputPath);
+
+ assertEquals(2, result.size());
+ assertEquals(msgEmpty, result.get(0));
+ assertEquals(msgNonEmpty, result.get(1));
+ }
+
/**
* Runs job that writes input to file and then job reading data back.
*/
diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java
index 8c64197c33..2871590025 100644
--- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java
+++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java
@@ -23,7 +23,6 @@
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import org.junit.Test;
-import org.apache.parquet.proto.TestUtils;
import org.apache.parquet.proto.test.TestProto3;
import org.apache.parquet.proto.test.TestProtobuf;
import org.apache.parquet.proto.test.Trees;
@@ -40,12 +39,12 @@ public class ProtoSchemaConverterTest {
/**
* Converts given pbClass to parquet schema and compares it with expected parquet schema.
*/
- private static void testConversion(Class extends Message> pbClass, String parquetSchemaString, boolean parquetSpecsCompliant) {
- testConversion(pbClass, parquetSchemaString, new ProtoSchemaConverter(parquetSpecsCompliant));
+ private static void testConversion(Class extends Message> pbClass, String parquetSchemaString, boolean parquetSpecsCompliant, boolean unwrapWrappers) {
+ testConversion(pbClass, parquetSchemaString, new ProtoSchemaConverter(parquetSpecsCompliant, 5, unwrapWrappers));
}
private static void testConversion(Class extends Message> pbClass, String parquetSchemaString) {
- testConversion(pbClass, parquetSchemaString, true);
+ testConversion(pbClass, parquetSchemaString, true, false);
}
private static void testConversion(Class extends Message> pbClass, String parquetSchemaString, ProtoSchemaConverter converter) {
@@ -54,6 +53,9 @@ private static void testConversion(Class extends Message> pbClass, String parq
assertEquals(expectedMT.toString(), schema.toString());
}
+ private void testConversion(Class extends Message> pbClass, String parquetSchemaString, boolean parquetSpecsCompliant) throws Exception {
+ testConversion(pbClass, parquetSchemaString, parquetSpecsCompliant, false);
+ }
/**
* Tests that all protocol buffer datatypes are converted to correct parquet datatypes.
@@ -206,7 +208,7 @@ public void testConvertRepeatedIntMessageNonSpecsCompliant() {
" repeated int32 repeatedInt = 1;",
"}");
- testConversion(TestProtobuf.RepeatedIntMessage.class, expectedSchema, false);
+ testConversion(TestProtobuf.RepeatedIntMessage.class, expectedSchema, false, false);
}
@Test
@@ -231,7 +233,7 @@ public void testProto3ConvertRepeatedIntMessageNonSpecsCompliant() {
" repeated int32 repeatedInt = 1;",
"}");
- testConversion(TestProto3.RepeatedIntMessage.class, expectedSchema, false);
+ testConversion(TestProto3.RepeatedIntMessage.class, expectedSchema, false, false);
}
@Test
@@ -263,7 +265,7 @@ public void testConvertRepeatedInnerMessageNonSpecsCompliant() {
" }",
"}");
- testConversion(TestProtobuf.RepeatedInnerMessage.class, expectedSchema, false);
+ testConversion(TestProtobuf.RepeatedInnerMessage.class, expectedSchema, false, false);
}
@Test
@@ -295,7 +297,7 @@ public void testProto3ConvertRepeatedInnerMessageNonSpecsCompliant() {
" }",
"}");
- testConversion(TestProto3.RepeatedInnerMessage.class, expectedSchema, false);
+ testConversion(TestProto3.RepeatedInnerMessage.class, expectedSchema, false, false);
}
@Test
@@ -323,7 +325,7 @@ public void testConvertMapIntMessageNonSpecsCompliant() {
" }",
"}");
- testConversion(TestProtobuf.MapIntMessage.class, expectedSchema, false);
+ testConversion(TestProtobuf.MapIntMessage.class, expectedSchema, false, false);
}
@Test
@@ -351,7 +353,61 @@ public void testProto3ConvertMapIntMessageNonSpecsCompliant() {
" }",
"}");
- testConversion(TestProto3.MapIntMessage.class, expectedSchema, false);
+ testConversion(TestProto3.MapIntMessage.class, expectedSchema, false, false);
+ }
+
+ @Test
+ public void testProto3ConvertDateTimeMessageWrapped() throws Exception {
+ String expectedSchema =
+ "message TestProto3.DateTimeMessage {\n" +
+ " optional group timestamp = 1 {\n" +
+ " optional int64 seconds = 1;\n" +
+ " optional int32 nanos = 2;\n" +
+ " }\n" +
+ " optional group date = 2 {\n" +
+ " optional int32 year = 1;\n" +
+ " optional int32 month = 2;\n" +
+ " optional int32 day = 3;\n" +
+ " }\n" +
+ " optional group time = 3 {\n" +
+ " optional int32 hours = 1;\n" +
+ " optional int32 minutes = 2;\n" +
+ " optional int32 seconds = 3;\n" +
+ " optional int32 nanos = 4;\n" +
+ " }\n" +
+ "}";
+
+ testConversion(TestProto3.DateTimeMessage.class, expectedSchema, false, false);
+ }
+
+ @Test
+ public void testProto3ConvertDateTimeMessageUnwrapped() throws Exception {
+ String expectedSchema =
+ "message TestProto3.DateTimeMessage {\n" +
+ " optional int64 timestamp (TIMESTAMP(NANOS,true)) = 1;\n" +
+ " optional int32 date (DATE) = 2;\n" +
+ " optional int64 time (TIME(NANOS,true)) = 3;\n" +
+ "}";
+
+ testConversion(TestProto3.DateTimeMessage.class, expectedSchema, false, true);
+ }
+
+ @Test
+ public void testProto3ConvertWrappedMessageUnwrapped() throws Exception {
+ String expectedSchema =
+ "message TestProto3.WrappedMessage {\n" +
+ " optional double wrappedDouble = 1;\n" +
+ " optional float wrappedFloat = 2;\n" +
+ " optional int64 wrappedInt64 = 3;\n" +
+ " optional int64 wrappedUInt64 = 4;\n" +
+ " optional int32 wrappedInt32 = 5;\n" +
+ " optional int64 wrappedUInt32 = 6;\n" +
+ " optional boolean wrappedBool = 7;\n" +
+ " optional binary wrappedString (UTF8) = 8;\n" +
+ " optional binary wrappedBytes = 9;\n" +
+ "}";
+
+ testConversion(TestProto3.WrappedMessage.class, expectedSchema, false, true);
}
@Test
@@ -379,8 +435,8 @@ public void testBinaryTreeRecursion() throws Exception {
" optional binary right = 3;",
" }",
"}");
- testConversion(Trees.BinaryTree.class, expectedSchema, new ProtoSchemaConverter(true, 1));
- testConversion(Trees.BinaryTree.class, TestUtils.readResource("BinaryTree.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH));
+ testConversion(Trees.BinaryTree.class, expectedSchema, new ProtoSchemaConverter(true, 1, false));
+ testConversion(Trees.BinaryTree.class, TestUtils.readResource("BinaryTree.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH, false));
}
@@ -404,8 +460,8 @@ public void testWideTreeRecursion() throws Exception {
" }",
" }",
"}");
- testConversion(Trees.WideTree.class, expectedSchema, new ProtoSchemaConverter(true, 1));
- testConversion(Trees.WideTree.class, TestUtils.readResource("WideTree.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH));
+ testConversion(Trees.WideTree.class, expectedSchema, new ProtoSchemaConverter(true, 1, false));
+ testConversion(Trees.WideTree.class, TestUtils.readResource("WideTree.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH, false));
}
@@ -455,8 +511,8 @@ public void testValueRecursion() throws Exception {
" }",
" }",
"}");
- testConversion(Value.class, expectedSchema, new ProtoSchemaConverter(true, 1));
- testConversion(Value.class, TestUtils.readResource("Value.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH));
+ testConversion(Value.class, expectedSchema, new ProtoSchemaConverter(true, 1, false));
+ testConversion(Value.class, TestUtils.readResource("Value.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH, false));
}
@Test
@@ -510,8 +566,8 @@ public void testStructRecursion() throws Exception {
" }",
" }",
"}");
- testConversion(Struct.class, expectedSchema, new ProtoSchemaConverter(true, 1));
- testConversion(Struct.class, TestUtils.readResource("Struct.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH));
+ testConversion(Struct.class, expectedSchema, new ProtoSchemaConverter(true, 1, false));
+ testConversion(Struct.class, TestUtils.readResource("Struct.par"), new ProtoSchemaConverter(true, PAR_RECURSION_DEPTH, false));
}
@Test
@@ -521,16 +577,16 @@ public void testDeepRecursion() {
long expectedBinaryTreeSize = 4;
long expectedStructSize = 7;
for (int i = 0; i < 10; ++i) {
- MessageType deepSchema = new ProtoSchemaConverter(true, i).convert(Trees.WideTree.class);
+ MessageType deepSchema = new ProtoSchemaConverter(true, i, false).convert(Trees.WideTree.class);
// 3, 5, 7, 9, 11, 13, 15, 17, 19, 21
assertEquals(2 * i + 3, deepSchema.getPaths().size());
- deepSchema = new ProtoSchemaConverter(true, i).convert(Trees.BinaryTree.class);
+ deepSchema = new ProtoSchemaConverter(true, i, false).convert(Trees.BinaryTree.class);
// 4, 10, 22, 46, 94, 190, 382, 766, 1534, 3070
assertEquals(expectedBinaryTreeSize, deepSchema.getPaths().size());
expectedBinaryTreeSize = 2 * expectedBinaryTreeSize + 2;
- deepSchema = new ProtoSchemaConverter(true, i).convert(Struct.class);
+ deepSchema = new ProtoSchemaConverter(true, i, false).convert(Struct.class);
// 7, 18, 40, 84, 172, 348, 700, 1404, 2812, 5628
assertEquals(expectedStructSize, deepSchema.getPaths().size());
expectedStructSize = 2 * expectedStructSize + 4;
diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java
index ef0356d01b..c4c34c9000 100644
--- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java
+++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java
@@ -18,25 +18,44 @@
*/
package org.apache.parquet.proto;
+import com.google.protobuf.BoolValue;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import com.google.protobuf.DoubleValue;
+import com.google.protobuf.FloatValue;
+import com.google.protobuf.Int32Value;
+import com.google.protobuf.Int64Value;
import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.Message;
+import com.google.protobuf.MessageOrBuilder;
+import com.google.protobuf.StringValue;
+import com.google.protobuf.Timestamp;
+import com.google.protobuf.UInt32Value;
+import com.google.protobuf.UInt64Value;
+import com.google.protobuf.util.Timestamps;
import com.google.protobuf.Value;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
-import org.junit.Test;
-import static org.junit.Assert.*;
-import org.mockito.InOrder;
-import org.mockito.Mockito;
+import org.apache.parquet.hadoop.ParquetWriter;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.proto.test.TestProto3;
import org.apache.parquet.proto.test.TestProtobuf;
+import org.junit.Test;
+import org.mockito.InOrder;
+import org.mockito.Mockito;
import org.apache.parquet.proto.test.Trees;
import java.io.IOException;
+import java.time.LocalDate;
+import java.time.LocalTime;
import java.util.List;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
public class ProtoWriteSupportTest {
private ProtoWriteSupport createReadConsumerInstance(Class cls, RecordConsumer readConsumerMock) {
@@ -1201,4 +1220,180 @@ public void testMapRecursion() {
Mockito.verifyNoMoreInteractions(readConsumerMock);
}
+
+ @Test
+ public void testProto3DateTimeMessageUnwrapped() throws Exception {
+ Timestamp timestamp = Timestamps.parse("2021-05-02T15:04:03.748Z");
+ LocalDate date = LocalDate.of(2021, 5, 2);
+ LocalTime time = LocalTime.of(15, 4, 3, 748_000_000);
+
+ RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class);
+ Configuration conf = new Configuration();
+ ProtoWriteSupport.setUnwrapProtoWrappers(conf, true);
+ ProtoWriteSupport instance = createReadConsumerInstance(
+ TestProto3.DateTimeMessage.class, readConsumerMock, conf);
+
+ TestProto3.DateTimeMessage.Builder msg = TestProto3.DateTimeMessage.newBuilder();
+ msg.setTimestamp(timestamp);
+ msg.setDate(com.google.type.Date.newBuilder()
+ .setYear(date.getYear())
+ .setMonth(date.getMonthValue())
+ .setDay(date.getDayOfMonth())
+ );
+ msg.setTime(com.google.type.TimeOfDay.newBuilder()
+ .setHours(time.getHour())
+ .setMinutes(time.getMinute())
+ .setSeconds(time.getSecond())
+ .setNanos(time.getNano())
+ );
+ instance.write(msg.build());
+
+ InOrder inOrder = Mockito.inOrder(readConsumerMock);
+
+ inOrder.verify(readConsumerMock).startMessage();
+ inOrder.verify(readConsumerMock).startField("timestamp", 0);
+ inOrder.verify(readConsumerMock).addLong(Timestamps.toNanos(timestamp));
+ inOrder.verify(readConsumerMock).endField("timestamp", 0);
+ inOrder.verify(readConsumerMock).startField("date", 1);
+ inOrder.verify(readConsumerMock).addInteger((int) date.toEpochDay());
+ inOrder.verify(readConsumerMock).endField("date", 1);
+ inOrder.verify(readConsumerMock).startField("time", 2);
+ inOrder.verify(readConsumerMock).addLong(time.toNanoOfDay());
+ inOrder.verify(readConsumerMock).endField("time", 2);
+ inOrder.verify(readConsumerMock).endMessage();
+ Mockito.verifyNoMoreInteractions(readConsumerMock);
+ }
+
+ @Test
+ public void testProto3DateTimeMessageRoundTrip() throws Exception {
+ Timestamp timestamp = Timestamps.parse("2021-05-02T15:04:03.748Z");
+ LocalDate date = LocalDate.of(2021, 5, 2);
+ LocalTime time = LocalTime.of(15, 4, 3, 748_000_000);
+ com.google.type.Date protoDate = com.google.type.Date.newBuilder()
+ .setYear(date.getYear())
+ .setMonth(date.getMonthValue())
+ .setDay(date.getDayOfMonth())
+ .build();
+ com.google.type.TimeOfDay protoTime = com.google.type.TimeOfDay.newBuilder()
+ .setHours(time.getHour())
+ .setMinutes(time.getMinute())
+ .setSeconds(time.getSecond())
+ .setNanos(time.getNano())
+ .build();
+
+ TestProto3.DateTimeMessage msg = TestProto3.DateTimeMessage.newBuilder()
+ .setTimestamp(timestamp)
+ .setDate(protoDate)
+ .setTime(protoTime)
+ .build();
+
+ //Write them out and read them back
+ Path tmpFilePath = TestUtils.someTemporaryFilePath();
+ ParquetWriter writer =
+ ProtoParquetWriter.builder(tmpFilePath)
+ .withMessage(TestProto3.DateTimeMessage.class)
+ .config(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, "true")
+ .build();
+ writer.write(msg);
+ writer.close();
+ List gotBack = TestUtils.readMessages(tmpFilePath, TestProto3.DateTimeMessage.class);
+
+ TestProto3.DateTimeMessage gotBackFirst = gotBack.get(0);
+ assertEquals(timestamp, gotBackFirst.getTimestamp());
+ assertEquals(protoDate, gotBackFirst.getDate());
+ assertEquals(protoTime, gotBackFirst.getTime());
+ }
+
+ @Test
+ public void testProto3WrappedMessageUnwrapped() throws Exception {
+ RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class);
+ Configuration conf = new Configuration();
+ ProtoWriteSupport.setUnwrapProtoWrappers(conf, true);
+ ProtoWriteSupport instance = createReadConsumerInstance(
+ TestProto3.WrappedMessage.class, readConsumerMock, conf);
+
+ TestProto3.WrappedMessage.Builder msg = TestProto3.WrappedMessage.newBuilder();
+ msg.setWrappedDouble(DoubleValue.of(3.1415));
+
+ instance.write(msg.build());
+
+ InOrder inOrder = Mockito.inOrder(readConsumerMock);
+
+ inOrder.verify(readConsumerMock).startMessage();
+ inOrder.verify(readConsumerMock).startField("wrappedDouble", 0);
+ inOrder.verify(readConsumerMock).addDouble(3.1415);
+ inOrder.verify(readConsumerMock).endField("wrappedDouble", 0);
+ inOrder.verify(readConsumerMock).endMessage();
+ Mockito.verifyNoMoreInteractions(readConsumerMock);
+ }
+
+ @Test
+ public void testProto3WrappedMessageUnwrappedRoundTrip() throws Exception {
+ TestProto3.WrappedMessage.Builder msg = TestProto3.WrappedMessage.newBuilder();
+ msg.setWrappedDouble(DoubleValue.of(0.577));
+ msg.setWrappedFloat(FloatValue.of(3.1415f));
+ msg.setWrappedInt64(Int64Value.of(1_000_000_000L * 4));
+ msg.setWrappedUInt64(UInt64Value.of(1_000_000_000L * 9));
+ msg.setWrappedInt32(Int32Value.of(1_000_000 * 3));
+ msg.setWrappedUInt32(UInt32Value.of(1_000_000 * 8));
+ msg.setWrappedBool(BoolValue.of(true));
+ msg.setWrappedString(StringValue.of("Good Will Hunting"));
+ msg.setWrappedBytes(BytesValue.of(ByteString.copyFrom("someText", "UTF-8")));
+
+ //Write them out and read them back
+ Path tmpFilePath = TestUtils.someTemporaryFilePath();
+ ParquetWriter writer =
+ ProtoParquetWriter.builder(tmpFilePath)
+ .withMessage(TestProto3.WrappedMessage.class)
+ .config(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, "true")
+ .build();
+ writer.write(msg);
+ writer.close();
+ List gotBack = TestUtils.readMessages(tmpFilePath, TestProto3.WrappedMessage.class);
+
+ TestProto3.WrappedMessage gotBackFirst = gotBack.get(0);
+ assertEquals(0.577, gotBackFirst.getWrappedDouble().getValue(), 1e-5);
+ assertEquals(3.1415f, gotBackFirst.getWrappedFloat().getValue(), 1e-5f);
+ assertEquals(1_000_000_000L * 4, gotBackFirst.getWrappedInt64().getValue());
+ assertEquals(1_000_000_000L * 9, gotBackFirst.getWrappedUInt64().getValue());
+ assertEquals(1_000_000 * 3, gotBackFirst.getWrappedInt32().getValue());
+ assertEquals(1_000_000 * 8, gotBackFirst.getWrappedUInt32().getValue());
+ assertEquals(BoolValue.of(true), gotBackFirst.getWrappedBool());
+ assertEquals("Good Will Hunting", gotBackFirst.getWrappedString().getValue());
+ assertEquals(ByteString.copyFrom("someText", "UTF-8"), gotBackFirst.getWrappedBytes().getValue());
+ }
+
+ @Test
+ public void testProto3WrappedMessageWithNullsRoundTrip() throws Exception {
+ TestProto3.WrappedMessage.Builder msg = TestProto3.WrappedMessage.newBuilder();
+ msg.setWrappedFloat(FloatValue.of(3.1415f));
+ msg.setWrappedString(StringValue.of("Good Will Hunting"));
+ msg.setWrappedInt32(Int32Value.of(0));
+
+ //Write them out and read them back
+ Path tmpFilePath = TestUtils.someTemporaryFilePath();
+ ParquetWriter writer =
+ ProtoParquetWriter.builder(tmpFilePath)
+ .withMessage(TestProto3.WrappedMessage.class)
+ .config(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, "true")
+ .build();
+ writer.write(msg);
+ writer.close();
+ List gotBack = TestUtils.readMessages(tmpFilePath, TestProto3.WrappedMessage.class);
+
+ TestProto3.WrappedMessage gotBackFirst = gotBack.get(0);
+ assertFalse(gotBackFirst.hasWrappedDouble());
+ assertEquals(3.1415f, gotBackFirst.getWrappedFloat().getValue(), 1e-5f);
+
+ // double-check that nulls are honored
+ assertTrue(gotBackFirst.hasWrappedFloat());
+ assertFalse(gotBackFirst.hasWrappedInt64());
+ assertFalse(gotBackFirst.hasWrappedUInt64());
+ assertTrue(gotBackFirst.hasWrappedInt32());
+ assertFalse(gotBackFirst.hasWrappedUInt32());
+ assertEquals(0, gotBackFirst.getWrappedUInt32().getValue());
+ assertFalse(gotBackFirst.hasWrappedBool());
+ assertEquals("Good Will Hunting", gotBackFirst.getWrappedString().getValue());
+ assertFalse(gotBackFirst.hasWrappedBytes());
+ }
}
diff --git a/parquet-protobuf/src/test/resources/TestProto3.proto b/parquet-protobuf/src/test/resources/TestProto3.proto
index fb4da1b0c1..c303fd1f5d 100644
--- a/parquet-protobuf/src/test/resources/TestProto3.proto
+++ b/parquet-protobuf/src/test/resources/TestProto3.proto
@@ -23,6 +23,11 @@ package TestProto3;
option java_package = "org.apache.parquet.proto.test";
+import "google/protobuf/timestamp.proto";
+import "google/protobuf/wrappers.proto";
+import "google/type/date.proto";
+import "google/type/timeofday.proto";
+
// original Dremel paper structures: Original paper used groups, not internal
// messages but groups were deprecated.
@@ -156,3 +161,21 @@ message FirstCustomClassMessage {
message SecondCustomClassMessage {
string string = 11;
}
+
+message DateTimeMessage {
+ google.protobuf.Timestamp timestamp = 1;
+ google.type.Date date = 2;
+ google.type.TimeOfDay time = 3;
+}
+
+message WrappedMessage {
+ google.protobuf.DoubleValue wrappedDouble = 1;
+ google.protobuf.FloatValue wrappedFloat = 2;
+ google.protobuf.Int64Value wrappedInt64 = 3;
+ google.protobuf.UInt64Value wrappedUInt64 = 4;
+ google.protobuf.Int32Value wrappedInt32 = 5;
+ google.protobuf.UInt32Value wrappedUInt32 = 6;
+ google.protobuf.BoolValue wrappedBool = 7;
+ google.protobuf.StringValue wrappedString = 8;
+ google.protobuf.BytesValue wrappedBytes = 9;
+}