Skip to content

Commit

Permalink
PARQUET-1647: Add logical type FLOAT16
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiashen committed Sep 24, 2023
1 parent 9b5a962 commit c170a47
Show file tree
Hide file tree
Showing 19 changed files with 1,228 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.parquet.column.statistics;

import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Types;

public class Float16Statistics extends BinaryStatistics
{
// A fake type object to be used to generate the proper comparator
private static final PrimitiveType DEFAULT_FAKE_TYPE = Types.optional(PrimitiveType.PrimitiveTypeName.BINARY)
.named("fake_binary_float16_type").withLogicalTypeAnnotation(LogicalTypeAnnotation.float16Type());

/**
* @deprecated will be removed in 2.0.0. Use {@link Statistics#createStats(org.apache.parquet.schema.Type)} instead
*/
@Deprecated
public Float16Statistics() {
this(DEFAULT_FAKE_TYPE);
}

Float16Statistics(PrimitiveType type) {
super(type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
import java.util.Arrays;
import org.apache.parquet.column.UnknownColumnTypeException;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveComparator;
import org.apache.parquet.schema.PrimitiveStringifier;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
import org.apache.parquet.schema.Type;

import org.apache.parquet.type.Float16;

/**
* Statistics class to keep track of statistics in parquet pages and column chunks
Expand Down Expand Up @@ -139,6 +140,43 @@ public Statistics<?> build() {
}
}

// Builder for FLOAT16 type to handle special cases of min/max values like NaN, -0.0, and 0.0
private static class Float16Builder extends Builder {
public Float16Builder(PrimitiveType type) {
super(type);
assert type.getPrimitiveTypeName() == PrimitiveTypeName.BINARY;
}

@Override
public Statistics<?> build() {
Float16Statistics stats = (Float16Statistics) super.build();
if (stats.hasNonNullValue()) {
Binary bMin = stats.genericGetMin();
Binary bMax = stats.genericGetMax();
short min = Float16.fromBytesLittleEndian(bMin.getBytes());
short max = Float16.fromBytesLittleEndian(bMax.getBytes());
// Drop min/max values in case of NaN as the sorting order of values is undefined for this case
if (Float16.isNaN(min) || Float16.isNaN(max)) {
bMin = Binary.fromConstantByteArray(Float16.toBytesLittleEndian(Float16.POSITIVE_ZERO));
bMax = Binary.fromConstantByteArray(Float16.toBytesLittleEndian(Float16.POSITIVE_ZERO));
stats.setMinMax(bMin, bMax);
((Statistics<?>) stats).hasNonNullValue = false;
} else {
// Updating min to -0.0 and max to +0.0 to ensure that no 0.0 values would be skipped
if (Float16.equals(min, Float16.POSITIVE_ZERO)) {
bMin = Binary.fromConstantByteArray(Float16.toBytesLittleEndian(Float16.NEGATIVE_ZERO));
stats.setMinMax(bMin, bMax);
}
if (Float16.equals(max, Float16.NEGATIVE_ZERO)) {
bMax = Binary.fromConstantByteArray(Float16.toBytesLittleEndian(Float16.POSITIVE_ZERO));
stats.setMinMax(bMin, bMax);
}
}
}
return stats;
}
}

private final PrimitiveType type;
private final PrimitiveComparator<T> comparator;
private boolean hasNonNullValue;
Expand Down Expand Up @@ -207,6 +245,10 @@ public static Statistics<?> createStats(Type type) {
case BINARY:
case INT96:
case FIXED_LEN_BYTE_ARRAY:
LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation();
if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.Float16LogicalTypeAnnotation) {
return new Float16Statistics(primitive);
}
return new BinaryStatistics(primitive);
default:
throw new UnknownColumnTypeException(primitive.getPrimitiveTypeName());
Expand All @@ -226,6 +268,11 @@ public static Builder getBuilderForReading(PrimitiveType type) {
return new FloatBuilder(type);
case DOUBLE:
return new DoubleBuilder(type);
case BINARY:
LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation();
if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.Float16LogicalTypeAnnotation) {
return new Float16Builder(type);
}
default:
return new Builder(type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ protected LogicalTypeAnnotation fromString(List<String> params) {
protected LogicalTypeAnnotation fromString(List<String> params) {
return IntervalLogicalTypeAnnotation.getInstance();
}
},
FLOAT16 {
@Override
protected LogicalTypeAnnotation fromString(List<String> params) {
return float16Type();
}
};

protected abstract LogicalTypeAnnotation fromString(List<String> params);
Expand Down Expand Up @@ -296,6 +302,10 @@ public static UUIDLogicalTypeAnnotation uuidType() {
return UUIDLogicalTypeAnnotation.INSTANCE;
}

public static Float16LogicalTypeAnnotation float16Type() {
return Float16LogicalTypeAnnotation.INSTANCE;
}

public static class StringLogicalTypeAnnotation extends LogicalTypeAnnotation {
private static final StringLogicalTypeAnnotation INSTANCE = new StringLogicalTypeAnnotation();

Expand Down Expand Up @@ -901,6 +911,36 @@ PrimitiveStringifier valueStringifier(PrimitiveType primitiveType) {
}
}

public static class Float16LogicalTypeAnnotation extends LogicalTypeAnnotation {
private static final Float16LogicalTypeAnnotation INSTANCE = new Float16LogicalTypeAnnotation();
public static final int BYTES = 2;

private Float16LogicalTypeAnnotation() {
}

@Override
@InterfaceAudience.Private
public OriginalType toOriginalType() {
// No OriginalType for Float16
return null;
}

@Override
public <T> Optional<T> accept(LogicalTypeAnnotationVisitor<T> logicalTypeAnnotationVisitor) {
return logicalTypeAnnotationVisitor.visit(this);
}

@Override
LogicalTypeToken getType() {
return LogicalTypeToken.FLOAT16;
}

@Override
PrimitiveStringifier valueStringifier(PrimitiveType primitiveType) {
return PrimitiveStringifier.FLOAT16_STRINGIFIER;
}
}

// This logical type annotation is implemented to support backward compatibility with ConvertedType.
// The new logical type representation in parquet-format doesn't have any interval type,
// thus this annotation is mapped to UNKNOWN.
Expand Down Expand Up @@ -1060,5 +1100,9 @@ default Optional<T> visit(IntervalLogicalTypeAnnotation intervalLogicalType) {
default Optional<T> visit(MapKeyValueTypeAnnotation mapKeyValueLogicalType) {
return empty();
}

default Optional<T> visit(Float16LogicalTypeAnnotation float16LogicalType) {
return empty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.nio.ByteBuffer;
import java.util.Comparator;

import org.apache.parquet.type.Float16;

/**
* {@link Comparator} implementation that also supports the comparison of the related primitive type to avoid the
* performance penalty of boxing/unboxing. The {@code compare} methods for the not supported primitive types throw
Expand Down Expand Up @@ -276,4 +278,22 @@ public String toString() {
return "BINARY_AS_SIGNED_INTEGER_COMPARATOR";
}
};

/**
* This comparator is for comparing two float16 values represented in 2 bytes binary.
*/
static final PrimitiveComparator<Binary> BINARY_AS_FLOAT16_COMPARATOR = new BinaryComparator() {

@Override
int compareBinary(Binary b1, Binary b2)
{
return Float16.compare(Float16.fromBytesLittleEndian(b1.getBytes()),
Float16.fromBytesLittleEndian(b2.getBytes()));
}

@Override
public String toString() {
return "BINARY_AS_FLOAT16_COMPARATOR";
}
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
import java.time.Instant;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.UUID;
import java.util.concurrent.TimeUnit;

import javax.naming.OperationNotSupportedException;

import org.apache.parquet.io.api.Binary;
import org.apache.parquet.type.Float16;

/**
* Class that provides string representations for the primitive values. These string values are to be used for
Expand Down Expand Up @@ -448,4 +448,12 @@ private void appendHex(byte[] array, int offset, int length, StringBuilder build
}
}
};

static final PrimitiveStringifier FLOAT16_STRINGIFIER = new BinaryStringifierBase("FLOAT16_STRINGIFIER") {

@Override
String stringifyNotNull(Binary value) {
return Float16.toFloatString(Float16.fromBytesLittleEndian(value.getBytes()));
}
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.apache.parquet.io.api.PrimitiveConverter;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.ColumnOrder.ColumnOrderName;
import org.apache.parquet.schema.LogicalTypeAnnotation.LogicalTypeAnnotationVisitor;
import org.apache.parquet.schema.LogicalTypeAnnotation.UUIDLogicalTypeAnnotation;

import static java.util.Optional.empty;
Expand Down Expand Up @@ -261,6 +260,11 @@ public Optional<PrimitiveComparator> visit(LogicalTypeAnnotation.JsonLogicalType
public Optional<PrimitiveComparator> visit(LogicalTypeAnnotation.BsonLogicalTypeAnnotation bsonLogicalType) {
return of(PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR);
}

@Override
public Optional<PrimitiveComparator> visit(LogicalTypeAnnotation.Float16LogicalTypeAnnotation float16LogicalType) {
return of(PrimitiveComparator.BINARY_AS_FLOAT16_COMPARATOR);
}
}).orElseThrow(() -> new ShouldNeverHappenException("No comparator logic implemented for BINARY logical type: " + logicalType));
}
},
Expand Down Expand Up @@ -390,6 +394,11 @@ public Optional<PrimitiveComparator> visit(LogicalTypeAnnotation.IntervalLogical
public Optional<PrimitiveComparator> visit(UUIDLogicalTypeAnnotation uuidLogicalType) {
return of(PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR);
}

@Override
public Optional<PrimitiveComparator> visit(LogicalTypeAnnotation.Float16LogicalTypeAnnotation float16LogicalType) {
return of(PrimitiveComparator.BINARY_AS_FLOAT16_COMPARATOR);
}
}).orElseThrow(() -> new ShouldNeverHappenException(
"No comparator logic implemented for FIXED_LEN_BYTE_ARRAY logical type: " + logicalType));
}
Expand Down Expand Up @@ -564,6 +573,14 @@ public PrimitiveType withId(int id) {
columnOrder);
}

/**
* @param logicalType LogicalTypeAnnotation
* @return a new PrimitiveType with the same fields and a new id null
*/
public PrimitiveType withLogicalTypeAnnotation(LogicalTypeAnnotation logicalType) {
return new PrimitiveType(getRepetition(), primitive, length, getName(), logicalType, getId());
}

/**
* @return the primitive type
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,11 @@ public Optional<Boolean> visit(LogicalTypeAnnotation.UUIDLogicalTypeAnnotation u
return checkFixedPrimitiveType(LogicalTypeAnnotation.UUIDLogicalTypeAnnotation.BYTES, uuidLogicalType);
}

@Override
public Optional<Boolean> visit(LogicalTypeAnnotation.Float16LogicalTypeAnnotation float16LogicalType) {
return checkFixedPrimitiveType(LogicalTypeAnnotation.Float16LogicalTypeAnnotation.BYTES, float16LogicalType);
}

@Override
public Optional<Boolean> visit(LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalLogicalType) {
Preconditions.checkState(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
package org.apache.parquet.schema;

import org.apache.parquet.io.api.Binary;
import org.apache.parquet.type.Float16;
import org.junit.Test;

import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;

import static org.apache.parquet.schema.PrimitiveComparator.BINARY_AS_FLOAT16_COMPARATOR;
import static org.apache.parquet.schema.PrimitiveComparator.BOOLEAN_COMPARATOR;
import static org.apache.parquet.schema.PrimitiveComparator.DOUBLE_COMPARATOR;
import static org.apache.parquet.schema.PrimitiveComparator.FLOAT_COMPARATOR;
Expand Down Expand Up @@ -268,6 +271,36 @@ public void testBinaryAsSignedIntegerComparatorWithEquals() {
}
}

@Test
public void testFloat16Comparator() {
short[] valuesInAscendingOrder = {
(short) 0xfc00,
Float16.MIN_VALUE,
-Float16.MAX_VALUE,
(short) 0xc000,
-Float16.MIN_VALUE,
0,
Float16.MIN_VALUE,
(short) 0x7bff,
Float16.MAX_VALUE,
(short) 0x7c00};

for (int i = 0; i < valuesInAscendingOrder.length; ++i) {
for (int j = 0; j < valuesInAscendingOrder.length; ++j) {
short vi = valuesInAscendingOrder[i];
short vj = valuesInAscendingOrder[j];
ByteBuffer bbi = ByteBuffer.allocate(2).order(ByteOrder.LITTLE_ENDIAN);
bbi.putShort(vi).flip();
ByteBuffer bbj = ByteBuffer.allocate(2).order(ByteOrder.LITTLE_ENDIAN);
bbj.putShort(vj).flip();
float fi = Float16.toFloat(vi);
float fj = Float16.toFloat(vj);
assertEquals(Float.compare(fi, fj), BINARY_AS_FLOAT16_COMPARATOR.compare(
Binary.fromConstantByteArray(bbi.array()), Binary.fromConstantByteArray(bbj.array())));
}
}
}

private <T> void testObjectComparator(PrimitiveComparator<T> comparator, T... valuesInAscendingOrder) {
for (int i = 0; i < valuesInAscendingOrder.length; ++i) {
for (int j = 0; j < valuesInAscendingOrder.length; ++j) {
Expand Down
Loading

0 comments on commit c170a47

Please sign in to comment.