def __init__()

in tensorflow_federated/python/core/backends/mapreduce/forms.py [0:0]


  def __init__(self,
               initialize,
               prepare,
               work,
               zero,
               accumulate,
               merge,
               report,
               secure_sum_bitwidth,
               secure_sum_max_input,
               secure_modular_sum_modulus,
               update,
               server_state_label=None,
               client_data_label=None):
    """Constructs a representation of a MapReduce-like iterative process.

    Note: All the computations supplied here as arguments must be TensorFlow
    computations, i.e., instances of `tff.Computation` constructed by the
    `tff.tf_computation` decorator/wrapper.

    Args:
      initialize: The computation that produces the initial server state.
      prepare: The computation that prepares the input for the clients.
      work: The client-side work computation.
      zero: The computation that produces the initial state for accumulators.
      accumulate: The computation that adds a client update to an accumulator.
      merge: The computation to use for merging pairs of accumulators.
      report: The computation that produces the final server-side aggregate for
        the top level accumulator (the global update).
      secure_sum_bitwidth: The computation that produces the bitwidth for
        bitwidth-based secure sums.
      secure_sum_max_input: The computation that produces the maximum input for
        `max_input`-based secure sums.
      secure_modular_sum_modulus: The computation that produces the modulus for
        secure modular sums.
      update: The computation that takes the global update and the server state
        and produces the new server state, as well as server-side output.
      server_state_label: Optional string label for the server state.
      client_data_label: Optional string label for the client data.

    Raises:
      TypeError: If the Python or TFF types of the arguments are invalid or not
        compatible with each other.
      AssertionError: If the manner in which the given TensorFlow computations
        are represented by TFF does not match what this code is expecting (this
        is an internal error that requires code update).
    """
    for label, comp in (
        ('initialize', initialize),
        ('prepare', prepare),
        ('work', work),
        ('zero', zero),
        ('accumulate', accumulate),
        ('merge', merge),
        ('report', report),
        ('secure_sum_bitwidth', secure_sum_bitwidth),
        ('secure_sum_max_input', secure_sum_max_input),
        ('secure_modular_sum_modulus', secure_modular_sum_modulus),
        ('update', update),
    ):
      _check_tensorflow_computation(label, comp)

    prepare_arg_type = prepare.type_signature.parameter
    init_result_type = initialize.type_signature.result
    if not _is_assignable_from_or_both_none(prepare_arg_type, init_result_type):
      raise TypeError(
          'The `prepare` computation expects an argument of type {}, '
          'which does not match the result type {} of `initialize`.'.format(
              prepare_arg_type, init_result_type))

    _check_accepts_tuple('work', work, 2)
    work_2nd_arg_type = work.type_signature.parameter[1]
    prepare_result_type = prepare.type_signature.result
    if not _is_assignable_from_or_both_none(work_2nd_arg_type,
                                            prepare_result_type):
      raise TypeError(
          'The `work` computation expects an argument tuple with type {} as '
          'the second element (the initial client state from the server), '
          'which does not match the result type {} of `prepare`.'.format(
              work_2nd_arg_type, prepare_result_type))

    _check_returns_tuple('work', work, WORK_RESULT_LEN)

    py_typecheck.check_len(accumulate.type_signature.parameter, 2)
    accumulate.type_signature.parameter[0].check_assignable_from(
        zero.type_signature.result)
    accumulate_2nd_arg_type = accumulate.type_signature.parameter[1]
    work_client_update_type = work.type_signature.result[WORK_UPDATE_INDEX]
    if not _is_assignable_from_or_both_none(accumulate_2nd_arg_type,
                                            work_client_update_type):

      raise TypeError(
          'The `accumulate` computation expects a second argument of type {}, '
          'which does not match the expected {} as implied by the type '
          'signature of `work`.'.format(accumulate_2nd_arg_type,
                                        work_client_update_type))
    accumulate.type_signature.parameter[0].check_assignable_from(
        accumulate.type_signature.result)

    py_typecheck.check_len(merge.type_signature.parameter, 2)
    merge.type_signature.parameter[0].check_assignable_from(
        accumulate.type_signature.result)
    merge.type_signature.parameter[1].check_assignable_from(
        accumulate.type_signature.result)
    merge.type_signature.parameter[0].check_assignable_from(
        merge.type_signature.result)

    report.type_signature.parameter.check_assignable_from(
        merge.type_signature.result)

    expected_update_parameter_type = computation_types.to_type([
        initialize.type_signature.result,
        [
            report.type_signature.result,
            # Update takes in the post-summation values of secure aggregation.
            work.type_signature.result[WORK_SECAGG_BITWIDTH_INDEX],
            work.type_signature.result[WORK_SECAGG_MAX_INPUT_INDEX],
            work.type_signature.result[WORK_SECAGG_MODULUS_INDEX],
        ],
    ])
    if not _is_assignable_from_or_both_none(update.type_signature.parameter,
                                            expected_update_parameter_type):
      raise TypeError(
          'The `update` computation expects an argument of type {}, '
          'which does not match the expected {} as implied by the type '
          'signatures of `initialize`, `report`, and `work`.'.format(
              update.type_signature.parameter, expected_update_parameter_type))

    _check_returns_tuple('update', update, 2)

    updated_state_type = update.type_signature.result[0]
    if not prepare_arg_type.is_assignable_from(updated_state_type):
      raise TypeError(
          'The `update` computation returns a result tuple whose first element '
          f'(the updated state type of the server) is type:\n'
          f'{updated_state_type}\n'
          f'which is not assignable to the state parameter type of `prepare`:\n'
          f'{prepare_arg_type}')

    self._initialize = initialize
    self._prepare = prepare
    self._work = work
    self._zero = zero
    self._accumulate = accumulate
    self._merge = merge
    self._report = report
    self._secure_sum_bitwidth = secure_sum_bitwidth
    self._secure_sum_max_input = secure_sum_max_input
    self._secure_modular_sum_modulus = secure_modular_sum_modulus
    self._update = update

    if server_state_label is not None:
      py_typecheck.check_type(server_state_label, str)
    self._server_state_label = server_state_label
    if client_data_label is not None:
      py_typecheck.check_type(client_data_label, str)
    self._client_data_label = client_data_label