Skip to content

Commit

Permalink
Fix Partition<Filtered>, add separate sliceByCached for joins, filter…
Browse files Browse the repository at this point in the history
…ed and nested filtered (#8626)

* Introduce separate sliceByCached for joins, filtered, partitions. Fix
sliceByCached for partitions.

* Simplify logics

* Nitpick

* Lambda capture
  • Loading branch information
saganatt authored Apr 22, 2022
1 parent b6c90b3 commit 360de51
Showing 1 changed file with 91 additions and 28 deletions.
119 changes: 91 additions & 28 deletions Framework/Core/include/Framework/ASoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -1189,14 +1189,12 @@ class Table
{
uint64_t offset = 0;
std::shared_ptr<arrow::Table> result = nullptr;
auto status = this->getSliceFor(value, node.name.c_str(), result, offset);
if (status.ok()) {
auto t = table_t({result}, offset);
copyIndexBindings(t);
return t;
if (!this->getSliceFor(value, node.name.c_str(), result, offset).ok()) {
o2::framework::throw_error(o2::framework::runtime_error("Failed to slice table"));
}
o2::framework::throw_error(o2::framework::runtime_error("Failed to slice table"));
O2_BUILTIN_UNREACHABLE();
auto t = table_t({result}, offset);
copyIndexBindings(t);
return t;
}

auto sliceBy(framework::expressions::BindingNode const& node, int value) const
Expand Down Expand Up @@ -2126,6 +2124,18 @@ struct Join : JoinBase<Ts...> {
using const_iterator = iterator;
using filtered_iterator = typename table_t::template RowViewFiltered<Join<Ts...>, Ts...>;
using filtered_const_iterator = filtered_iterator;

auto sliceByCached(framework::expressions::BindingNode const& node, int value)
{
uint64_t offset = 0;
std::shared_ptr<arrow::Table> result = nullptr;
if (!this->getSliceFor(value, node.name.c_str(), result, offset).ok()) {
o2::framework::throw_error(o2::framework::runtime_error("Failed to slice table"));
}
auto t = Join<Ts...>({result}, offset);
this->copyIndexBindings(t);
return t;
}
};

template <typename... Ts>
Expand Down Expand Up @@ -2318,23 +2328,26 @@ class FilteredBase : public T
{
uint64_t offset = 0;
std::shared_ptr<arrow::Table> result = nullptr;
auto status = ((table_t*)this)->getSliceFor(value, node.name.c_str(), result, offset);
if (status.ok()) {
auto start = offset;
auto end = start + result->num_rows();
auto start_iterator = std::lower_bound(mSelectedRows.begin(), mSelectedRows.end(), start);
auto stop_iterator = std::lower_bound(start_iterator, mSelectedRows.end(), end);
SelectionVector slicedSelection{start_iterator, stop_iterator};
std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(),
[&](int64_t idx) {
return idx - static_cast<int64_t>(start);
});
self_t fresult{{result}, std::move(slicedSelection), start};
copyIndexBindings(fresult);
if (!((table_t*)this)->getSliceFor(value, node.name.c_str(), result, offset).ok()) {
o2::framework::throw_error(o2::framework::runtime_error("Failed to slice table"));
}
if (offset >= this->tableSize()) {
self_t fresult{{result}, SelectionVector{}, 0}; // empty slice
this->copyIndexBindings(fresult);
return fresult;
}
o2::framework::throw_error(o2::framework::runtime_error("Failed to slice table"));
O2_BUILTIN_UNREACHABLE();
auto start = offset;
auto end = start + result->num_rows();
auto start_iterator = std::lower_bound(mSelectedRows.begin(), mSelectedRows.end(), start);
auto stop_iterator = std::lower_bound(start_iterator, mSelectedRows.end(), end);
SelectionVector slicedSelection{start_iterator, stop_iterator};
std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(),
[&start](int64_t idx) {
return idx - static_cast<int64_t>(start);
});
self_t fresult{{result}, std::move(slicedSelection), start};
copyIndexBindings(fresult);
return fresult;
}

auto sliceBy(framework::expressions::BindingNode const& node, int value) const
Expand Down Expand Up @@ -2440,6 +2453,8 @@ class Filtered : public FilteredBase<T>
{
public:
using self_t = Filtered<T>;
using table_t = typename FilteredBase<T>::table_t;

Filtered(std::vector<std::shared_ptr<arrow::Table>>&& tables, gandiva::Selection const& selection, uint64_t offset = 0)
: FilteredBase<T>(std::move(tables), selection, offset) {}

Expand Down Expand Up @@ -2520,7 +2535,33 @@ class Filtered : public FilteredBase<T>
{
return operator*=(other.getSelectedRows());
}
using FilteredBase<T>::sliceByCached;

auto sliceByCached(framework::expressions::BindingNode const& node, int value)
{
uint64_t offset = 0;
std::shared_ptr<arrow::Table> result = nullptr;
if (!((table_t*)this)->getSliceFor(value, node.name.c_str(), result, offset).ok()) {
o2::framework::throw_error(o2::framework::runtime_error("Failed to slice table"));
}
if (offset >= this->tableSize()) {
self_t fresult{{result}, SelectionVector{}, 0}; // empty slice
this->copyIndexBindings(fresult);
return fresult;
}
auto start = offset;
auto end = start + result->num_rows();
auto start_iterator = std::lower_bound(this->getSelectedRows().begin(), this->getSelectedRows().end(), start);
auto stop_iterator = std::lower_bound(start_iterator, this->getSelectedRows().end(), end);
SelectionVector slicedSelection{start_iterator, stop_iterator};
std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(),
[&start](int64_t idx) {
return idx - static_cast<int64_t>(start);
});
auto slicedSize = slicedSelection.size();
self_t fresult{{result}, std::move(slicedSelection), start};
this->copyIndexBindings(fresult);
return fresult;
}
};

template <typename T>
Expand All @@ -2531,23 +2572,23 @@ class Filtered<Filtered<T>> : public FilteredBase<typename T::table_t>
using table_t = typename FilteredBase<typename T::table_t>::table_t;

Filtered(std::vector<Filtered<T>>&& tables, gandiva::Selection const& selection, uint64_t offset = 0)
: FilteredBase<typename T::table_t>(std::move(extractTablesFromFiltered(std::move(tables))), selection, offset)
: FilteredBase<typename T::table_t>(std::move(extractTablesFromFiltered(tables)), selection, offset)
{
for (auto& table : tables) {
*this *= table;
}
}

Filtered(std::vector<Filtered<T>>&& tables, SelectionVector&& selection, uint64_t offset = 0)
: FilteredBase<typename T::table_t>(std::move(extractTablesFromFiltered(std::move(tables))), std::forward<SelectionVector>(selection), offset)
: FilteredBase<typename T::table_t>(std::move(extractTablesFromFiltered(tables)), std::forward<SelectionVector>(selection), offset)
{
for (auto& table : tables) {
*this *= table;
}
}

Filtered(std::vector<Filtered<T>>&& tables, gsl::span<int64_t const> const& selection, uint64_t offset = 0)
: FilteredBase<typename T::table_t>(std::move(extractTablesFromFiltered(std::move(tables))), selection, offset)
: FilteredBase<typename T::table_t>(std::move(extractTablesFromFiltered(tables)), selection, offset)
{
for (auto& table : tables) {
*this *= table;
Expand Down Expand Up @@ -2626,10 +2667,32 @@ class Filtered<Filtered<T>> : public FilteredBase<typename T::table_t>
return operator*=(other.getSelectedRows());
}

using FilteredBase<typename T::table_t>::sliceByCached;
auto sliceByCached(framework::expressions::BindingNode const& node, int value)
{
uint64_t offset = 0;
std::shared_ptr<arrow::Table> result = nullptr;
if (!((table_t*)this)->getSliceFor(value, node.name.c_str(), result, offset).ok()) {
o2::framework::throw_error(o2::framework::runtime_error("Failed to slice table"));
}
auto start = offset;
auto end = start + result->num_rows();
auto start_iterator = std::lower_bound(this->getSelectedRows().begin(), this->getSelectedRows().end(), start);
auto stop_iterator = std::lower_bound(start_iterator, this->getSelectedRows().end(), end);
SelectionVector slicedSelection{start_iterator, stop_iterator};
std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(),
[&start](int64_t idx) {
return idx - static_cast<int64_t>(start);
});
SelectionVector copy = slicedSelection;
Filtered<T> filteredTable{{result}, std::move(slicedSelection), start};
std::vector<Filtered<T>> filtered{filteredTable};
self_t fresult{std::move(filtered), std::move(copy), start};
this->copyIndexBindings(fresult);
return fresult;
}

private:
std::vector<std::shared_ptr<arrow::Table>> extractTablesFromFiltered(std::vector<Filtered<T>>&& tables)
std::vector<std::shared_ptr<arrow::Table>> extractTablesFromFiltered(std::vector<Filtered<T>>& tables)
{
std::vector<std::shared_ptr<arrow::Table>> outTables;
for (auto& table : tables) {
Expand Down

0 comments on commit 360de51

Please sign in to comment.