void FilterDatasetIterator::MaybeScheduleBackgroundTask()

in lib/data/filter_dataset.cc [61:173]


void FilterDatasetIterator::MaybeScheduleBackgroundTask(
    const ExecutionContext& exec_ctx, bool is_token_owner, int callback_count) {
  {
    mutex_lock lock(mu_);
    // There is no more output value to update. Release the token if the caller
    // owns the token and then return.
    if (output_buffer_.empty()) {
      if (is_token_owner) {
        token_owned_ = false;
      }
      return;
    }
    // Return since the token is already owned by another thread.
    if (!is_token_owner && token_owned_) return;
    // Take the token if the thread does not already own the token.
    token_owned_ = true;
  }
  // Only the thread that owns the token can execute the code below. This
  // ensures in-order delivery since at most one thread can take value from the
  // input_iterator_ and update the output value in the output_buffer_.

  // Fetch enough number of values from the `input_iterator_` to 1) satisfy the
  // values newly added in the `output_buffer_` and 2) compensate for the values
  // in the `input_and_predicate_buffer_` whose predicate value evaluates to
  // false. And schedule tasks to run the filter_fn for newly fetched values in
  // parallel.
  auto host = exec_ctx.host();
  int input_fetch_num = OutputBufferSize() -
                        input_and_predicate_buffer_.size() +
                        std::max(num_false_predicate_.load(), 0);
  const Function* filter_fn = parent_dataset_->filter_fn_.get();
  for (int i = 0; i < input_fetch_num; i++) {
    auto input = input_iterator_->GetNext(exec_ctx);
    auto predicate_values =
        RunFunctionWhenReady(filter_fn, input.CopyRef().values, exec_ctx);
    assert(predicate_values.size() == 1);

    predicate_values[0]->AndThen([predicate_values = predicate_values[0],
                                  iterator = FormRef(this)]() mutable {
      if (!predicate_values->IsError() && !predicate_values->get<bool>()) {
        iterator->num_false_predicate_.fetch_add(1);
      }
    });
    auto predicate = IterationResult::Pending(std::move(predicate_values),
                                              input.eof.CopyRef());
    input_and_predicate_buffer_.push(
        std::make_pair(std::move(input), std::move(predicate)));
  }
  // After the first value in the `input_and_predicate_buffer_` becomes
  // available, the token owner should update `output_buffer` as appropriate,
  // then call MaybeScheduleBackgroundTask() again to schedule more tasks if
  // there are still unfilled outputs.
  assert(!input_and_predicate_buffer_.empty() &&
         "input_and_predicate_buffer should not be empty");
  auto next = std::move(input_and_predicate_buffer_.front());
  input_and_predicate_buffer_.pop();
  auto input = std::move(next.first);
  auto predicate = std::move(next.second);

  llvm::SmallVector<AsyncValue*, 4> async_value_ptrs;
  for (auto& value : input.values) async_value_ptrs.push_back(value.get());
  async_value_ptrs.push_back(input.eof.GetAsyncValue());
  async_value_ptrs.push_back(predicate.values[0].get());
  async_value_ptrs.push_back(predicate.eof.GetAsyncValue());
  RunWhenReady(async_value_ptrs, [exec_ctx, host, callback_count,
                                  input = std::move(input),
                                  predicate = std::move(predicate),
                                  iterator = FormRef(this)]() mutable {
    auto predicate_value = std::move(predicate.values[0]);
    auto predicate_eof = std::move(predicate.eof);
    if (predicate_eof.IsError()) {
      auto output = iterator->DequeueOutputBuffer();
      for (auto& value : output.values) {
        value->SetError(predicate_eof.GetError());
      }
      output.eof.SetError(predicate_eof.GetError());
    } else if (predicate_eof.get()) {
      // The input_iterator_ has been exhausted. Note that predicate_eof and
      // input.eof should have the same value.
      auto error = MakeErrorAsyncValueRef(host, "iterator reached end");
      while (iterator->OutputBufferSize() > 0) {
        auto output = iterator->DequeueOutputBuffer();
        for (auto& value : output.values) {
          value->SetError(error->GetError());
        }
        output.eof.emplace(true);
      }
    } else if (predicate_value->IsError()) {
      auto output = iterator->DequeueOutputBuffer();
      for (auto& value : output.values) {
        value->SetError(predicate_value->GetError());
      }
      output.eof.SetError(predicate_value->GetError());
    } else if (predicate_value->get<bool>()) {
      // The input satisfies the predicate.
      auto output = iterator->DequeueOutputBuffer();
      for (int i = 0; i < output.values.size(); ++i) {
        auto* output_value = cast<IndirectAsyncValue>(output.values[i].get());
        output_value->ForwardTo(std::move(input.values[i]));
      }
      output.eof.emplace(false);
    } else {
      iterator->num_false_predicate_.fetch_sub(1);
    }
    if (callback_count >= MAX_RECURSIVE_CALLS) {
      EnqueueWork(exec_ctx, [exec_ctx, iterator = std::move(iterator)] {
        iterator->MaybeScheduleBackgroundTask(exec_ctx, true, 0);
      });
    } else {
      iterator->MaybeScheduleBackgroundTask(exec_ctx, true, callback_count + 1);
    }
  });
}