Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update naming of bool template parameters, move internal methods to private #5

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,11 @@ if(ARROW_COMPUTE)
compute/exec/key_map.cc
compute/exec/options.cc
compute/exec/order_by_impl.cc
compute/exec/partition_util.cc
compute/exec/project_node.cc
compute/exec/sink_node.cc
compute/exec/source_node.cc
compute/exec/swiss_join.cc
compute/exec/task_util.cc
compute/exec/union_node.cc
compute/exec/util.cc
Expand Down Expand Up @@ -445,6 +447,7 @@ if(ARROW_COMPUTE)
append_avx2_src(compute/exec/key_encode_avx2.cc)
append_avx2_src(compute/exec/key_hash_avx2.cc)
append_avx2_src(compute/exec/key_map_avx2.cc)
append_avx2_src(compute/exec/swiss_join_avx2.cc)
append_avx2_src(compute/exec/util_avx2.cc)

list(APPEND ARROW_TESTING_SRCS compute/exec/test_util.cc)
Expand Down
102 changes: 47 additions & 55 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
}

Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
size_t num_threads, const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
Expand All @@ -98,7 +99,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
ctx_ = ctx;
join_type_ = join_type;
num_threads_ = num_threads;
schema_mgr_ = schema_mgr;
schema_[0] = proj_map_left;
schema_[1] = proj_map_right;
key_cmp_ = std::move(key_cmp);
filter_ = std::move(filter);
output_batch_callback_ = std::move(output_batch_callback);
Expand Down Expand Up @@ -139,12 +141,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
private:
void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) {
std::vector<ValueDescr> data_types;
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
int num_cols = schema_[side]->num_cols(projection_handle);
data_types.resize(num_cols);
for (int icol = 0; icol < num_cols; ++icol) {
data_types[icol] =
ValueDescr(schema_mgr_->proj_maps[side].data_type(projection_handle, icol),
ValueDescr::ARRAY);
data_types[icol] = ValueDescr(schema_[side]->data_type(projection_handle, icol),
ValueDescr::ARRAY);
}
encoder->Init(data_types, ctx_);
encoder->Clear();
Expand All @@ -155,8 +156,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ThreadLocalState& local_state = local_states_[thread_index];
if (!local_state.is_initialized) {
InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys);
bool has_payload =
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads);
}
Expand All @@ -168,11 +168,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder,
const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr) {
ExecBatch projected({}, batch.length);
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
int num_cols = schema_[side]->num_cols(projection_handle);
projected.values.resize(num_cols);

auto to_input =
schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT);
auto to_input = schema_[side]->map(projection_handle, HashJoinProjection::INPUT);
for (int icol = 0; icol < num_cols; ++icol) {
projected.values[icol] = batch.values[to_input.get(icol)];
}
Expand Down Expand Up @@ -235,16 +234,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
ExecBatch* opt_left_payload, ExecBatch* opt_right_key,
ExecBatch* opt_right_payload) {
ExecBatch result({}, batch_size_next);
int num_out_cols_left =
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right =
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_left = schema_[0]->num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right = schema_[1]->num_cols(HashJoinProjection::OUTPUT);

result.values.resize(num_out_cols_left + num_out_cols_right);
auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
auto from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
auto from_key = schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
auto from_payload =
schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_left; ++icol) {
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
Expand All @@ -262,10 +258,9 @@ class HashJoinBasicImpl : public HashJoinImpl {
? opt_left_key->values[from_key.get(icol)]
: opt_left_payload->values[from_payload.get(icol)];
}
from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
from_key = schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
from_payload =
schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_right; ++icol) {
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
Expand All @@ -284,7 +279,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
: opt_right_payload->values[from_payload.get(icol)];
}

output_batch_callback_(std::move(result));
output_batch_callback_(0, std::move(result));

// Update the counter of produced batches
//
Expand All @@ -310,13 +305,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
hash_table_keys_.Decode(match_right.size(), match_right.data()));

ExecBatch left_payload;
if (!schema_mgr_->LeftPayloadIsEmpty()) {
if (!schema_[0]->is_empty(HashJoinProjection::PAYLOAD)) {
ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode(
match_left.size(), match_left.data()));
}

ExecBatch right_payload;
if (!schema_mgr_->RightPayloadIsEmpty()) {
if (!schema_[1]->is_empty(HashJoinProjection::PAYLOAD)) {
ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
match_right.size(), match_right.data()));
}
Expand All @@ -336,14 +331,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
}
};

SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap left_to_key =
schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay =
schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key =
schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay =
schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);

AppendFields(left_to_key, left_to_pay, left_key, left_payload);
AppendFields(right_to_key, right_to_pay, right_key, right_payload);
Expand Down Expand Up @@ -419,15 +414,14 @@ class HashJoinBasicImpl : public HashJoinImpl {

bool has_left =
(join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI &&
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT) > 0);
schema_[0]->num_cols(HashJoinProjection::OUTPUT) > 0);
bool has_right =
(join_type_ != JoinType::LEFT_SEMI && join_type_ != JoinType::LEFT_ANTI &&
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT) > 0);
schema_[1]->num_cols(HashJoinProjection::OUTPUT) > 0);
bool has_left_payload =
has_left && (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
has_left && (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_right_payload =
has_right &&
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
has_right && (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);

ThreadLocalState& local_state = local_states_[thread_index];
InitLocalStateIfNeeded(thread_index);
Expand All @@ -450,7 +444,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ARROW_ASSIGN_OR_RAISE(right_key,
hash_table_keys_.Decode(batch_size_next, opt_right_ids));
// Post process build side keys that use dictionary
RETURN_NOT_OK(dict_build_.PostDecode(schema_mgr_->proj_maps[1], &right_key, ctx_));
RETURN_NOT_OK(dict_build_.PostDecode(*schema_[1], &right_key, ctx_));
}
if (has_right_payload) {
ARROW_ASSIGN_OR_RAISE(right_payload,
Expand Down Expand Up @@ -550,8 +544,7 @@ class HashJoinBasicImpl : public HashJoinImpl {

RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys,
batch, &batch_key_for_lookups));
bool has_left_payload =
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_left_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_left_payload) {
local_state.exec_batch_payloads.Clear();
RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::PAYLOAD,
Expand All @@ -563,13 +556,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_state.match_left.clear();
local_state.match_right.clear();

bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded(
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], ctx_);
bool use_key_batch_for_dicts =
dict_probe_.BatchRemapNeeded(thread_index, *schema_[0], *schema_[1], ctx_);
RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys;
if (use_key_batch_for_dicts) {
RETURN_NOT_OK(dict_probe_.EncodeBatch(
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], dict_build_,
batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_));
RETURN_NOT_OK(dict_probe_.EncodeBatch(thread_index, *schema_[0], *schema_[1],
dict_build_, batch, &row_encoder_for_lookups,
&batch_key_for_lookups, ctx_));
}

// Collect information about all nulls in key columns.
Expand Down Expand Up @@ -609,9 +602,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
if (batches.empty()) {
hash_table_empty_ = true;
} else {
dict_build_.InitEncoder(schema_mgr_->proj_maps[1], &hash_table_keys_, ctx_);
bool has_payload =
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_);
bool has_payload = (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_);
}
Expand All @@ -626,11 +618,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
} else if (hash_table_empty_) {
hash_table_empty_ = false;

RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], &batch, ctx_));
RETURN_NOT_OK(dict_build_.Init(*schema_[1], &batch, ctx_));
}
int32_t num_rows_before = hash_table_keys_.num_rows();
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, schema_mgr_->proj_maps[1],
batch, &hash_table_keys_, ctx_));
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, *schema_[1], batch,
&hash_table_keys_, ctx_));
if (has_payload) {
RETURN_NOT_OK(
EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch));
Expand All @@ -643,7 +635,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
}

if (hash_table_empty_) {
RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], nullptr, ctx_));
RETURN_NOT_OK(dict_build_.Init(*schema_[1], nullptr, ctx_));
}

return Status::OK();
Expand Down Expand Up @@ -869,7 +861,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ExecContext* ctx_;
JoinType join_type_;
size_t num_threads_;
HashJoinSchema* schema_mgr_;
const HashJoinProjectionMaps* schema_[2];
std::vector<JoinKeyCmp> key_cmp_;
Expression filter_;
std::unique_ptr<TaskScheduler> scheduler_;
Expand Down
10 changes: 8 additions & 2 deletions cpp/src/arrow/compute/exec/hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class ARROW_EXPORT HashJoinSchema {
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

bool HasDictionaries() const;

bool HasLargeBinary() const;

Result<Expression> BindFilter(Expression filter, const Schema& left_schema,
const Schema& right_schema);
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_suffix,
Expand Down Expand Up @@ -98,12 +102,13 @@ class ARROW_EXPORT HashJoinSchema {

class HashJoinImpl {
public:
using OutputBatchCallback = std::function<void(ExecBatch)>;
using OutputBatchCallback = std::function<void(int64_t, ExecBatch)>;
using FinishedCallback = std::function<void(int64_t)>;

virtual ~HashJoinImpl() = default;
virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
size_t num_threads, const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
Expand All @@ -113,6 +118,7 @@ class HashJoinImpl {
virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0;

static Result<std::unique_ptr<HashJoinImpl>> MakeBasic();
static Result<std::unique_ptr<HashJoinImpl>> MakeSwiss();

protected:
util::tracing::Span span_;
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/exec/hash_join_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ class JoinBenchmark {

DCHECK_OK(join_->Init(
ctx_.get(), settings.join_type, !is_parallel, settings.num_threads,
schema_mgr_.get(), {JoinKeyCmp::EQ}, std::move(filter), [](ExecBatch) {},
[](int64_t x) {}, schedule_callback));
&(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), {JoinKeyCmp::EQ},
std::move(filter), [](int64_t, ExecBatch) {}, [](int64_t x) {},
schedule_callback));
}

void RunJoin() {
Expand Down
54 changes: 51 additions & 3 deletions cpp/src/arrow/compute/exec/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,34 @@ Status HashJoinSchema::CollectFilterColumns(std::vector<FieldRef>& left_filter,
return Status::OK();
}

bool HashJoinSchema::HasDictionaries() const {
for (int side = 0; side <= 1; ++side) {
for (int icol = 0; icol < proj_maps[side].num_cols(HashJoinProjection::INPUT);
++icol) {
const std::shared_ptr<DataType>& column_type =
proj_maps[side].data_type(HashJoinProjection::INPUT, icol);
if (column_type->id() == Type::DICTIONARY) {
return true;
}
}
}
return false;
}

bool HashJoinSchema::HasLargeBinary() const {
for (int side = 0; side <= 1; ++side) {
for (int icol = 0; icol < proj_maps[side].num_cols(HashJoinProjection::INPUT);
++icol) {
const std::shared_ptr<DataType>& column_type =
proj_maps[side].data_type(HashJoinProjection::INPUT, icol);
if (is_large_binary_like(column_type->id())) {
return true;
}
}
}
return false;
}

class HashJoinNode : public ExecNode {
public:
HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& join_options,
Expand Down Expand Up @@ -504,8 +532,26 @@ class HashJoinNode : public ExecNode {
// Generate output schema
std::shared_ptr<Schema> output_schema = schema_mgr->MakeOutputSchema(
join_options.output_suffix_for_left, join_options.output_suffix_for_right);

// Create hash join implementation object
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<HashJoinImpl> impl, HashJoinImpl::MakeBasic());
// SwissJoin does not support:
// a) 64-bit string offsets
// b) residual predicates
// c) dictionaries
//
bool use_swiss_join;
#if ARROW_LITTLE_ENDIAN
use_swiss_join = (filter == literal(true)) && !schema_mgr->HasDictionaries() &&
!schema_mgr->HasLargeBinary();
#else
use_swiss_join = false;
#endif
std::unique_ptr<HashJoinImpl> impl;
if (use_swiss_join) {
ARROW_ASSIGN_OR_RAISE(impl, HashJoinImpl::MakeSwiss());
} else {
ARROW_ASSIGN_OR_RAISE(impl, HashJoinImpl::MakeBasic());
}

return plan->EmplaceNode<HashJoinNode>(
plan, inputs, join_options, std::move(output_schema), std::move(schema_mgr),
Expand Down Expand Up @@ -584,8 +630,10 @@ class HashJoinNode : public ExecNode {

RETURN_NOT_OK(impl_->Init(
plan_->exec_context(), join_type_, use_sync_execution, num_threads,
schema_mgr_.get(), key_cmp_, filter_,
[this](ExecBatch batch) { this->OutputBatchCallback(batch); },
&(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), key_cmp_, filter_,
[this](int64_t /*ignored*/, ExecBatch batch) {
this->OutputBatchCallback(batch);
},
[this](int64_t total_num_batches) { this->FinishedCallback(total_num_batches); },
[this](std::function<Status(size_t)> func) -> Status {
return this->ScheduleTaskCallback(std::move(func));
Expand Down
Loading