in tensorflow_federated/python/learning/templates/client_works.py [0:0]
def __init__(self, initialize_fn, next_fn):
super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)
if not initialize_fn.type_signature.result.is_federated():
raise errors.TemplateNotFederatedError(
f'Provided `initialize_fn` must return a federated type, but found '
f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
f'see a collection of federated types, try wrapping the returned '
f'value in `tff.federated_zip` before returning.')
next_types = (
structure.flatten(next_fn.type_signature.parameter) +
structure.flatten(next_fn.type_signature.result))
if not all([t.is_federated() for t in next_types]):
offending_types = '\n- '.join(
[t for t in next_types if not t.is_federated()])
raise errors.TemplateNotFederatedError(
f'Provided `next_fn` must be a *federated* computation, that is, '
f'operate on `tff.FederatedType`s, but found\n'
f'next_fn with type signature:\n{next_fn.type_signature}\n'
f'The non-federated types are:\n {offending_types}.')
if initialize_fn.type_signature.result.placement != placements.SERVER:
raise errors.TemplatePlacementError(
f'The state controlled by a `ClientWorkProcess` must be placed at '
f'the SERVER, but found type: {initialize_fn.type_signature.result}.')
# Note that state of next_fn being placed at SERVER is now ensured by the
# assertions in base class which would otherwise raise
# TemplateStateNotAssignableError.
next_fn_param = next_fn.type_signature.parameter
if not next_fn_param.is_struct():
raise errors.TemplateNextFnNumArgsError(
f'The `next_fn` must have exactly three input arguments, but found '
f'the following input type which is not a Struct: {next_fn_param}.')
if len(next_fn_param) != 3:
next_param_str = '\n- '.join([str(t) for t in next_fn_param])
raise errors.TemplateNextFnNumArgsError(
f'The `next_fn` must have exactly three input arguments, but found '
f'{len(next_fn_param)} input arguments:\n{next_param_str}')
second_next_param = next_fn_param[1]
client_data_param = next_fn_param[2]
if second_next_param.placement != placements.CLIENTS:
raise errors.TemplatePlacementError(
f'The second input argument of `next_fn` must be placed at CLIENTS '
f'but found {second_next_param}.')
if client_data_param.placement != placements.CLIENTS:
raise errors.TemplatePlacementError(
f'The third input argument of `next_fn` must be placed at CLIENTS '
f'but found {client_data_param}.')
if not client_data_param.member.is_sequence():
raise ClientDataTypeError(
f'The third input argument of `next_fn` must be a sequence but found '
f'{client_data_param}.')
next_fn_result = next_fn.type_signature.result
if (not next_fn_result.result.is_federated() or
next_fn_result.result.placement != placements.CLIENTS):
raise errors.TemplatePlacementError(
f'The "result" attribute of the return type of `next_fn` must be '
f'placed at CLIENTS, but found {next_fn_result.result}.')
if (not next_fn_result.result.member.is_struct_with_python() or
next_fn_result.result.member.python_container is not ClientResult):
raise ClientResultTypeError(
f'The "result" attribute of the return type of `next_fn` must have '
f'the `ClientResult` container, but found {next_fn_result.result}.')
if next_fn_result.measurements.placement != placements.SERVER:
raise errors.TemplatePlacementError(
f'The "measurements" attribute of return type of `next_fn` must be '
f'placed at SERVER, but found {next_fn_result.measurements}.')