Status JoinProbeProcessor::OnNextBatch()

in cpp/src/arrow/acero/swiss_join.cc [2246:2412]


Status JoinProbeProcessor::OnNextBatch(int64_t thread_id,
                                       const ExecBatch& keypayload_batch,
                                       arrow::util::TempVectorStack* temp_stack,
                                       std::vector<KeyColumnArray>* temp_column_arrays) {
  bool no_duplicate_keys = (hash_table_->key_to_payload() == nullptr);
  const SwissTable* swiss_table = hash_table_->keys()->swiss_table();
  int64_t hardware_flags = swiss_table->hardware_flags();
  int minibatch_size = swiss_table->minibatch_size();
  int num_rows = static_cast<int>(keypayload_batch.length);

  ExecBatch key_batch({}, keypayload_batch.length);
  key_batch.values.resize(num_key_columns_);
  for (int i = 0; i < num_key_columns_; ++i) {
    key_batch.values[i] = keypayload_batch.values[i];
  }

  // Break into mini-batches
  //
  // Start by allocating mini-batch buffers
  //
  auto hashes_buf = arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size);
  auto match_bitvector_buf = arrow::util::TempVectorHolder<uint8_t>(
      temp_stack, static_cast<uint32_t>(bit_util::BytesForBits(minibatch_size)));
  auto key_ids_buf = arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size);
  auto materialize_batch_ids_buf =
      arrow::util::TempVectorHolder<uint16_t>(temp_stack, minibatch_size);
  auto materialize_key_ids_buf =
      arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size);
  auto materialize_payload_ids_buf =
      arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size);
  auto filter_bitvector_buf = arrow::util::TempVectorHolder<uint8_t>(
      temp_stack, static_cast<uint32_t>(bit_util::BytesForBits(minibatch_size)));

  for (int minibatch_start = 0; minibatch_start < num_rows;) {
    uint32_t minibatch_size_next = std::min(minibatch_size, num_rows - minibatch_start);

    SwissTableWithKeys::Input input(&key_batch, minibatch_start,
                                    minibatch_start + minibatch_size_next, temp_stack,
                                    temp_column_arrays);
    hash_table_->keys()->Hash(&input, hashes_buf.mutable_data(), hardware_flags);
    hash_table_->keys()->MapReadOnly(&input, hashes_buf.mutable_data(),
                                     match_bitvector_buf.mutable_data(),
                                     key_ids_buf.mutable_data());

    // AND bit vector with null key filter for join
    //
    bool ignored;
    JoinNullFilter::Filter(key_batch, minibatch_start, minibatch_size_next, *cmp_,
                           &ignored,
                           /*and_with_input=*/true, match_bitvector_buf.mutable_data());
    // Semi-joins
    //
    if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI ||
        join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) {
      int num_passing_ids = 0;
      if (join_type_ == JoinType::LEFT_SEMI) {
        RETURN_NOT_OK(residual_filter_->FilterLeftSemi(
            keypayload_batch, minibatch_start, minibatch_size_next,
            match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(),
            no_duplicate_keys, temp_stack, &num_passing_ids,
            materialize_batch_ids_buf.mutable_data()));
      } else if (join_type_ == JoinType::LEFT_ANTI) {
        RETURN_NOT_OK(residual_filter_->FilterLeftAnti(
            keypayload_batch, minibatch_start, minibatch_size_next,
            match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(),
            no_duplicate_keys, temp_stack, &num_passing_ids,
            materialize_batch_ids_buf.mutable_data()));
      } else {
        RETURN_NOT_OK(residual_filter_->FilterRightSemiAnti(
            thread_id, keypayload_batch, minibatch_start, minibatch_size_next,
            match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(),
            no_duplicate_keys, temp_stack));
      }

      if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI) {
        // For left-semi, left-anti joins: call materialize using match
        // row ids.
        //
        RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly(
            keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(),
            [&](ExecBatch batch) {
              return output_batch_fn_(thread_id, std::move(batch));
            }));
      }
    } else {
      // We need to output matching pairs of rows from both sides of the join.
      // Since every hash table lookup for an input row might have multiple
      // matches we use a helper class that implements enumerating all of them.
      //
      JoinMatchIterator match_iterator;
      match_iterator.SetLookupResult(
          minibatch_size_next, minibatch_start, match_bitvector_buf.mutable_data(),
          key_ids_buf.mutable_data(), no_duplicate_keys, hash_table_->key_to_payload());
      int num_matches_next;
      bool use_filter_bitvector = residual_filter_->NeedFilterBitVector(join_type_);
      if (use_filter_bitvector) {
        residual_filter_->InitFilterBitVector(minibatch_size_next,
                                              filter_bitvector_buf.mutable_data());
      }
      while (match_iterator.GetNextBatch(minibatch_size, &num_matches_next,
                                         materialize_batch_ids_buf.mutable_data(),
                                         materialize_key_ids_buf.mutable_data(),
                                         materialize_payload_ids_buf.mutable_data())) {
        RETURN_NOT_OK(residual_filter_->FilterInner(
            keypayload_batch, num_matches_next, materialize_batch_ids_buf.mutable_data(),
            materialize_key_ids_buf.mutable_data(),
            materialize_payload_ids_buf.mutable_data(), !no_duplicate_keys, temp_stack,
            &num_matches_next));

        const uint16_t* materialize_batch_ids = materialize_batch_ids_buf.mutable_data();
        const uint32_t* materialize_key_ids = materialize_key_ids_buf.mutable_data();
        const uint32_t* materialize_payload_ids =
            no_duplicate_keys ? materialize_key_ids_buf.mutable_data()
                              : materialize_payload_ids_buf.mutable_data();

        // For filtered result, update filter bit-vector.
        //
        if (use_filter_bitvector) {
          residual_filter_->UpdateFilterBitVector(minibatch_start, num_matches_next,
                                                  materialize_batch_ids,
                                                  filter_bitvector_buf.mutable_data());
        }

        // For right-outer, full-outer joins we need to update has-match flags
        // for the rows in hash table.
        //
        if (join_type_ == JoinType::RIGHT_OUTER || join_type_ == JoinType::FULL_OUTER) {
          hash_table_->UpdateHasMatchForPayloads(thread_id, num_matches_next,
                                                 materialize_payload_ids);
        }

        // Call materialize for resulting id tuples pointing to matching pairs
        // of rows.
        //
        RETURN_NOT_OK(materialize_[thread_id]->Append(
            keypayload_batch, num_matches_next, materialize_batch_ids,
            materialize_key_ids, materialize_payload_ids, [&](ExecBatch batch) {
              return output_batch_fn_(thread_id, std::move(batch));
            }));
      }

      // For left-outer and full-outer joins output non-matches.
      //
      // Call materialize. Nulls will be output in all columns that come from
      // the other side of the join.
      //
      if (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER) {
        int num_passing_ids = 0;
        CollectPassingBatchIds(0, hardware_flags, minibatch_start, minibatch_size_next,
                               use_filter_bitvector ? filter_bitvector_buf.mutable_data()
                                                    : match_bitvector_buf.mutable_data(),
                               &num_passing_ids,
                               materialize_batch_ids_buf.mutable_data());

        RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly(
            keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(),
            [&](ExecBatch batch) {
              return output_batch_fn_(thread_id, std::move(batch));
            }));
      }
    }

    minibatch_start += minibatch_size_next;
  }

  return Status::OK();
}