def _make_wrapper()

in tensorflow_federated/python/aggregators/robust.py [0:0]


def _make_wrapper(
    clipping_norm: Union[float, estimation_process.EstimationProcess],
    inner_agg_factory: factory.AggregationFactory,
    clipped_count_sum_factory: factory.UnweightedAggregationFactory,
    make_clip_fn: Callable[[factory.ValueType], computation_base.Computation],
    attribute_prefix: str) -> factory.AggregationFactory:
  """Constructs an aggregation factory that applies clip_fn before aggregation.

  Args:
    clipping_norm: Either a float (for fixed norm) or an `EstimationProcess`
      (for adaptive norm) that specifies the norm over which the values should
      be clipped.
    inner_agg_factory: A factory specifying the type of aggregation to be done
      after zeroing.
    clipped_count_sum_factory: A factory specifying the type of aggregation done
      for the `clipped_count` measurement.
    make_clip_fn: A callable that takes a value type and returns a
      tff.computation specifying the clip operation to apply before aggregation.
    attribute_prefix: A str for prefixing state and measurement names.

  Returns:
    An aggregation factory that applies clip_fn before aggregation.
  """
  py_typecheck.check_type(inner_agg_factory,
                          (factory.UnweightedAggregationFactory,
                           factory.WeightedAggregationFactory))
  py_typecheck.check_type(clipped_count_sum_factory,
                          factory.UnweightedAggregationFactory)
  py_typecheck.check_type(clipping_norm,
                          (float, estimation_process.EstimationProcess))
  if isinstance(clipping_norm, float):
    clipping_norm_process = _constant_process(clipping_norm)
  else:
    clipping_norm_process = clipping_norm
  _check_norm_process(clipping_norm_process, 'clipping_norm_process')

  clipped_count_agg_process = clipped_count_sum_factory.create(
      computation_types.to_type(COUNT_TF_TYPE))

  prefix = lambda s: attribute_prefix + s

  def init_fn_impl(inner_agg_process):
    state = collections.OrderedDict([
        (prefix('ing_norm'), clipping_norm_process.initialize()),
        ('inner_agg', inner_agg_process.initialize()),
        (prefix('ed_count_agg'), clipped_count_agg_process.initialize())
    ])
    return intrinsics.federated_zip(state)

  def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None):
    clipping_norm_state, agg_state, clipped_count_state = state

    clipping_norm = clipping_norm_process.report(clipping_norm_state)

    clients_clipping_norm = intrinsics.federated_broadcast(clipping_norm)

    clipped_value, global_norm, was_clipped = intrinsics.federated_map(
        clip_fn, (value, clients_clipping_norm))

    new_clipping_norm_state = clipping_norm_process.next(
        clipping_norm_state, global_norm)

    if weight is None:
      agg_output = inner_agg_process.next(agg_state, clipped_value)
    else:
      agg_output = inner_agg_process.next(agg_state, clipped_value, weight)

    clipped_count_output = clipped_count_agg_process.next(
        clipped_count_state, was_clipped)

    new_state = collections.OrderedDict([
        (prefix('ing_norm'), new_clipping_norm_state),
        ('inner_agg', agg_output.state),
        (prefix('ed_count_agg'), clipped_count_output.state)
    ])
    measurements = collections.OrderedDict([
        (prefix('ing'), agg_output.measurements),
        (prefix('ing_norm'), clipping_norm),
        (prefix('ed_count'), clipped_count_output.result)
    ])

    return measured_process.MeasuredProcessOutput(
        state=intrinsics.federated_zip(new_state),
        result=agg_output.result,
        measurements=intrinsics.federated_zip(measurements))

  if isinstance(inner_agg_factory, factory.WeightedAggregationFactory):

    class WeightedRobustFactory(factory.WeightedAggregationFactory):
      """`WeightedAggregationFactory` factory for clipping large values."""

      def create(
          self, value_type: factory.ValueType, weight_type: factory.ValueType
      ) -> aggregation_process.AggregationProcess:
        _check_value_type(value_type)
        py_typecheck.check_type(weight_type, factory.ValueType.__args__)

        inner_agg_process = inner_agg_factory.create(value_type, weight_type)
        clip_fn = make_clip_fn(value_type)

        @computations.federated_computation()
        def init_fn():
          return init_fn_impl(inner_agg_process)

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type),
            computation_types.at_clients(weight_type))
        def next_fn(state, value, weight):
          return next_fn_impl(state, value, clip_fn, inner_agg_process, weight)

        return aggregation_process.AggregationProcess(init_fn, next_fn)

    return WeightedRobustFactory()
  else:

    class UnweightedRobustFactory(factory.UnweightedAggregationFactory):
      """`UnweightedAggregationFactory` factory for clipping large values."""

      def create(
          self, value_type: factory.ValueType
      ) -> aggregation_process.AggregationProcess:
        _check_value_type(value_type)

        inner_agg_process = inner_agg_factory.create(value_type)
        clip_fn = make_clip_fn(value_type)

        @computations.federated_computation()
        def init_fn():
          return init_fn_impl(inner_agg_process)

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type))
        def next_fn(state, value):
          return next_fn_impl(state, value, clip_fn, inner_agg_process)

        return aggregation_process.AggregationProcess(init_fn, next_fn)

    return UnweightedRobustFactory()