#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Worker operations executor.

For internal use only; no backwards-compatibility guarantees.
"""

# pytype: skip-file

import logging
import sys
import threading
import traceback
from enum import Enum
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple

from apache_beam.coders import TupleCoder
from apache_beam.coders import coders
from apache_beam.internal import util
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.pvalue import TaggedOutput
from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider
from apache_beam.runners.sdf_utils import RestrictionTrackerView
from apache_beam.runners.sdf_utils import SplitResultPrimary
from apache_beam.runners.sdf_utils import SplitResultResidual
from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker
from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator
from apache_beam.transforms import DoFn
from apache_beam.transforms import core
from apache_beam.transforms import userstate
from apache_beam.transforms.core import RestrictionProvider
from apache_beam.transforms.core import WatermarkEstimatorProvider
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import TimestampedValue
from apache_beam.transforms.window import WindowFn
from apache_beam.typehints.batch import BatchConverter
from apache_beam.utils.counters import Counter
from apache_beam.utils.counters import CounterName
from apache_beam.utils.timestamp import Timestamp
from apache_beam.utils.windowed_value import HomogeneousWindowedBatch
from apache_beam.utils.windowed_value import WindowedBatch
from apache_beam.utils.windowed_value import WindowedValue

if TYPE_CHECKING:
  from apache_beam.runners.worker.bundle_processor import ExecutionContext
  from apache_beam.transforms import sideinputs
  from apache_beam.transforms.core import TimerSpec
  from apache_beam.io.iobase import RestrictionProgress
  from apache_beam.iobase import RestrictionTracker
  from apache_beam.iobase import WatermarkEstimator

IMPULSE_VALUE_CODER_IMPL = coders.WindowedValueCoder(
    coders.BytesCoder(), coders.GlobalWindowCoder()).get_impl()

ENCODED_IMPULSE_VALUE = IMPULSE_VALUE_CODER_IMPL.encode_nested(
    GlobalWindows.windowed_value(b''))

_LOGGER = logging.getLogger(__name__)


class NameContext(object):
  """Holds the name information for a step."""
  def __init__(self, step_name, transform_id=None):
    # type: (str, Optional[str]) -> None

    """Creates a new step NameContext.

    Args:
      step_name: The name of the step.
    """
    self.step_name = step_name
    self.transform_id = transform_id

  def __eq__(self, other):
    return self.step_name == other.step_name

  def __repr__(self):
    return 'NameContext(%s)' % self.__dict__

  def __hash__(self):
    return hash(self.step_name)

  def metrics_name(self):
    """Returns the step name used for metrics reporting."""
    return self.step_name

  def logging_name(self):
    """Returns the step name used for logging."""
    return self.step_name


class Receiver(object):
  """For internal use only; no backwards-compatibility guarantees.

  An object that consumes a WindowedValue.

  This class can be efficiently used to pass values between the
  sdk and worker harnesses.
  """
  def receive(self, windowed_value):
    # type: (WindowedValue) -> None
    raise NotImplementedError

  def receive_batch(self, windowed_batch):
    # type: (WindowedBatch) -> None
    raise NotImplementedError

  def flush(self):
    raise NotImplementedError


class MethodWrapper(object):
  """For internal use only; no backwards-compatibility guarantees.

  Represents a method that can be invoked by `DoFnInvoker`."""
  def __init__(self, obj_to_invoke, method_name):
    """
    Initiates a ``MethodWrapper``.

    Args:
      obj_to_invoke: the object that contains the method. Has to either be a
                    `DoFn` object or a `RestrictionProvider` object.
      method_name: name of the method as a string.
    """

    if not isinstance(obj_to_invoke,
                      (DoFn, RestrictionProvider, WatermarkEstimatorProvider)):
      raise ValueError(
          '\'obj_to_invoke\' has to be either a \'DoFn\' or '
          'a \'RestrictionProvider\'. Received %r instead.' % obj_to_invoke)

    self.args, self.defaults = core.get_function_arguments(obj_to_invoke,
                                                           method_name)
    # TODO(BEAM-5878) support kwonlyargs on Python 3.
    self.method_value = getattr(obj_to_invoke, method_name)
    self.method_name = method_name

    self.has_userstate_arguments = False
    self.state_args_to_replace = {}  # type: Dict[str, core.StateSpec]
    self.timer_args_to_replace = {}  # type: Dict[str, core.TimerSpec]
    self.timestamp_arg_name = None  # type: Optional[str]
    self.window_arg_name = None  # type: Optional[str]
    self.key_arg_name = None  # type: Optional[str]
    self.restriction_provider = None
    self.restriction_provider_arg_name = None
    self.watermark_estimator_provider = None
    self.watermark_estimator_provider_arg_name = None
    self.dynamic_timer_tag_arg_name = None

    if hasattr(self.method_value, 'unbounded_per_element'):
      self.unbounded_per_element = True
    else:
      self.unbounded_per_element = False

    for kw, v in zip(self.args[-len(self.defaults):], self.defaults):
      if isinstance(v, core.DoFn.StateParam):
        self.state_args_to_replace[kw] = v.state_spec
        self.has_userstate_arguments = True
      elif isinstance(v, core.DoFn.TimerParam):
        self.timer_args_to_replace[kw] = v.timer_spec
        self.has_userstate_arguments = True
      elif core.DoFn.TimestampParam == v:
        self.timestamp_arg_name = kw
      elif core.DoFn.WindowParam == v:
        self.window_arg_name = kw
      elif core.DoFn.WindowedValueParam == v:
        self.window_arg_name = kw
      elif core.DoFn.KeyParam == v:
        self.key_arg_name = kw
      elif isinstance(v, core.DoFn.RestrictionParam):
        self.restriction_provider = v.restriction_provider or obj_to_invoke
        self.restriction_provider_arg_name = kw
      elif isinstance(v, core.DoFn.WatermarkEstimatorParam):
        self.watermark_estimator_provider = (
            v.watermark_estimator_provider or obj_to_invoke)
        self.watermark_estimator_provider_arg_name = kw
      elif core.DoFn.DynamicTimerTagParam == v:
        self.dynamic_timer_tag_arg_name = kw

    # Create NoOpWatermarkEstimatorProvider if there is no
    # WatermarkEstimatorParam provided.
    if self.watermark_estimator_provider is None:
      self.watermark_estimator_provider = NoOpWatermarkEstimatorProvider()

  def invoke_timer_callback(
      self,
      user_state_context,
      key,
      window,
      timestamp,
      pane_info,
      dynamic_timer_tag):
    # TODO(ccy): support side inputs.
    kwargs = {}
    if self.has_userstate_arguments:
      for kw, state_spec in self.state_args_to_replace.items():
        kwargs[kw] = user_state_context.get_state(state_spec, key, window)
      for kw, timer_spec in self.timer_args_to_replace.items():
        kwargs[kw] = user_state_context.get_timer(
            timer_spec, key, window, timestamp, pane_info)

    if self.timestamp_arg_name:
      kwargs[self.timestamp_arg_name] = Timestamp.of(timestamp)
    if self.window_arg_name:
      kwargs[self.window_arg_name] = window
    if self.key_arg_name:
      kwargs[self.key_arg_name] = key
    if self.dynamic_timer_tag_arg_name:
      kwargs[self.dynamic_timer_tag_arg_name] = dynamic_timer_tag

    if kwargs:
      return self.method_value(**kwargs)
    else:
      return self.method_value()


class BatchingPreference(Enum):
  DO_NOT_CARE = 1  # This operation can operate on batches or element-at-a-time
  # TODO: Should we also store batching parameters here? (time/size preferences)
  BATCH_REQUIRED = 2  # This operation can only operate on batches
  BATCH_FORBIDDEN = 3  # This operation can only work element-at-a-time
  # Other possibilities: BATCH_PREFERRED (with min batch size specified)

  @property
  def supports_batches(self) -> bool:
    return self in (self.BATCH_REQUIRED, self.DO_NOT_CARE)

  @property
  def supports_elements(self) -> bool:
    return self in (self.BATCH_FORBIDDEN, self.DO_NOT_CARE)

  @property
  def requires_batches(self) -> bool:
    return self == self.BATCH_REQUIRED


class DoFnSignature(object):
  """Represents the signature of a given ``DoFn`` object.

  Signature of a ``DoFn`` provides a view of the properties of a given ``DoFn``.
  Among other things, this will give an extensible way for for (1) accessing the
  structure of the ``DoFn`` including methods and method parameters
  (2) identifying features that a given ``DoFn`` support, for example, whether
  a given ``DoFn`` is a Splittable ``DoFn`` (
  https://s.apache.org/splittable-do-fn) (3) validating a ``DoFn`` based on the
  feature set offered by it.
  """
  def __init__(self, do_fn):
    # type: (core.DoFn) -> None
    # We add a property here for all methods defined by Beam DoFn features.

    assert isinstance(do_fn, core.DoFn)
    self.do_fn = do_fn

    self.process_method = MethodWrapper(do_fn, 'process')
    self.process_batch_method = MethodWrapper(do_fn, 'process_batch')
    self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle')
    self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle')
    self.setup_lifecycle_method = MethodWrapper(do_fn, 'setup')
    self.teardown_lifecycle_method = MethodWrapper(do_fn, 'teardown')

    restriction_provider = self.get_restriction_provider()
    watermark_estimator_provider = self.get_watermark_estimator_provider()
    self.create_watermark_estimator_method = (
        MethodWrapper(
            watermark_estimator_provider, 'create_watermark_estimator'))
    self.initial_restriction_method = (
        MethodWrapper(restriction_provider, 'initial_restriction')
        if restriction_provider else None)
    self.create_tracker_method = (
        MethodWrapper(restriction_provider, 'create_tracker')
        if restriction_provider else None)
    self.split_method = (
        MethodWrapper(restriction_provider, 'split')
        if restriction_provider else None)

    self._validate()

    # Handle stateful DoFns.
    self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn)
    self.timer_methods = {}  # type: Dict[TimerSpec, MethodWrapper]
    if self._is_stateful_dofn:
      # Populate timer firing methods, keyed by TimerSpec.
      _, all_timer_specs = userstate.get_dofn_specs(do_fn)
      for timer_spec in all_timer_specs:
        method = timer_spec._attached_callback
        self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__)

  def get_restriction_provider(self):
    # type: () -> RestrictionProvider
    return self.process_method.restriction_provider

  def get_watermark_estimator_provider(self):
    # type: () -> WatermarkEstimatorProvider
    return self.process_method.watermark_estimator_provider

  def is_unbounded_per_element(self):
    return self.process_method.unbounded_per_element

  def _validate(self):
    # type: () -> None
    self._validate_process()
    self._validate_process_batch()
    self._validate_bundle_method(self.start_bundle_method)
    self._validate_bundle_method(self.finish_bundle_method)
    self._validate_stateful_dofn()

  def _check_duplicate_dofn_params(self, method: MethodWrapper):
    param_ids = [
        d.param_id for d in method.defaults if isinstance(d, core._DoFnParam)
    ]
    if len(param_ids) != len(set(param_ids)):
      raise ValueError(
          'DoFn %r has duplicate %s method parameters: %s.' %
          (self.do_fn, method.method_name, param_ids))

  def _validate_process(self):
    # type: () -> None

    """Validate that none of the DoFnParameters are repeated in the function
    """
    self._check_duplicate_dofn_params(self.process_method)

  def _validate_process_batch(self):
    # type: () -> None
    self._check_duplicate_dofn_params(self.process_batch_method)

    for d in self.process_batch_method.defaults:
      if not isinstance(d, core._DoFnParam):
        continue

      # Helpful errors for params which will be supported in the future
      if d == (core.DoFn.ElementParam):
        # We currently assume we can just get the typehint from the first
        # parameter. ElementParam breaks this assumption
        raise NotImplementedError(
            f"DoFn {self.do_fn!r} uses unsupported DoFn param ElementParam.")

      if d in (core.DoFn.KeyParam, core.DoFn.StateParam, core.DoFn.TimerParam):
        raise NotImplementedError(
            f"DoFn {self.do_fn!r} has unsupported per-key DoFn param {d}. "
            "Per-key DoFn params are not yet supported for process_batch "
            "(https://github.com/apache/beam/issues/21653).")

      # Fallback to catch anything not explicitly supported
      if not d in (core.DoFn.WindowParam,
                   core.DoFn.TimestampParam,
                   core.DoFn.PaneInfoParam):
        raise ValueError(
            f"DoFn {self.do_fn!r} has unsupported process_batch "
            f"method parameter {d}")

  def _validate_bundle_method(self, method_wrapper):
    """Validate that none of the DoFnParameters are used in the function
    """
    for param in core.DoFn.DoFnProcessParams:
      if param in method_wrapper.defaults:
        raise ValueError(
            'DoFn.process() method-only parameter %s cannot be used in %s.' %
            (param, method_wrapper))

  def _validate_stateful_dofn(self):
    # type: () -> None
    userstate.validate_stateful_dofn(self.do_fn)

  def is_splittable_dofn(self):
    # type: () -> bool
    return self.get_restriction_provider() is not None

  def get_restriction_coder(self):
    # type: () -> Optional[TupleCoder]

    """Get coder for a restriction when processing an SDF. """
    if self.is_splittable_dofn():
      return TupleCoder([
          (self.get_restriction_provider().restriction_coder()),
          (self.get_watermark_estimator_provider().estimator_state_coder())
      ])
    else:
      return None

  def is_stateful_dofn(self):
    # type: () -> bool
    return self._is_stateful_dofn

  def has_timers(self):
    # type: () -> bool
    _, all_timer_specs = userstate.get_dofn_specs(self.do_fn)
    return bool(all_timer_specs)

  def has_bundle_finalization(self):
    for sig in (self.start_bundle_method,
                self.process_method,
                self.finish_bundle_method):
      for d in sig.defaults:
        try:
          if d == DoFn.BundleFinalizerParam:
            return True
        except Exception:  # pylint: disable=broad-except
          # Default value might be incomparable.
          pass
    return False

  def get_bundle_contexts(self):
    seen = set()
    for sig in (self.setup_lifecycle_method,
                self.start_bundle_method,
                self.process_method,
                self.process_batch_method,
                self.finish_bundle_method,
                self.teardown_lifecycle_method):
      for d in sig.defaults:
        try:
          if isinstance(d, DoFn.BundleContextParam):
            if d not in seen:
              seen.add(d)
              yield d
        except Exception:  # pylint: disable=broad-except
          # Default value might be incomparable.
          pass

  def get_setup_contexts(self):
    seen = set()
    for sig in (self.setup_lifecycle_method,
                self.start_bundle_method,
                self.process_method,
                self.process_batch_method,
                self.finish_bundle_method,
                self.teardown_lifecycle_method):
      for d in sig.defaults:
        try:
          if isinstance(d, DoFn.SetupContextParam):
            if d not in seen:
              seen.add(d)
              yield d
        except Exception:  # pylint: disable=broad-except
          # Default value might be incomparable.
          pass


class DoFnInvoker(object):
  """An abstraction that can be used to execute DoFn methods.

  A DoFnInvoker describes a particular way for invoking methods of a DoFn
  represented by a given DoFnSignature."""

  def __init__(self,
               output_handler,  # type: _OutputHandler
               signature  # type: DoFnSignature
              ):
    # type: (...) -> None

    """
    Initializes `DoFnInvoker`

    :param output_handler: an OutputHandler for receiving elements produced
                             by invoking functions of the DoFn.
    :param signature: a DoFnSignature for the DoFn being invoked
    """
    self.output_handler = output_handler
    self.signature = signature
    self.user_state_context = None  # type: Optional[userstate.UserStateContext]
    self.bundle_finalizer_param = None  # type: Optional[core._BundleFinalizerParam]

  @staticmethod
  def create_invoker(
      signature,  # type: DoFnSignature
      output_handler,  # type: OutputHandler
      context=None,  # type: Optional[DoFnContext]
      side_inputs=None,   # type: Optional[List[sideinputs.SideInputMap]]
      input_args=None, input_kwargs=None,
      process_invocation=True,
      user_state_context=None,  # type: Optional[userstate.UserStateContext]
      bundle_finalizer_param=None  # type: Optional[core._BundleFinalizerParam]
  ):
    # type: (...) -> DoFnInvoker

    """ Creates a new DoFnInvoker based on given arguments.

    Args:
        output_handler: an OutputHandler for receiving elements produced by
                          invoking functions of the DoFn.
        signature: a DoFnSignature for the DoFn being invoked.
        context: Context to be used when invoking the DoFn (deprecated).
        side_inputs: side inputs to be used when invoking th process method.
        input_args: arguments to be used when invoking the process method. Some
                    of the arguments given here might be placeholders (for
                    example for side inputs) that get filled before invoking the
                    process method.
        input_kwargs: keyword arguments to be used when invoking the process
                      method. Some of the keyword arguments given here might be
                      placeholders (for example for side inputs) that get filled
                      before invoking the process method.
        process_invocation: If True, this function may return an invoker that
                            performs extra optimizations for invoking process()
                            method efficiently.
        user_state_context: The UserStateContext instance for the current
                            Stateful DoFn.
        bundle_finalizer_param: The param that passed to a process method, which
                                allows a callback to be registered.
    """
    side_inputs = side_inputs or []
    use_per_window_invoker = process_invocation and (
        side_inputs or input_args or input_kwargs or
        signature.process_method.defaults or
        signature.process_batch_method.defaults or signature.is_stateful_dofn())
    if not use_per_window_invoker:
      return SimpleInvoker(output_handler, signature)
    else:
      if context is None:
        raise TypeError("Must provide context when not using SimpleInvoker")
      return PerWindowInvoker(
          output_handler,
          signature,
          context,
          side_inputs,
          input_args,
          input_kwargs,
          user_state_context,
          bundle_finalizer_param)

  def invoke_process(self,
                     windowed_value,  # type: WindowedValue
                     restriction=None,
                     watermark_estimator_state=None,
                     additional_args=None,
                     additional_kwargs=None
                    ):
    # type: (...) -> Iterable[SplitResultResidual]

    """Invokes the DoFn.process() function.

    Args:
      windowed_value: a WindowedValue object that gives the element for which
                      process() method should be invoked along with the window
                      the element belongs to.
      restriction: The restriction to use when executing this splittable DoFn.
                   Should only be specified for splittable DoFns.
      watermark_estimator_state: The watermark estimator state to use when
                                 executing this splittable DoFn. Should only
                                 be specified for splittable DoFns.
      additional_args: additional arguments to be passed to the current
                      `DoFn.process()` invocation, usually as side inputs.
      additional_kwargs: additional keyword arguments to be passed to the
                         current `DoFn.process()` invocation.
    """
    raise NotImplementedError

  def invoke_process_batch(self,
                     windowed_batch,  # type: WindowedBatch
                     additional_args=None,
                     additional_kwargs=None
                    ):
    # type: (...) -> None

    """Invokes the DoFn.process() function.

    Args:
      windowed_batch: a WindowedBatch object that gives a batch of elements for
                      which process_batch() method should be invoked, along with
                      the window each element belongs to.
      additional_args: additional arguments to be passed to the current
                      `DoFn.process()` invocation, usually as side inputs.
      additional_kwargs: additional keyword arguments to be passed to the
                         current `DoFn.process()` invocation.
    """
    raise NotImplementedError

  def invoke_setup(self):
    # type: () -> None

    """Invokes the DoFn.setup() method
    """
    self._setup_context_values = {
        c: c.create_and_enter()
        for c in self.signature.get_setup_contexts()
    }
    self.signature.setup_lifecycle_method.method_value()

  def invoke_start_bundle(self):
    # type: () -> None

    """Invokes the DoFn.start_bundle() method.
    """
    self._bundle_context_values = {
        c: c.create_and_enter()
        for c in self.signature.get_bundle_contexts()
    }
    self.output_handler.start_bundle_outputs(
        self.signature.start_bundle_method.method_value())

  def invoke_finish_bundle(self):
    # type: () -> None

    """Invokes the DoFn.finish_bundle() method.
    """
    self.output_handler.finish_bundle_outputs(
        self.signature.finish_bundle_method.method_value())
    for c in self._bundle_context_values.values():
      c[0].__exit__(None, None, None)
    self._bundle_context_values = None

  def invoke_teardown(self):
    # type: () -> None

    """Invokes the DoFn.teardown() method
    """
    self.signature.teardown_lifecycle_method.method_value()
    for c in self._setup_context_values.values():
      c[0].__exit__(None, None, None)
    self._setup_context_values = None

  def invoke_user_timer(
      self, timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag):
    # self.output_handler is Optional, but in practice it won't be None here
    self.output_handler.handle_process_outputs(
        WindowedValue(None, timestamp, (window, )),
        self.signature.timer_methods[timer_spec].invoke_timer_callback(
            self.user_state_context,
            key,
            window,
            timestamp,
            pane_info,
            dynamic_timer_tag))

  def invoke_create_watermark_estimator(self, estimator_state):
    return self.signature.create_watermark_estimator_method.method_value(
        estimator_state)

  def invoke_split(self, element, restriction):
    return self.signature.split_method.method_value(element, restriction)

  def invoke_initial_restriction(self, element):
    return self.signature.initial_restriction_method.method_value(element)

  def invoke_create_tracker(self, restriction):
    return self.signature.create_tracker_method.method_value(restriction)


class SimpleInvoker(DoFnInvoker):
  """An invoker that processes elements ignoring windowing information."""

  def __init__(self,
               output_handler,  # type: OutputHandler
               signature  # type: DoFnSignature
              ):
    # type: (...) -> None
    super().__init__(output_handler, signature)
    self.process_method = signature.process_method.method_value
    self.process_batch_method = signature.process_batch_method.method_value

  def invoke_process(self,
                     windowed_value,  # type: WindowedValue
                     restriction=None,
                     watermark_estimator_state=None,
                     additional_args=None,
                     additional_kwargs=None
                    ):
    # type: (...) -> Iterable[SplitResultResidual]
    self.output_handler.handle_process_outputs(
        windowed_value, self.process_method(windowed_value.value))
    return []

  def invoke_process_batch(self,
                     windowed_batch,  # type: WindowedBatch
                     restriction=None,
                     watermark_estimator_state=None,
                     additional_args=None,
                     additional_kwargs=None
                    ):
    # type: (...) -> None
    self.output_handler.handle_process_batch_outputs(
        windowed_batch, self.process_batch_method(windowed_batch.values))


def _get_arg_placeholders(
    method: MethodWrapper,
    input_args: Optional[List[Any]],
    input_kwargs: Optional[Dict[str, any]]):
  input_args = input_args if input_args else []
  input_kwargs = input_kwargs if input_kwargs else {}

  arg_names = method.args
  default_arg_values = method.defaults

  # Create placeholder for element parameter of DoFn.process() method.
  # Not to be confused with ArgumentPlaceHolder, which may be passed in
  # input_args and is a placeholder for side-inputs.
  class ArgPlaceholder(object):
    def __init__(self, placeholder):
      self.placeholder = placeholder

  if all(core.DoFn.ElementParam != arg for arg in default_arg_values):
    # TODO(https://github.com/apache/beam/issues/19631): Handle cases in which
    #   len(arg_names) == len(default_arg_values).
    args_to_pick = len(arg_names) - len(default_arg_values) - 1
    # Positional argument values for process(), with placeholders for special
    # values such as the element, timestamp, etc.
    args_with_placeholders = ([ArgPlaceholder(core.DoFn.ElementParam)] +
                              input_args[:args_to_pick])
  else:
    args_to_pick = len(arg_names) - len(default_arg_values)
    args_with_placeholders = input_args[:args_to_pick]

  # Fill the OtherPlaceholders for context, key, window or timestamp
  remaining_args_iter = iter(input_args[args_to_pick:])
  for a, d in zip(arg_names[-len(default_arg_values):], default_arg_values):
    if core.DoFn.ElementParam == d:
      args_with_placeholders.append(ArgPlaceholder(d))
    elif core.DoFn.KeyParam == d:
      args_with_placeholders.append(ArgPlaceholder(d))
    elif core.DoFn.WindowParam == d:
      args_with_placeholders.append(ArgPlaceholder(d))
    elif core.DoFn.WindowedValueParam == d:
      args_with_placeholders.append(ArgPlaceholder(d))
    elif core.DoFn.TimestampParam == d:
      args_with_placeholders.append(ArgPlaceholder(d))
    elif core.DoFn.PaneInfoParam == d:
      args_with_placeholders.append(ArgPlaceholder(d))
    elif core.DoFn.SideInputParam == d:
      # If no more args are present then the value must be passed via kwarg
      try:
        args_with_placeholders.append(next(remaining_args_iter))
      except StopIteration:
        if a not in input_kwargs:
          raise ValueError("Value for sideinput %s not provided" % a)
    elif isinstance(d, core.DoFn.StateParam):
      args_with_placeholders.append(ArgPlaceholder(d))
    elif isinstance(d, core.DoFn.TimerParam):
      args_with_placeholders.append(ArgPlaceholder(d))
    elif isinstance(d, type) and core.DoFn.BundleFinalizerParam == d:
      args_with_placeholders.append(ArgPlaceholder(d))
    elif isinstance(d, core.DoFn.BundleContextParam):
      args_with_placeholders.append(ArgPlaceholder(d))
    elif isinstance(d, core.DoFn.SetupContextParam):
      args_with_placeholders.append(ArgPlaceholder(d))
    else:
      # If no more args are present then the value must be passed via kwarg
      try:
        args_with_placeholders.append(next(remaining_args_iter))
      except StopIteration:
        pass
  args_with_placeholders.extend(list(remaining_args_iter))

  # Stash the list of placeholder positions for performance
  placeholders = [(i, x.placeholder)
                  for (i, x) in enumerate(args_with_placeholders)
                  if isinstance(x, ArgPlaceholder)]

  return placeholders, args_with_placeholders, input_kwargs


class PerWindowInvoker(DoFnInvoker):
  """An invoker that processes elements considering windowing information."""

  def __init__(self,
               output_handler,  # type: OutputHandler
               signature,  # type: DoFnSignature
               context,  # type: DoFnContext
               side_inputs,  # type: Iterable[sideinputs.SideInputMap]
               input_args,
               input_kwargs,
               user_state_context,  # type: Optional[userstate.UserStateContext]
               bundle_finalizer_param  # type: Optional[core._BundleFinalizerParam]
              ):
    super().__init__(output_handler, signature)
    self.side_inputs = side_inputs
    self.context = context
    self.process_method = signature.process_method.method_value
    default_arg_values = signature.process_method.defaults
    self.has_windowed_inputs = (
        not all(si.is_globally_windowed() for si in side_inputs) or any(
            core.DoFn.WindowParam == arg
            for arg in signature.process_method.defaults) or any(
                core.DoFn.WindowParam == arg
                for arg in signature.process_batch_method.defaults) or
        signature.is_stateful_dofn())
    self.user_state_context = user_state_context
    self.is_splittable = signature.is_splittable_dofn()
    self.is_key_param_required = any(
        core.DoFn.KeyParam == arg for arg in default_arg_values)
    self.threadsafe_restriction_tracker = None  # type: Optional[ThreadsafeRestrictionTracker]
    self.threadsafe_watermark_estimator = None  # type: Optional[ThreadsafeWatermarkEstimator]
    self.current_windowed_value = None  # type: Optional[WindowedValue]
    self.bundle_finalizer_param = bundle_finalizer_param
    if self.is_splittable:
      self.splitting_lock = threading.Lock()
      self.current_window_index = None
      self.stop_window_index = None

    # TODO(https://github.com/apache/beam/issues/28776): Remove caching after
    # fully rolling out.
    # If true, always recalculate window args. If false, has_cached_window_args
    # and has_cached_window_batch_args will be set to true if the corresponding
    # self.args_for_process,have been updated and should be reused directly.
    self.recalculate_window_args = (
        self.has_windowed_inputs or 'disable_global_windowed_args_caching' in
        RuntimeValueProvider.experiments)
    self.has_cached_window_args = False
    self.has_cached_window_batch_args = False

    # Try to prepare all the arguments that can just be filled in
    # without any additional work. in the process function.
    # Also cache all the placeholders needed in the process function.
    input_args = list(input_args)
    (
        self.placeholders_for_process,
        self.args_for_process,
        self.kwargs_for_process) = _get_arg_placeholders(
            signature.process_method, input_args, input_kwargs)

    self.process_batch_method = signature.process_batch_method.method_value

    (
        self.placeholders_for_process_batch,
        self.args_for_process_batch,
        self.kwargs_for_process_batch) = _get_arg_placeholders(
            signature.process_batch_method, input_args, input_kwargs)

  def invoke_process(self,
                     windowed_value,  # type: WindowedValue
                     restriction=None,
                     watermark_estimator_state=None,
                     additional_args=None,
                     additional_kwargs=None
                    ):
    # type: (...) -> Iterable[SplitResultResidual]
    if not additional_args:
      additional_args = []
    if not additional_kwargs:
      additional_kwargs = {}

    self.context.set_element(windowed_value)
    # Call for the process function for each window if has windowed side inputs
    # or if the process accesses the window parameter. We can just call it once
    # otherwise as none of the arguments are changing

    residuals = []
    if self.is_splittable:
      if restriction is None:
        # This may be a SDF invoked as an ordinary DoFn on runners that don't
        # understand SDF.  See, e.g. BEAM-11472.
        # In this case, processing the element is simply processing it against
        # the entire initial restriction.
        restriction = self.signature.initial_restriction_method.method_value(
            windowed_value.value)

      with self.splitting_lock:
        self.current_windowed_value = windowed_value
        self.restriction = restriction
        self.watermark_estimator_state = watermark_estimator_state
      try:
        if self.has_windowed_inputs and len(windowed_value.windows) > 1:
          for i, w in enumerate(windowed_value.windows):
            if not self._should_process_window_for_sdf(
                windowed_value, additional_kwargs, i):
              break
            residual = self._invoke_process_per_window(
                WindowedValue(
                    windowed_value.value, windowed_value.timestamp, (w, )),
                additional_args,
                additional_kwargs)
            if residual:
              residuals.append(residual)
        else:
          if self._should_process_window_for_sdf(windowed_value,
                                                 additional_kwargs):
            residual = self._invoke_process_per_window(
                windowed_value, additional_args, additional_kwargs)
            if residual:
              residuals.append(residual)
      finally:
        with self.splitting_lock:
          self.current_windowed_value = None
          self.restriction = None
          self.watermark_estimator_state = None
          self.current_window_index = None
          self.threadsafe_restriction_tracker = None
          self.threadsafe_watermark_estimator = None
    elif self.has_windowed_inputs and len(windowed_value.windows) != 1:
      for w in windowed_value.windows:
        self._invoke_process_per_window(
            WindowedValue(
                windowed_value.value, windowed_value.timestamp, (w, )),
            additional_args,
            additional_kwargs)
    else:
      self._invoke_process_per_window(
          windowed_value, additional_args, additional_kwargs)
    return residuals

  def invoke_process_batch(self,
                     windowed_batch,  # type: WindowedBatch
                     additional_args=None,
                     additional_kwargs=None
                    ):
    # type: (...) -> None

    if not additional_args:
      additional_args = []
    if not additional_kwargs:
      additional_kwargs = {}

    assert isinstance(windowed_batch, HomogeneousWindowedBatch)

    if self.has_windowed_inputs and len(windowed_batch.windows) != 1:
      for w in windowed_batch.windows:
        self._invoke_process_batch_per_window(
            HomogeneousWindowedBatch.of(
                windowed_batch.values,
                windowed_batch.timestamp, (w, ),
                windowed_batch.pane_info),
            additional_args,
            additional_kwargs)
    else:
      self._invoke_process_batch_per_window(
          windowed_batch, additional_args, additional_kwargs)

  def _should_process_window_for_sdf(
      self,
      windowed_value, # type: WindowedValue
      additional_kwargs,
      window_index=None, # type: Optional[int]
  ):
    restriction_tracker = self.invoke_create_tracker(self.restriction)
    watermark_estimator = self.invoke_create_watermark_estimator(
        self.watermark_estimator_state)
    with self.splitting_lock:
      if window_index:
        self.current_window_index = window_index
        if window_index == 0:
          self.stop_window_index = len(windowed_value.windows)
        if window_index == self.stop_window_index:
          return False
      self.threadsafe_restriction_tracker = ThreadsafeRestrictionTracker(
          restriction_tracker)
      self.threadsafe_watermark_estimator = (
          ThreadsafeWatermarkEstimator(watermark_estimator))

    restriction_tracker_param = (
        self.signature.process_method.restriction_provider_arg_name)
    if not restriction_tracker_param:
      raise ValueError(
          'DoFn is splittable but DoFn does not have a '
          'RestrictionTrackerParam defined')
    additional_kwargs[restriction_tracker_param] = (
        RestrictionTrackerView(self.threadsafe_restriction_tracker))
    watermark_param = (
        self.signature.process_method.watermark_estimator_provider_arg_name)
    # When the watermark_estimator is a NoOpWatermarkEstimator, the system
    # will not add watermark_param into the DoFn param list.
    if watermark_param is not None:
      additional_kwargs[watermark_param] = self.threadsafe_watermark_estimator
    return True

  def _invoke_process_per_window(self,
                                 windowed_value,  # type: WindowedValue
                                 additional_args,
                                 additional_kwargs,
                                ):
    # type: (...) -> Optional[SplitResultResidual]
    if self.has_cached_window_args:
      args_for_process, kwargs_for_process = (
          self.args_for_process, self.kwargs_for_process)
    else:
      if self.has_windowed_inputs:
        assert len(windowed_value.windows) <= 1
        window, = windowed_value.windows
      else:
        window = GlobalWindow()
      side_inputs = [si[window] for si in self.side_inputs]
      side_inputs.extend(additional_args)
      args_for_process, kwargs_for_process = util.insert_values_in_args(
          self.args_for_process, self.kwargs_for_process, side_inputs)
      if not self.recalculate_window_args:
        self.args_for_process, self.kwargs_for_process = (
            args_for_process, kwargs_for_process)
        self.has_cached_window_args = True

    # Extract key in the case of a stateful DoFn. Note that in the case of a
    # stateful DoFn, we set during __init__ self.has_windowed_inputs to be
    # True. Therefore, windows will be exploded coming into this method, and
    # we can rely on the window variable being set above.
    if self.user_state_context or self.is_key_param_required:
      try:
        key, unused_value = windowed_value.value
      except (TypeError, ValueError):
        raise ValueError((
            'Input value to a stateful DoFn or KeyParam must be a KV tuple; '
            'instead, got \'%s\'.') % (windowed_value.value, ))

    for i, p in self.placeholders_for_process:
      if core.DoFn.ElementParam == p:
        args_for_process[i] = windowed_value.value
      elif core.DoFn.KeyParam == p:
        args_for_process[i] = key
      elif core.DoFn.WindowParam == p:
        args_for_process[i] = window
      elif core.DoFn.WindowedValueParam == p:
        args_for_process[i] = windowed_value
      elif core.DoFn.TimestampParam == p:
        args_for_process[i] = windowed_value.timestamp
      elif core.DoFn.PaneInfoParam == p:
        args_for_process[i] = windowed_value.pane_info
      elif isinstance(p, core.DoFn.StateParam):
        assert self.user_state_context is not None
        args_for_process[i] = (
            self.user_state_context.get_state(p.state_spec, key, window))
      elif isinstance(p, core.DoFn.TimerParam):
        assert self.user_state_context is not None
        args_for_process[i] = (
            self.user_state_context.get_timer(
                p.timer_spec,
                key,
                window,
                windowed_value.timestamp,
                windowed_value.pane_info))
      elif core.DoFn.BundleFinalizerParam == p:
        args_for_process[i] = self.bundle_finalizer_param
      elif isinstance(p, core.DoFn.BundleContextParam):
        args_for_process[i] = self._bundle_context_values[p][1]
      elif isinstance(p, core.DoFn.SetupContextParam):
        args_for_process[i] = self._setup_context_values[p][1]

    kwargs_for_process = kwargs_for_process or {}

    if additional_kwargs:
      kwargs_for_process.update(additional_kwargs)

    self.output_handler.handle_process_outputs(
        windowed_value,
        self.process_method(*args_for_process, **kwargs_for_process),
        self.threadsafe_watermark_estimator)

    if self.is_splittable:
      assert self.threadsafe_restriction_tracker is not None
      self.threadsafe_restriction_tracker.check_done()
      deferred_status = self.threadsafe_restriction_tracker.deferred_status()
      if deferred_status:
        deferred_restriction, deferred_timestamp = deferred_status
        element = windowed_value.value
        size = self.signature.get_restriction_provider().restriction_size(
            element, deferred_restriction)
        if size < 0:
          raise ValueError('Expected size >= 0 but received %s.' % size)
        current_watermark = (
            self.threadsafe_watermark_estimator.current_watermark())
        estimator_state = (
            self.threadsafe_watermark_estimator.get_estimator_state())
        residual_value = ((element, (deferred_restriction, estimator_state)),
                          size)
        return SplitResultResidual(
            residual_value=windowed_value.with_value(residual_value),
            current_watermark=current_watermark,
            deferred_timestamp=deferred_timestamp)
    return None

  def _invoke_process_batch_per_window(
      self,
      windowed_batch: WindowedBatch,
      additional_args,
      additional_kwargs,
  ):
    # type: (...) -> Optional[SplitResultResidual]

    if self.has_cached_window_batch_args:
      args_for_process_batch, kwargs_for_process_batch = (
          self.args_for_process_batch, self.kwargs_for_process_batch)
    else:
      if self.has_windowed_inputs:
        assert isinstance(windowed_batch, HomogeneousWindowedBatch)
        assert len(windowed_batch.windows) <= 1
        window, = windowed_batch.windows
      else:
        window = GlobalWindow()
      side_inputs = [si[window] for si in self.side_inputs]
      side_inputs.extend(additional_args)
      args_for_process_batch, kwargs_for_process_batch = (
          util.insert_values_in_args(
              self.args_for_process_batch,
              self.kwargs_for_process_batch,
              side_inputs,
          )
      )
      if not self.recalculate_window_args:
        self.args_for_process_batch, self.kwargs_for_process_batch = (
            args_for_process_batch, kwargs_for_process_batch)
        self.has_cached_window_batch_args = True

    for i, p in self.placeholders_for_process_batch:
      if core.DoFn.ElementParam == p:
        args_for_process_batch[i] = windowed_batch.values
      elif core.DoFn.KeyParam == p:
        raise NotImplementedError(
            'https://github.com/apache/beam/issues/21653: Per-key process_batch'
        )
      elif core.DoFn.WindowParam == p:
        args_for_process_batch[i] = window
      elif core.DoFn.TimestampParam == p:
        args_for_process_batch[i] = windowed_batch.timestamp
      elif core.DoFn.PaneInfoParam == p:
        assert isinstance(windowed_batch, HomogeneousWindowedBatch)
        args_for_process_batch[i] = windowed_batch.pane_info
      elif isinstance(p, core.DoFn.StateParam):
        raise NotImplementedError(
            "https://github.com/apache/beam/issues/21653: "
            "Per-key process_batch")
      elif isinstance(p, core.DoFn.TimerParam):
        raise NotImplementedError(
            "https://github.com/apache/beam/issues/21653: "
            "Per-key process_batch")
      elif isinstance(p, core.DoFn.BundleContextParam):
        args_for_process_batch[i] = self._bundle_context_values[p][1]
      elif isinstance(p, core.DoFn.SetupContextParam):
        args_for_process_batch[i] = self._setup_context_values[p][1]

    kwargs_for_process_batch = kwargs_for_process_batch or {}

    if additional_kwargs:
      kwargs_for_process_batch.update(additional_kwargs)

    self.output_handler.handle_process_batch_outputs(
        windowed_batch,
        self.process_batch_method(
            *args_for_process_batch, **kwargs_for_process_batch),
        self.threadsafe_watermark_estimator)

  @staticmethod
  def _try_split(fraction,
      window_index, # type: Optional[int]
      stop_window_index, # type: Optional[int]
      windowed_value, # type: WindowedValue
      restriction,
      watermark_estimator_state,
      restriction_provider, # type: RestrictionProvider
      restriction_tracker, # type: RestrictionTracker
      watermark_estimator, # type: WatermarkEstimator
                 ):
    # type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual], Optional[int]]]

    """Try to split returning a primaries, residuals and a new stop index.

    For non-window observing splittable DoFns we split the current restriction
    and assign the primary and residual to all the windows.

    For window observing splittable DoFns, we:
    1) return a split at a window boundary if the fraction lies outside of the
       current window.
    2) attempt to split the current restriction, if successful then return
       the primary and residual for the current window and an additional
       primary and residual for any fully processed and fully unprocessed
       windows.
    3) fall back to returning a split at the window boundary if possible

    Args:
      window_index: the current index of the window being processed or None
                    if the splittable DoFn is not window observing.
      stop_window_index: the current index to stop processing at or None
                         if the splittable DoFn is not window observing.
      windowed_value: the current windowed value
      restriction: the initial restriction when processing was started.
      watermark_estimator_state: the initial watermark estimator state when
                                 processing was started.
      restriction_provider: the DoFn's restriction provider
      restriction_tracker: the current restriction tracker
      watermark_estimator: the current watermark estimator

    Returns:
      A tuple containing (primaries, residuals, new_stop_index) or None if
      splitting was not possible. new_stop_index will only be set if the
      splittable DoFn is window observing otherwise it will be None.
    """
    def compute_whole_window_split(to_index, from_index):
      restriction_size = restriction_provider.restriction_size(
          windowed_value, restriction)
      if restriction_size < 0:
        raise ValueError(
            'Expected size >= 0 but received %s.' % restriction_size)
      # The primary and residual both share the same value only differing
      # by the set of windows they are in.
      value = ((windowed_value.value, (restriction, watermark_estimator_state)),
               restriction_size)
      primary_restriction = SplitResultPrimary(
          primary_value=WindowedValue(
              value,
              windowed_value.timestamp,
              windowed_value.windows[:to_index])) if to_index > 0 else None
      # Don't report any updated watermarks for the residual since they have
      # not processed any part of the restriction.
      residual_restriction = SplitResultResidual(
          residual_value=WindowedValue(
              value,
              windowed_value.timestamp,
              windowed_value.windows[from_index:stop_window_index]),
          current_watermark=None,
          deferred_timestamp=None) if from_index < stop_window_index else None
      return (primary_restriction, residual_restriction)

    primary_restrictions = []
    residual_restrictions = []

    window_observing = window_index is not None
    # If we are processing each window separately and we aren't on the last
    # window then compute whether the split lies within the current window
    # or a future window.
    if window_observing and window_index != stop_window_index - 1:
      progress = restriction_tracker.current_progress()
      if not progress:
        # Assume no work has been completed for the current window if progress
        # is unavailable.
        from apache_beam.io.iobase import RestrictionProgress
        progress = RestrictionProgress(completed=0, remaining=1)

      scaled_progress = PerWindowInvoker._scale_progress(
          progress, window_index, stop_window_index)
      # Compute the fraction of the remainder relative to the scaled progress.
      # If the value is greater than or equal to progress.remaining_work then we
      # should split at the closest window boundary.
      fraction_of_remainder = scaled_progress.remaining_work * fraction
      if fraction_of_remainder >= progress.remaining_work:
        # The fraction is outside of the current window and hence we will
        # split at the closest window boundary. Favor a split and return the
        # last window if we would have rounded up to the end of the window
        # based upon the fraction.
        new_stop_window_index = min(
            stop_window_index - 1,
            window_index + max(
                1,
                int(
                    round((
                        progress.completed_work +
                        scaled_progress.remaining_work * fraction) /
                          progress.total_work))))
        primary, residual = compute_whole_window_split(
            new_stop_window_index, new_stop_window_index)
        assert primary is not None
        assert residual is not None
        return ([primary], [residual], new_stop_window_index)
      else:
        # The fraction is within the current window being processed so compute
        # the updated fraction based upon the number of windows being processed.
        new_stop_window_index = window_index + 1
        fraction = fraction_of_remainder / progress.remaining_work
        # Attempt to split below, if we can't then we'll compute a split
        # using only window boundaries
    else:
      # We aren't splitting within multiple windows so we don't change our
      # stop index.
      new_stop_window_index = stop_window_index

    # Temporary workaround for [BEAM-7473]: get current_watermark before
    # split, in case watermark gets advanced before getting split results.
    # In worst case, current_watermark is always stale, which is ok.
    current_watermark = (watermark_estimator.current_watermark())
    current_estimator_state = (watermark_estimator.get_estimator_state())
    split = restriction_tracker.try_split(fraction)
    if split:
      primary, residual = split
      element = windowed_value.value
      primary_size = restriction_provider.restriction_size(
          windowed_value.value, primary)
      if primary_size < 0:
        raise ValueError('Expected size >= 0 but received %s.' % primary_size)
      residual_size = restriction_provider.restriction_size(
          windowed_value.value, residual)
      if residual_size < 0:
        raise ValueError('Expected size >= 0 but received %s.' % residual_size)
      # We use the watermark estimator state for the original process call
      # for the primary and the updated watermark estimator state for the
      # residual for the split.
      primary_split_value = ((element, (primary, watermark_estimator_state)),
                             primary_size)
      residual_split_value = ((element, (residual, current_estimator_state)),
                              residual_size)
      windows = (
          windowed_value.windows[window_index],
      ) if window_observing else windowed_value.windows
      primary_restrictions.append(
          SplitResultPrimary(
              primary_value=WindowedValue(
                  primary_split_value, windowed_value.timestamp, windows)))
      residual_restrictions.append(
          SplitResultResidual(
              residual_value=WindowedValue(
                  residual_split_value, windowed_value.timestamp, windows),
              current_watermark=current_watermark,
              deferred_timestamp=None))

      if window_observing:
        assert new_stop_window_index == window_index + 1
        primary, residual = compute_whole_window_split(
            window_index, window_index + 1)
        if primary:
          primary_restrictions.append(primary)
        if residual:
          residual_restrictions.append(residual)
      return (
          primary_restrictions, residual_restrictions, new_stop_window_index)
    elif new_stop_window_index and new_stop_window_index != stop_window_index:
      # If we failed to split but have a new stop index then return a split
      # at the window boundary.
      primary, residual = compute_whole_window_split(
          new_stop_window_index, new_stop_window_index)
      assert primary is not None
      assert residual is not None
      return ([primary], [residual], new_stop_window_index)
    else:
      return None

  def try_split(self, fraction):
    # type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual]]]
    if not self.is_splittable:
      return None

    with self.splitting_lock:
      if not self.threadsafe_restriction_tracker:
        return None

      # Make a local reference to member variables that change references during
      # processing under lock before attempting to split so we have a consistent
      # view of all the references.
      result = PerWindowInvoker._try_split(
          fraction,
          self.current_window_index,
          self.stop_window_index,
          self.current_windowed_value,
          self.restriction,
          self.watermark_estimator_state,
          self.signature.get_restriction_provider(),
          self.threadsafe_restriction_tracker,
          self.threadsafe_watermark_estimator)
      if not result:
        return None

      residuals, primaries, self.stop_window_index = result
      return (residuals, primaries)

  @staticmethod
  def _scale_progress(progress, window_index, stop_window_index):
    # We scale progress based upon the amount of work we will do for one
    # window and have it apply for all windows.
    completed = window_index * progress.total_work + progress.completed_work
    remaining = (
        stop_window_index -
        (window_index + 1)) * progress.total_work + progress.remaining_work
    from apache_beam.io.iobase import RestrictionProgress
    return RestrictionProgress(completed=completed, remaining=remaining)

  def current_element_progress(self):
    # type: () -> Optional[RestrictionProgress]
    if not self.is_splittable:
      return None

    with self.splitting_lock:
      current_window_index = self.current_window_index
      stop_window_index = self.stop_window_index
      threadsafe_restriction_tracker = self.threadsafe_restriction_tracker

    if not threadsafe_restriction_tracker:
      return None

    progress = threadsafe_restriction_tracker.current_progress()
    if not current_window_index or not progress:
      return progress

    # stop_window_index should always be set if current_window_index is set,
    # it is an error otherwise.
    assert stop_window_index
    return PerWindowInvoker._scale_progress(
        progress, current_window_index, stop_window_index)


class DoFnRunner:
  """For internal use only; no backwards-compatibility guarantees.

  A helper class for executing ParDo operations.
  """

  def __init__(self,
               fn,  # type: core.DoFn
               args,
               kwargs,
               side_inputs,  # type: Iterable[sideinputs.SideInputMap]
               windowing,
               tagged_receivers,  # type: Mapping[Optional[str], Receiver]
               step_name=None,  # type: Optional[str]
               logging_context=None,
               state=None,
               scoped_metrics_container=None,
               operation_name=None,
               transform_id=None,
               user_state_context=None,  # type: Optional[userstate.UserStateContext]
              ):
    """Initializes a DoFnRunner.

    Args:
      fn: user DoFn to invoke
      args: positional side input arguments (static and placeholder), if any
      kwargs: keyword side input arguments (static and placeholder), if any
      side_inputs: list of sideinput.SideInputMaps for deferred side inputs
      windowing: windowing properties of the output PCollection(s)
      tagged_receivers: a dict of tag name to Receiver objects
      step_name: the name of this step
      logging_context: DEPRECATED [BEAM-4728]
      state: handle for accessing DoFn state
      scoped_metrics_container: DEPRECATED
      operation_name: The system name assigned by the runner for this operation.
      transform_id: The PTransform Id in the pipeline proto for this DoFn.
      user_state_context: The UserStateContext instance for the current
                          Stateful DoFn.
    """
    # Need to support multiple iterations.
    side_inputs = list(side_inputs)

    self.step_name = step_name
    self.transform_id = transform_id
    self.context = DoFnContext(step_name, state=state)
    self.bundle_finalizer_param = DoFn.BundleFinalizerParam()
    self.execution_context = None  # type: Optional[ExecutionContext]

    do_fn_signature = DoFnSignature(fn)

    # Optimize for the common case.
    main_receivers = tagged_receivers[None]

    # TODO(https://github.com/apache/beam/issues/18886): Remove if block after
    # output counter released.
    if 'outputs_per_element_counter' in RuntimeValueProvider.experiments:
      # TODO(BEAM-3955): Make step_name and operation_name less confused.
      output_counter_name = (
          CounterName('per-element-output-count', step_name=operation_name))
      per_element_output_counter = state._counter_factory.get_counter(
          output_counter_name, Counter.DATAFLOW_DISTRIBUTION).accumulator
    else:
      per_element_output_counter = None

    output_handler = _OutputHandler(
        windowing.windowfn,
        main_receivers,
        tagged_receivers,
        per_element_output_counter,
        getattr(fn, 'output_batch_converter', None),
        getattr(
            do_fn_signature.process_method.method_value,
            '_beam_yields_batches',
            False),
        getattr(
            do_fn_signature.process_batch_method.method_value,
            '_beam_yields_elements',
            False),
    )

    if do_fn_signature.is_stateful_dofn() and not user_state_context:
      raise Exception(
          'Requested execution of a stateful DoFn, but no user state context '
          'is available. This likely means that the current runner does not '
          'support the execution of stateful DoFns.')

    self.do_fn_invoker = DoFnInvoker.create_invoker(
        do_fn_signature,
        output_handler,
        self.context,
        side_inputs,
        args,
        kwargs,
        user_state_context=user_state_context,
        bundle_finalizer_param=self.bundle_finalizer_param)

  def process(self, windowed_value):
    # type: (WindowedValue) -> Iterable[SplitResultResidual]
    try:
      return self.do_fn_invoker.invoke_process(windowed_value)
    except BaseException as exn:
      self._reraise_augmented(exn, windowed_value)
      return []

  def _maybe_sample_exception(
      self, exc_info: Tuple, windowed_value: Optional[WindowedValue]) -> None:

    if self.execution_context is None:
      return

    output_sampler = self.execution_context.output_sampler
    if output_sampler is None:
      return

    output_sampler.sample_exception(
        windowed_value,
        exc_info,
        self.transform_id,
        self.execution_context.instruction_id)

  def process_batch(self, windowed_batch):
    # type: (WindowedBatch) -> None
    try:
      self.do_fn_invoker.invoke_process_batch(windowed_batch)
    except BaseException as exn:
      self._reraise_augmented(exn)

  def process_with_sized_restriction(self, windowed_value):
    # type: (WindowedValue) -> Iterable[SplitResultResidual]
    (element, (restriction, estimator_state)), _ = windowed_value.value
    return self.do_fn_invoker.invoke_process(
        windowed_value.with_value(element),
        restriction=restriction,
        watermark_estimator_state=estimator_state)

  def try_split(self, fraction):
    # type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual]]]
    assert isinstance(self.do_fn_invoker, PerWindowInvoker)
    return self.do_fn_invoker.try_split(fraction)

  def current_element_progress(self):
    # type: () -> Optional[RestrictionProgress]
    assert isinstance(self.do_fn_invoker, PerWindowInvoker)
    return self.do_fn_invoker.current_element_progress()

  def process_user_timer(
      self, timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag):
    try:
      self.do_fn_invoker.invoke_user_timer(
          timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag)
    except BaseException as exn:
      self._reraise_augmented(exn)

  def _invoke_bundle_method(self, bundle_method):
    try:
      self.context.set_element(None)
      bundle_method()
    except BaseException as exn:
      self._reraise_augmented(exn)

  def _invoke_lifecycle_method(self, lifecycle_method):
    try:
      self.context.set_element(None)
      lifecycle_method()
    except BaseException as exn:
      self._reraise_augmented(exn)

  def setup(self):
    # type: () -> None
    self._invoke_lifecycle_method(self.do_fn_invoker.invoke_setup)

  def start(self):
    # type: () -> None
    self._invoke_bundle_method(self.do_fn_invoker.invoke_start_bundle)

  def finish(self):
    # type: () -> None
    self._invoke_bundle_method(self.do_fn_invoker.invoke_finish_bundle)

  def teardown(self):
    # type: () -> None
    self._invoke_lifecycle_method(self.do_fn_invoker.invoke_teardown)

  def finalize(self):
    # type: () -> None
    self.bundle_finalizer_param.finalize_bundle()

  def _reraise_augmented(self, exn, windowed_value=None):
    if getattr(exn, '_tagged_with_step', False) or not self.step_name:
      raise exn
    step_annotation = " [while running '%s']" % self.step_name
    # To emulate exception chaining (not available in Python 2).
    try:
      # Attempt to construct the same kind of exception
      # with an augmented message.
      new_exn = type(exn)(exn.args[0] + step_annotation, *exn.args[1:])
      new_exn._tagged_with_step = True  # Could raise attribute error.
    except:  # pylint: disable=bare-except
      # If anything goes wrong, construct a RuntimeError whose message
      # records the original exception's type and message.
      new_exn = RuntimeError(
          traceback.format_exception_only(type(exn), exn)[-1].strip() +
          step_annotation)
      new_exn._tagged_with_step = True
    exc_info = sys.exc_info()
    _, _, tb = exc_info

    new_exn = new_exn.with_traceback(tb)
    self._maybe_sample_exception(exc_info, windowed_value)
    _LOGGER.exception(new_exn)
    raise new_exn


class OutputHandler(object):
  def handle_process_outputs(
      self, windowed_input_element, results, watermark_estimator=None):
    # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
    raise NotImplementedError

  def handle_process_batch_outputs(
      self, windowed_input_element, results, watermark_estimator=None):
    # type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None
    raise NotImplementedError


class _OutputHandler(OutputHandler):
  """Processes output produced by DoFn method invocations."""

  def __init__(self,
               window_fn,
               main_receivers,  # type: Receiver
               tagged_receivers,  # type: Mapping[Optional[str], Receiver]
               per_element_output_counter,
               output_batch_converter, # type: Optional[BatchConverter]
               process_yields_batches, # type: bool
               process_batch_yields_elements, # type: bool
               ):
    """Initializes ``_OutputHandler``.

    Args:
      window_fn: a windowing function (WindowFn).
      main_receivers: a dict of tag name to Receiver objects.
      tagged_receivers: main receiver object.
      per_element_output_counter: per_element_output_counter of one work_item.
                                  could be none if experimental flag turn off
    """
    self.window_fn = window_fn
    self.main_receivers = main_receivers
    self.tagged_receivers = tagged_receivers
    if (per_element_output_counter is not None and
        per_element_output_counter.is_cythonized):
      self.per_element_output_counter = per_element_output_counter
    else:
      self.per_element_output_counter = None
    self.output_batch_converter = output_batch_converter
    self._process_yields_batches = process_yields_batches
    self._process_batch_yields_elements = process_batch_yields_elements

  def handle_process_outputs(
      self, windowed_input_element, results, watermark_estimator=None):
    # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None

    """Dispatch the result of process computation to the appropriate receivers.

    A value wrapped in a TaggedOutput object will be unwrapped and
    then dispatched to the appropriate indexed output.
    """
    if results is None:
      results = []

    # TODO(https://github.com/apache/beam/issues/20404): Verify that the
    #  results object is a valid iterable type if
    #  performance_runtime_type_check is active, without harming performance
    output_element_count = 0
    for result in results:
      tag, result = self._handle_tagged_output(result)

      if not self._process_yields_batches:
        # process yields elements
        windowed_value = self._maybe_propagate_windowing_info(
            windowed_input_element, result)

        output_element_count += 1

        self._write_value_to_tag(tag, windowed_value, watermark_estimator)
      else:  # process yields batches
        self._verify_batch_output(result)

        if isinstance(result, WindowedBatch):
          assert isinstance(result, HomogeneousWindowedBatch)
          windowed_batch = result

          if (windowed_input_element is not None and
              len(windowed_input_element.windows) != 1):
            windowed_batch.windows *= len(windowed_input_element.windows)
        else:
          windowed_batch = (
              HomogeneousWindowedBatch.from_batch_and_windowed_value(
                  batch=result, windowed_value=windowed_input_element))

        output_element_count += self.output_batch_converter.get_length(
            windowed_batch.values)

        self._write_batch_to_tag(tag, windowed_batch, watermark_estimator)

    # TODO(https://github.com/apache/beam/issues/18886): Remove if block after
    # output counter released. Only enable per_element_output_counter when
    # counter cythonized
    if self.per_element_output_counter is not None:
      self.per_element_output_counter.add_input(output_element_count)

  def handle_process_batch_outputs(
      self, windowed_input_batch, results, watermark_estimator=None):
    # type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None

    """Dispatch the result of process_batch computation to the appropriate
    receivers.

    A value wrapped in a TaggedOutput object will be unwrapped and
    then dispatched to the appropriate indexed output.
    """
    if results is None:
      results = []

    output_element_count = 0
    for result in results:
      tag, result = self._handle_tagged_output(result)

      if not self._process_batch_yields_elements:
        # process_batch yields batches
        assert self.output_batch_converter is not None

        self._verify_batch_output(result)

        if isinstance(result, WindowedBatch):
          assert isinstance(result, HomogeneousWindowedBatch)
          windowed_batch = result

          if (windowed_input_batch is not None and
              len(windowed_input_batch.windows) != 1):
            windowed_batch.windows *= len(windowed_input_batch.windows)
        else:
          windowed_batch = windowed_input_batch.with_values(result)

        output_element_count += self.output_batch_converter.get_length(
            windowed_batch.values)

        self._write_batch_to_tag(tag, windowed_batch, watermark_estimator)
      else:  # process_batch yields elements
        assert isinstance(windowed_input_batch, HomogeneousWindowedBatch)

        windowed_value = self._maybe_propagate_windowing_info(
            windowed_input_batch.as_empty_windowed_value(), result)

        output_element_count += 1

        self._write_value_to_tag(tag, windowed_value, watermark_estimator)

    # TODO(https://github.com/apache/beam/issues/18886): Remove if block after
    # output counter released. Only enable per_element_output_counter when
    # counter cythonized
    if self.per_element_output_counter is not None:
      self.per_element_output_counter.add_input(output_element_count)

  def _maybe_propagate_windowing_info(self, windowed_input_element, result):
    # type: (WindowedValue, Any) -> WindowedValue
    if isinstance(result, WindowedValue):
      windowed_value = result
      if (windowed_input_element is not None and
          len(windowed_input_element.windows) != 1):
        windowed_value.windows *= len(windowed_input_element.windows)
      return windowed_value

    elif isinstance(result, TimestampedValue):
      assign_context = WindowFn.AssignContext(result.timestamp, result.value)
      windowed_value = WindowedValue(
          result.value, result.timestamp, self.window_fn.assign(assign_context))
      if len(windowed_input_element.windows) != 1:
        windowed_value.windows *= len(windowed_input_element.windows)
      return windowed_value

    else:
      return windowed_input_element.with_value(result)

  def _handle_tagged_output(self, result):
    if isinstance(result, TaggedOutput):
      tag = result.tag
      if not isinstance(tag, str):
        raise TypeError('In %s, tag %s is not a string' % (self, tag))
      return tag, result.value
    return None, result

  def _write_value_to_tag(self, tag, windowed_value, watermark_estimator):
    if watermark_estimator is not None:
      watermark_estimator.observe_timestamp(windowed_value.timestamp)

    if tag is None:
      self.main_receivers.receive(windowed_value)
    else:
      self.tagged_receivers[tag].receive(windowed_value)

  def _write_batch_to_tag(self, tag, windowed_batch, watermark_estimator):
    if watermark_estimator is not None:
      for timestamp in windowed_batch.timestamps:
        watermark_estimator.observe_timestamp(timestamp)

    if tag is None:
      self.main_receivers.receive_batch(windowed_batch)
    else:
      self.tagged_receivers[tag].receive_batch(windowed_batch)

  def _verify_batch_output(self, result):
    if isinstance(result, (WindowedValue, TimestampedValue)):
      raise TypeError(
          f"Received {type(result).__name__} from DoFn that was "
          "expected to produce a batch.")

  def start_bundle_outputs(self, results):
    """Validate that start_bundle does not output any elements"""
    if results is None:
      return
    raise RuntimeError(
        'Start Bundle should not output any elements but got %s' % results)

  def finish_bundle_outputs(self, results):
    """Dispatch the result of finish_bundle to the appropriate receivers.

    A value wrapped in a TaggedOutput object will be unwrapped and
    then dispatched to the appropriate indexed output.
    """
    if results is None:
      return

    for result in results:
      tag = None
      if isinstance(result, TaggedOutput):
        tag = result.tag
        if not isinstance(tag, str):
          raise TypeError('In %s, tag %s is not a string' % (self, tag))
        result = result.value

      if isinstance(result, WindowedValue):
        windowed_value = result
      else:
        raise RuntimeError('Finish Bundle should only output WindowedValue ' +\
                           'type but got %s' % type(result))

      if tag is None:
        self.main_receivers.receive(windowed_value)
      else:
        self.tagged_receivers[tag].receive(windowed_value)


class _NoContext(WindowFn.AssignContext):
  """An uninspectable WindowFn.AssignContext."""
  NO_VALUE = object()

  def __init__(self, value, timestamp=NO_VALUE):
    self.value = value
    self._timestamp = timestamp

  @property
  def timestamp(self):
    if self._timestamp is self.NO_VALUE:
      raise ValueError('No timestamp in this context.')
    else:
      return self._timestamp

  @property
  def existing_windows(self):
    raise ValueError('No existing_windows in this context.')


class DoFnState(object):
  """For internal use only; no backwards-compatibility guarantees.

  Keeps track of state that DoFns want, currently, user counters.
  """
  def __init__(self, counter_factory):
    self.step_name = ''
    self._counter_factory = counter_factory

  def counter_for(self, aggregator):
    """Looks up the counter for this aggregator, creating one if necessary."""
    return self._counter_factory.get_aggregator_counter(
        self.step_name, aggregator)


# TODO(robertwb): Replace core.DoFnContext with this.
class DoFnContext(object):
  """For internal use only; no backwards-compatibility guarantees."""
  def __init__(self, label, element=None, state=None):
    self.label = label
    self.state = state
    if element is not None:
      self.set_element(element)

  def set_element(self, windowed_value):
    # type: (Optional[WindowedValue]) -> None
    self.windowed_value = windowed_value

  @property
  def element(self):
    if self.windowed_value is None:
      raise AttributeError('element not accessible in this context')
    else:
      return self.windowed_value.value

  @property
  def timestamp(self):
    if self.windowed_value is None:
      raise AttributeError('timestamp not accessible in this context')
    else:
      return self.windowed_value.timestamp

  @property
  def windows(self):
    if self.windowed_value is None:
      raise AttributeError('windows not accessible in this context')
    else:
      return self.windowed_value.windows
