in tensorflow_federated/python/aggregators/robust.py [0:0]
def _make_wrapper(
clipping_norm: Union[float, estimation_process.EstimationProcess],
inner_agg_factory: factory.AggregationFactory,
clipped_count_sum_factory: factory.UnweightedAggregationFactory,
make_clip_fn: Callable[[factory.ValueType], computation_base.Computation],
attribute_prefix: str) -> factory.AggregationFactory:
"""Constructs an aggregation factory that applies clip_fn before aggregation.
Args:
clipping_norm: Either a float (for fixed norm) or an `EstimationProcess`
(for adaptive norm) that specifies the norm over which the values should
be clipped.
inner_agg_factory: A factory specifying the type of aggregation to be done
after zeroing.
clipped_count_sum_factory: A factory specifying the type of aggregation done
for the `clipped_count` measurement.
make_clip_fn: A callable that takes a value type and returns a
tff.computation specifying the clip operation to apply before aggregation.
attribute_prefix: A str for prefixing state and measurement names.
Returns:
An aggregation factory that applies clip_fn before aggregation.
"""
py_typecheck.check_type(inner_agg_factory,
(factory.UnweightedAggregationFactory,
factory.WeightedAggregationFactory))
py_typecheck.check_type(clipped_count_sum_factory,
factory.UnweightedAggregationFactory)
py_typecheck.check_type(clipping_norm,
(float, estimation_process.EstimationProcess))
if isinstance(clipping_norm, float):
clipping_norm_process = _constant_process(clipping_norm)
else:
clipping_norm_process = clipping_norm
_check_norm_process(clipping_norm_process, 'clipping_norm_process')
clipped_count_agg_process = clipped_count_sum_factory.create(
computation_types.to_type(COUNT_TF_TYPE))
prefix = lambda s: attribute_prefix + s
def init_fn_impl(inner_agg_process):
state = collections.OrderedDict([
(prefix('ing_norm'), clipping_norm_process.initialize()),
('inner_agg', inner_agg_process.initialize()),
(prefix('ed_count_agg'), clipped_count_agg_process.initialize())
])
return intrinsics.federated_zip(state)
def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None):
clipping_norm_state, agg_state, clipped_count_state = state
clipping_norm = clipping_norm_process.report(clipping_norm_state)
clients_clipping_norm = intrinsics.federated_broadcast(clipping_norm)
clipped_value, global_norm, was_clipped = intrinsics.federated_map(
clip_fn, (value, clients_clipping_norm))
new_clipping_norm_state = clipping_norm_process.next(
clipping_norm_state, global_norm)
if weight is None:
agg_output = inner_agg_process.next(agg_state, clipped_value)
else:
agg_output = inner_agg_process.next(agg_state, clipped_value, weight)
clipped_count_output = clipped_count_agg_process.next(
clipped_count_state, was_clipped)
new_state = collections.OrderedDict([
(prefix('ing_norm'), new_clipping_norm_state),
('inner_agg', agg_output.state),
(prefix('ed_count_agg'), clipped_count_output.state)
])
measurements = collections.OrderedDict([
(prefix('ing'), agg_output.measurements),
(prefix('ing_norm'), clipping_norm),
(prefix('ed_count'), clipped_count_output.result)
])
return measured_process.MeasuredProcessOutput(
state=intrinsics.federated_zip(new_state),
result=agg_output.result,
measurements=intrinsics.federated_zip(measurements))
if isinstance(inner_agg_factory, factory.WeightedAggregationFactory):
class WeightedRobustFactory(factory.WeightedAggregationFactory):
"""`WeightedAggregationFactory` factory for clipping large values."""
def create(
self, value_type: factory.ValueType, weight_type: factory.ValueType
) -> aggregation_process.AggregationProcess:
_check_value_type(value_type)
py_typecheck.check_type(weight_type, factory.ValueType.__args__)
inner_agg_process = inner_agg_factory.create(value_type, weight_type)
clip_fn = make_clip_fn(value_type)
@computations.federated_computation()
def init_fn():
return init_fn_impl(inner_agg_process)
@computations.federated_computation(
init_fn.type_signature.result,
computation_types.at_clients(value_type),
computation_types.at_clients(weight_type))
def next_fn(state, value, weight):
return next_fn_impl(state, value, clip_fn, inner_agg_process, weight)
return aggregation_process.AggregationProcess(init_fn, next_fn)
return WeightedRobustFactory()
else:
class UnweightedRobustFactory(factory.UnweightedAggregationFactory):
"""`UnweightedAggregationFactory` factory for clipping large values."""
def create(
self, value_type: factory.ValueType
) -> aggregation_process.AggregationProcess:
_check_value_type(value_type)
inner_agg_process = inner_agg_factory.create(value_type)
clip_fn = make_clip_fn(value_type)
@computations.federated_computation()
def init_fn():
return init_fn_impl(inner_agg_process)
@computations.federated_computation(
init_fn.type_signature.result,
computation_types.at_clients(value_type))
def next_fn(state, value):
return next_fn_impl(state, value, clip_fn, inner_agg_process)
return aggregation_process.AggregationProcess(init_fn, next_fn)
return UnweightedRobustFactory()