in flsim/utils/fl/common.py [0:0]
def get_mismatched_param_max_difference(cls, models: List[nn.Module]):
if len(models) <= 1:
return 0.0
dicts = [aModel.state_dict() for aModel in models]
max_diff = 0
# compute maximum element-wise difference of model parameters
for name, param in dicts[0].items():
for adict in dicts[1:]:
param_here = adict[name]
param_diff = torch.max(torch.abs(param - param_here)).item()
# pyre-fixme[58]: `<` is not supported for operand types
# `Union[float, int]` and `int`.
max_diff = param_diff if (param_diff > max_diff) else max_diff
# if epsilon is specified, do approx comparison
return max_diff