in cpp/src/arrow/acero/swiss_join.cc [2246:2412]
Status JoinProbeProcessor::OnNextBatch(int64_t thread_id,
const ExecBatch& keypayload_batch,
arrow::util::TempVectorStack* temp_stack,
std::vector<KeyColumnArray>* temp_column_arrays) {
bool no_duplicate_keys = (hash_table_->key_to_payload() == nullptr);
const SwissTable* swiss_table = hash_table_->keys()->swiss_table();
int64_t hardware_flags = swiss_table->hardware_flags();
int minibatch_size = swiss_table->minibatch_size();
int num_rows = static_cast<int>(keypayload_batch.length);
ExecBatch key_batch({}, keypayload_batch.length);
key_batch.values.resize(num_key_columns_);
for (int i = 0; i < num_key_columns_; ++i) {
key_batch.values[i] = keypayload_batch.values[i];
}
// Break into mini-batches
//
// Start by allocating mini-batch buffers
//
auto hashes_buf = arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size);
auto match_bitvector_buf = arrow::util::TempVectorHolder<uint8_t>(
temp_stack, static_cast<uint32_t>(bit_util::BytesForBits(minibatch_size)));
auto key_ids_buf = arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size);
auto materialize_batch_ids_buf =
arrow::util::TempVectorHolder<uint16_t>(temp_stack, minibatch_size);
auto materialize_key_ids_buf =
arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size);
auto materialize_payload_ids_buf =
arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size);
auto filter_bitvector_buf = arrow::util::TempVectorHolder<uint8_t>(
temp_stack, static_cast<uint32_t>(bit_util::BytesForBits(minibatch_size)));
for (int minibatch_start = 0; minibatch_start < num_rows;) {
uint32_t minibatch_size_next = std::min(minibatch_size, num_rows - minibatch_start);
SwissTableWithKeys::Input input(&key_batch, minibatch_start,
minibatch_start + minibatch_size_next, temp_stack,
temp_column_arrays);
hash_table_->keys()->Hash(&input, hashes_buf.mutable_data(), hardware_flags);
hash_table_->keys()->MapReadOnly(&input, hashes_buf.mutable_data(),
match_bitvector_buf.mutable_data(),
key_ids_buf.mutable_data());
// AND bit vector with null key filter for join
//
bool ignored;
JoinNullFilter::Filter(key_batch, minibatch_start, minibatch_size_next, *cmp_,
&ignored,
/*and_with_input=*/true, match_bitvector_buf.mutable_data());
// Semi-joins
//
if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI ||
join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) {
int num_passing_ids = 0;
if (join_type_ == JoinType::LEFT_SEMI) {
RETURN_NOT_OK(residual_filter_->FilterLeftSemi(
keypayload_batch, minibatch_start, minibatch_size_next,
match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(),
no_duplicate_keys, temp_stack, &num_passing_ids,
materialize_batch_ids_buf.mutable_data()));
} else if (join_type_ == JoinType::LEFT_ANTI) {
RETURN_NOT_OK(residual_filter_->FilterLeftAnti(
keypayload_batch, minibatch_start, minibatch_size_next,
match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(),
no_duplicate_keys, temp_stack, &num_passing_ids,
materialize_batch_ids_buf.mutable_data()));
} else {
RETURN_NOT_OK(residual_filter_->FilterRightSemiAnti(
thread_id, keypayload_batch, minibatch_start, minibatch_size_next,
match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(),
no_duplicate_keys, temp_stack));
}
if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI) {
// For left-semi, left-anti joins: call materialize using match
// row ids.
//
RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly(
keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(),
[&](ExecBatch batch) {
return output_batch_fn_(thread_id, std::move(batch));
}));
}
} else {
// We need to output matching pairs of rows from both sides of the join.
// Since every hash table lookup for an input row might have multiple
// matches we use a helper class that implements enumerating all of them.
//
JoinMatchIterator match_iterator;
match_iterator.SetLookupResult(
minibatch_size_next, minibatch_start, match_bitvector_buf.mutable_data(),
key_ids_buf.mutable_data(), no_duplicate_keys, hash_table_->key_to_payload());
int num_matches_next;
bool use_filter_bitvector = residual_filter_->NeedFilterBitVector(join_type_);
if (use_filter_bitvector) {
residual_filter_->InitFilterBitVector(minibatch_size_next,
filter_bitvector_buf.mutable_data());
}
while (match_iterator.GetNextBatch(minibatch_size, &num_matches_next,
materialize_batch_ids_buf.mutable_data(),
materialize_key_ids_buf.mutable_data(),
materialize_payload_ids_buf.mutable_data())) {
RETURN_NOT_OK(residual_filter_->FilterInner(
keypayload_batch, num_matches_next, materialize_batch_ids_buf.mutable_data(),
materialize_key_ids_buf.mutable_data(),
materialize_payload_ids_buf.mutable_data(), !no_duplicate_keys, temp_stack,
&num_matches_next));
const uint16_t* materialize_batch_ids = materialize_batch_ids_buf.mutable_data();
const uint32_t* materialize_key_ids = materialize_key_ids_buf.mutable_data();
const uint32_t* materialize_payload_ids =
no_duplicate_keys ? materialize_key_ids_buf.mutable_data()
: materialize_payload_ids_buf.mutable_data();
// For filtered result, update filter bit-vector.
//
if (use_filter_bitvector) {
residual_filter_->UpdateFilterBitVector(minibatch_start, num_matches_next,
materialize_batch_ids,
filter_bitvector_buf.mutable_data());
}
// For right-outer, full-outer joins we need to update has-match flags
// for the rows in hash table.
//
if (join_type_ == JoinType::RIGHT_OUTER || join_type_ == JoinType::FULL_OUTER) {
hash_table_->UpdateHasMatchForPayloads(thread_id, num_matches_next,
materialize_payload_ids);
}
// Call materialize for resulting id tuples pointing to matching pairs
// of rows.
//
RETURN_NOT_OK(materialize_[thread_id]->Append(
keypayload_batch, num_matches_next, materialize_batch_ids,
materialize_key_ids, materialize_payload_ids, [&](ExecBatch batch) {
return output_batch_fn_(thread_id, std::move(batch));
}));
}
// For left-outer and full-outer joins output non-matches.
//
// Call materialize. Nulls will be output in all columns that come from
// the other side of the join.
//
if (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER) {
int num_passing_ids = 0;
CollectPassingBatchIds(0, hardware_flags, minibatch_start, minibatch_size_next,
use_filter_bitvector ? filter_bitvector_buf.mutable_data()
: match_bitvector_buf.mutable_data(),
&num_passing_ids,
materialize_batch_ids_buf.mutable_data());
RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly(
keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(),
[&](ExecBatch batch) {
return output_batch_fn_(thread_id, std::move(batch));
}));
}
}
minibatch_start += minibatch_size_next;
}
return Status::OK();
}