Skip to content

Commit

Permalink
Address comments and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiashen committed Sep 27, 2023
1 parent 6b47de8 commit efe5d34
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,9 @@

package org.apache.parquet.column.statistics;

import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Types;

public class Float16Statistics extends BinaryStatistics
{
// A fake type object to be used to generate the proper comparator
private static final PrimitiveType DEFAULT_FAKE_TYPE = Types.optional(PrimitiveType.PrimitiveTypeName.BINARY)
.named("fake_binary_float16_type").withLogicalTypeAnnotation(LogicalTypeAnnotation.float16Type());

/**
* @deprecated will be removed in 2.0.0. Use {@link Statistics#createStats(org.apache.parquet.schema.Type)} instead
*/
@Deprecated
public Float16Statistics() {
this(DEFAULT_FAKE_TYPE);
}

public class Float16Statistics extends BinaryStatistics {
Float16Statistics(PrimitiveType type) {
super(type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public Statistics<?> build() {
private static class Float16Builder extends Builder {
public Float16Builder(PrimitiveType type) {
super(type);
assert type.getPrimitiveTypeName() == PrimitiveTypeName.BINARY;
assert type.getPrimitiveTypeName() == PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY;
}

@Override
Expand All @@ -157,18 +157,18 @@ public Statistics<?> build() {
short max = Float16.fromBytesLittleEndian(bMax.getBytes());
// Drop min/max values in case of NaN as the sorting order of values is undefined for this case
if (Float16.isNaN(min) || Float16.isNaN(max)) {
bMin = Binary.fromConstantByteArray(Float16.toBytesLittleEndian(Float16.POSITIVE_ZERO));
bMax = Binary.fromConstantByteArray(Float16.toBytesLittleEndian(Float16.POSITIVE_ZERO));
bMin = Binary.fromConstantByteArray(Float16.POSITIVE_ZERO_BYTES_LITTLE_ENDIAN);
bMax = Binary.fromConstantByteArray(Float16.POSITIVE_ZERO_BYTES_LITTLE_ENDIAN);
stats.setMinMax(bMin, bMax);
((Statistics<?>) stats).hasNonNullValue = false;
} else {
// Updating min to -0.0 and max to +0.0 to ensure that no 0.0 values would be skipped
if (Float16.equals(min, Float16.POSITIVE_ZERO)) {
bMin = Binary.fromConstantByteArray(Float16.toBytesLittleEndian(Float16.NEGATIVE_ZERO));
if (min == Float16.POSITIVE_ZERO) {
bMin = Binary.fromConstantByteArray(Float16.NEGATIVE_ZERO_BYTES_LITTLE_ENDIAN);
stats.setMinMax(bMin, bMax);
}
if (Float16.equals(max, Float16.NEGATIVE_ZERO)) {
bMax = Binary.fromConstantByteArray(Float16.toBytesLittleEndian(Float16.POSITIVE_ZERO));
if (max == Float16.NEGATIVE_ZERO) {
bMax = Binary.fromConstantByteArray(Float16.POSITIVE_ZERO_BYTES_LITTLE_ENDIAN);
stats.setMinMax(bMin, bMax);
}
}
Expand Down Expand Up @@ -268,7 +268,7 @@ public static Builder getBuilderForReading(PrimitiveType type) {
return new FloatBuilder(type);
case DOUBLE:
return new DoubleBuilder(type);
case BINARY:
case FIXED_LEN_BYTE_ARRAY:
LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation();
if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.Float16LogicalTypeAnnotation) {
return new Float16Builder(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,7 @@ public String toString() {
static final PrimitiveComparator<Binary> BINARY_AS_FLOAT16_COMPARATOR = new BinaryComparator() {

@Override
int compareBinary(Binary b1, Binary b2)
{
int compareBinary(Binary b1, Binary b2) {
return Float16.compare(Float16.fromBytesLittleEndian(b1.getBytes()),
Float16.fromBytesLittleEndian(b2.getBytes()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,6 @@ public Optional<PrimitiveComparator> visit(LogicalTypeAnnotation.JsonLogicalType
public Optional<PrimitiveComparator> visit(LogicalTypeAnnotation.BsonLogicalTypeAnnotation bsonLogicalType) {
return of(PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR);
}

@Override
public Optional<PrimitiveComparator> visit(LogicalTypeAnnotation.Float16LogicalTypeAnnotation float16LogicalType) {
return of(PrimitiveComparator.BINARY_AS_FLOAT16_COMPARATOR);
}
}).orElseThrow(() -> new ShouldNeverHappenException("No comparator logic implemented for BINARY logical type: " + logicalType));
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,14 @@ public void testBinaryAsSignedIntegerComparatorWithEquals() {
@Test
public void testFloat16Comparator() {
short[] valuesInAscendingOrder = {
(short) 0xfc00,
Float16.MIN_VALUE,
-Float16.MAX_VALUE,
(short) 0xc000,
-Float16.MIN_VALUE,
0,
Float16.MIN_VALUE,
(short) 0x7bff,
Float16.MAX_VALUE,
(short) 0x7c00};
(short) 0xfc00, // -Infinity
(short) 0xc000, // -2.0
-Float16.MAX_VALUE, // -6.109476E-5
Float16.NEGATIVE_ZERO, // -0
Float16.POSITIVE_ZERO, // +0
Float16.MIN_VALUE, // 5.9604645E-8
Float16.MAX_VALUE, // 65504.0
(short) 0x7c00}; // Infinity

for (int i = 0; i < valuesInAscendingOrder.length; ++i) {
for (int j = 0; j < valuesInAscendingOrder.length; ++j) {
Expand All @@ -297,6 +295,10 @@ public void testFloat16Comparator() {
float fj = Float16.toFloat(vj);
assertEquals(Float.compare(fi, fj), BINARY_AS_FLOAT16_COMPARATOR.compare(
Binary.fromConstantByteArray(bbi.array()), Binary.fromConstantByteArray(bbj.array())));
if (i < j) {
assertEquals(-1, Float.compare(fi, fj));
}

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;

/**
* The class is a utility class to manipulate half-precision 16-bit
Expand All @@ -46,8 +47,7 @@
* floating points (float32).
* Ref: https://android.googlesource.com/platform/libcore/+/master/luni/src/main/java/libcore/util/FP16.java
*/
public class Float16
{
public class Float16 {
// Smallest negative value a half-precision float may have.
public static final short LOWEST_VALUE = (short) 0xfbff;
// Maximum positive finite value a half-precision float may have.
Expand All @@ -60,6 +60,10 @@ public class Float16
public static final short POSITIVE_ZERO = (short) 0x0000;
// Negative 0 of type half-precision float.
public static final short NEGATIVE_ZERO = (short) 0x8000;
// Byte array in little endian for positive 0 of type half-precision float.
public static final byte[] POSITIVE_ZERO_BYTES_LITTLE_ENDIAN = Float16.toBytesLittleEndian(Float16.POSITIVE_ZERO);
// Byte array in little endian for negative 0 of type half-precision float.
public static final byte[] NEGATIVE_ZERO_BYTES_LITTLE_ENDIAN = Float16.toBytesLittleEndian(Float16.NEGATIVE_ZERO);
// A Not-a-Number representation of a half-precision float.
static final short NaN = (short) 0x7e00;
// Positive infinity of type half-precision float.
Expand Down Expand Up @@ -304,7 +308,7 @@ public static int compare(short x, short y) {
*/
public static short fromBytesLittleEndian(byte[] bytes) {
if (bytes.length != 2) {
throw new InvalidFloat16ValueException(String.valueOf(bytes));
throw new InvalidFloat16ValueException(Arrays.toString(bytes));
}

ByteBuffer buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
/**
* Thrown if Binary is invalid as a Float16 value.
*/
public class InvalidFloat16ValueException extends ParquetRuntimeException
{
public class InvalidFloat16ValueException extends ParquetRuntimeException {
private static final long serialVersionUID = 1L;

public InvalidFloat16ValueException(String message) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import static org.apache.parquet.type.Float16.NaN;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;

public class TestFloat16
{
public class TestFloat16 {
@Test
public void testFloat16ToFloat() {
// Zeroes
Expand Down Expand Up @@ -241,4 +241,19 @@ public void testCompare() {
assertEquals(1, Float16.compare(toFloat16(12.462f), toFloat16(-12.462f)));
assertEquals(-1, Float16.compare(toFloat16(-12.462f), toFloat16(12.462f)));
}

@Test
public void testFromBytesLittleEndian() {
// bytes of 0xfbff stored in Little-Endian
byte[] float16Value = new byte[] {-1, -5};
short h = Float16.fromBytesLittleEndian(float16Value);
assertEquals(LOWEST_VALUE, h);
byte[] wrongFloat16Bytes = new byte[] {0, 0, 0};
try {
Float16.fromBytesLittleEndian(wrongFloat16Bytes);
fail("Invalid float16 value");
} catch (InvalidFloat16ValueException e) {
assertTrue(e.getMessage().contains("[0, 0, 0] is invalid as a Float16 value"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.parquet.statistics;

import org.apache.parquet.Preconditions;
import org.apache.parquet.example.data.Group;
import org.apache.parquet.example.data.GroupFactory;
import org.apache.parquet.example.data.simple.SimpleGroupFactory;
import org.apache.parquet.hadoop.ParquetFileReader;
import org.apache.parquet.hadoop.ParquetWriter;
import org.apache.parquet.internal.column.columnindex.ColumnIndex;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Types;
import org.apache.parquet.type.Float16;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.column.statistics.Statistics;
import org.apache.parquet.hadoop.example.ExampleParquetWriter;
import org.apache.parquet.hadoop.example.GroupWriteSupport;
import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData;
import org.apache.parquet.hadoop.util.HadoopInputFile;

import static org.apache.parquet.schema.LogicalTypeAnnotation.float16Type;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY;
import static org.junit.Assert.assertEquals;

public class TestFloat16Statistics {

@Rule
public TemporaryFolder temp = new TemporaryFolder();

private short[] valuesInAscendingOrder = {
(short) 0xfc00, // -Infinity
(short) 0xc000, // -2.0
-Float16.MAX_VALUE, // -6.109476E-5
Float16.NEGATIVE_ZERO, // -0
Float16.POSITIVE_ZERO, // +0
Float16.MIN_VALUE, // 5.9604645E-8
Float16.MAX_VALUE, // 65504.0
(short) 0x7c00}; // Infinity

@Test
public void testFloat16ColumnIndex() throws IOException
{
MessageType schema = Types.buildMessage().
required(FIXED_LEN_BYTE_ARRAY).as(float16Type()).length(2).named("col_float16").named("msg");

Configuration conf = new Configuration();
GroupWriteSupport.setSchema(schema, conf);

GroupFactory factory = new SimpleGroupFactory(schema);
Path path = newTempPath();
try (ParquetWriter<Group> writer = ExampleParquetWriter.builder(path)
.withConf(conf)
.withDictionaryEncoding(false)
.build()) {

for (short value : valuesInAscendingOrder) {
writer.write(factory.newGroup().append("col_float16", Binary.fromConstantByteArray(Float16.toBytesLittleEndian(value))));
}
}

try (ParquetFileReader reader = ParquetFileReader.open(HadoopInputFile.fromPath(path, new Configuration()))) {

ColumnChunkMetaData column = reader.getFooter().getBlocks().get(0).getColumns().get(0);
ColumnIndex index = reader.readColumnIndex(column);
assertEquals(Collections.singletonList((short) 0xfc00), toFloat16List(index.getMinValues()));
assertEquals(Collections.singletonList((short) 0x7c00), toFloat16List(index.getMaxValues()));
}
}

@Test
public void testFloat16Statistics() throws IOException {
for (int i = 0; i < valuesInAscendingOrder.length; ++i) {
for (int j = 0; j < valuesInAscendingOrder.length; ++j) {
int minIndex = i;
int maxIndex = j;

if (Float16.compare(valuesInAscendingOrder[i], valuesInAscendingOrder[j]) > 0) {
minIndex = j;
maxIndex = i;
}

// Refer to Float16Builder class
if (valuesInAscendingOrder[minIndex] == Float16.POSITIVE_ZERO) {
minIndex = 3;
}
if (valuesInAscendingOrder[maxIndex] == Float16.NEGATIVE_ZERO) {
maxIndex = 4;
}

MessageType schema = Types.buildMessage().
required(FIXED_LEN_BYTE_ARRAY).as(float16Type()).length(2).named("col_float16").named("msg");

Configuration conf = new Configuration();
GroupWriteSupport.setSchema(schema, conf);

GroupFactory factory = new SimpleGroupFactory(schema);
Path path = newTempPath();
try (ParquetWriter<Group> writer = ExampleParquetWriter.builder(path)
.withConf(conf)
.withDictionaryEncoding(false)
.build()) {
writer.write(factory.newGroup().append("col_float16", Binary.fromConstantByteArray(Float16.toBytesLittleEndian(valuesInAscendingOrder[i]))));
writer.write(factory.newGroup().append("col_float16", Binary.fromConstantByteArray(Float16.toBytesLittleEndian(valuesInAscendingOrder[j]))));
}

try (ParquetFileReader reader = ParquetFileReader.open(HadoopInputFile.fromPath(path, new Configuration()))) {
ColumnChunkMetaData column = reader.getFooter().getBlocks().get(0).getColumns().get(0);
Statistics<?> statistics = column.getStatistics();

assertEquals(valuesInAscendingOrder[minIndex], Float16.fromBytesLittleEndian(statistics.getMinBytes()));
assertEquals(valuesInAscendingOrder[maxIndex], Float16.fromBytesLittleEndian(statistics.getMaxBytes()));
}
}
}
}

private Path newTempPath() throws IOException {
File file = temp.newFile();
Preconditions.checkArgument(file.delete(), "Could not remove temp file");
return new Path(file.getAbsolutePath());
}

private static List<Short> toFloat16List(List<ByteBuffer> buffers) {
return buffers.stream()
.map(buffer -> Float16.fromBytesLittleEndian(buffer.array()))
.collect(Collectors.toList());
}
}

0 comments on commit efe5d34

Please sign in to comment.