diff --git a/cpp/src/lance/io/reader.cc b/cpp/src/lance/io/reader.cc index e3abb32195..303278bfa7 100644 --- a/cpp/src/lance/io/reader.cc +++ b/cpp/src/lance/io/reader.cc @@ -299,11 +299,15 @@ ::arrow::Result> FileReader::GetArray( storage_arr = GetPrimitiveArray(field, batch_id, params); } + if (!storage_arr.ok()) { + return storage_arr; + } + if (field->is_extension_type()) { - std::shared_ptr<::arrow::ExtensionType> ext_type = ::arrow::GetExtensionType( - field->extension_name()); + std::shared_ptr<::arrow::ExtensionType> ext_type = + ::arrow::GetExtensionType(field->extension_name()); if (ext_type != nullptr && storage_arr.ok()) { - return ::arrow::ExtensionType::WrapArray(ext_type, storage_arr.ValueOrDie()); + return ::arrow::ExtensionType::WrapArray(ext_type, storage_arr.ValueOrDie()); } } return storage_arr; @@ -335,11 +339,15 @@ ::arrow::Result> FileReader::GetListArray( "FileReader::GetListArray: indices is empty: field={}({})", field->name(), field->id())); } auto start = static_cast(indices->Value(0)); - auto length = static_cast(indices->Value(indices->length() - 1) - start); + auto length = static_cast(indices->Value(indices->length() - 1) - start + 1); + ARROW_ASSIGN_OR_RAISE(auto unfiltered_arr, GetListArray(field, batch_id, ArrayReadParams(start, length))); - ARROW_ASSIGN_OR_RAISE(auto datum, - ::arrow::compute::CallFunction("take", {unfiltered_arr, indices})); + ARROW_ASSIGN_OR_RAISE(auto offsets_datum, + ::arrow::compute::Subtract(indices, ::arrow::Datum(indices->Value(0)))); + ARROW_ASSIGN_OR_RAISE( + auto datum, + ::arrow::compute::CallFunction("take", {unfiltered_arr, offsets_datum.make_array()})); return datum.make_array(); } @@ -363,10 +371,11 @@ ::arrow::Result> FileReader::GetListArray( // Realigned offsets to be zero-started ARROW_ASSIGN_OR_RAISE(auto shifted_offsets, ResetOffsets(offsets)); // Setup null bitmap - ARROW_ASSIGN_OR_RAISE(auto null_bitmap, ::arrow::AllocateBitmap(shifted_offsets->length() - 1, pool_)); + ARROW_ASSIGN_OR_RAISE(auto null_bitmap, + ::arrow::AllocateBitmap(shifted_offsets->length() - 1, pool_)); for (int i = 0; i < shifted_offsets->length() - 1; i++) { - ::arrow::bit_util::SetBitTo(null_bitmap->mutable_data(), i, - offsets->Value(i + 1) - offsets->Value(i) > 0); + ::arrow::bit_util::SetBitTo( + null_bitmap->mutable_data(), i, offsets->Value(i + 1) - offsets->Value(i) > 0); } return std::make_shared<::arrow::ListArray>(field->type(), shifted_offsets->length() - 1, @@ -394,11 +403,10 @@ ::arrow::Result> FileReader::GetPrimitiveArray( decoder->Reset(position, length); decltype(decoder->ToArray()) result; if (params.indices) { - result = decoder->Take(params.indices.value()); + return decoder->Take(params.indices.value()); } else { - result = decoder->ToArray(params.offset.value(), params.length); + return decoder->ToArray(params.offset.value(), params.length); } - return result; } FileReader::ArrayReadParams::ArrayReadParams(int32_t off, std::optional len) diff --git a/cpp/src/lance/io/reader_test.cc b/cpp/src/lance/io/reader_test.cc index f60d717e92..b1427bc225 100644 --- a/cpp/src/lance/io/reader_test.cc +++ b/cpp/src/lance/io/reader_test.cc @@ -25,6 +25,7 @@ #include "lance/arrow/stl.h" #include "lance/arrow/type.h" #include "lance/arrow/writer.h" +#include "lance/io/reader.h" TEST_CASE("Test List Array With Nulls") { auto int_builder = std::make_shared<::arrow::Int32Builder>(); @@ -64,3 +65,38 @@ TEST_CASE("Test List Array With Nulls") { CHECK(scalar->Equals(::arrow::NullScalar())); } } + +TEST_CASE("Get List Array With Indices") { + auto value_builder = std::make_shared<::arrow::Int32Builder>(); + auto list_builder = ::arrow::ListBuilder(::arrow::default_memory_pool(), value_builder); + for (int i = 0; i < 10; i++) { + CHECK(list_builder.Append().ok()); + CHECK(value_builder->AppendValues({1 * i, 2 * i, 3 * i}).ok()); + } + + auto arr = std::static_pointer_cast<::arrow::ListArray>(list_builder.Finish().ValueOrDie()); + auto schema = ::arrow::schema({::arrow::field("values", ::arrow::list(::arrow::int32()))}); + auto table = ::arrow::Table::Make(schema, {arr}); + + auto sink = arrow::io::BufferOutputStream::Create().ValueOrDie(); + CHECK(lance::arrow::WriteTable(*table, sink).ok()); + auto infile = make_shared(sink->Finish().ValueOrDie()); + auto reader = lance::io::FileReader(infile); + CHECK(reader.Open().ok()); + + for (auto& indices : std::vector>({{0, 1, 3}, {2, 3, 4}, {0, 5, 9}})) { + list_builder.Reset(); + value_builder->Reset(); + for (int idx : indices) { + CHECK(list_builder.Append().ok()); + CHECK(value_builder->AppendValues({idx * 1, idx * 2, idx * 3}).ok()); + } + auto expected_arr = + std::static_pointer_cast<::arrow::ListArray>(list_builder.Finish().ValueOrDie()); + auto expected_table = ::arrow::Table::Make(schema, {expected_arr}); + + auto batch = reader.ReadBatch(reader.schema(), 0, lance::arrow::ToArray(indices).ValueOrDie()) + .ValueOrDie(); + CHECK(batch->Equals(*expected_table->CombineChunksToBatch().ValueOrDie())); + } +} \ No newline at end of file