def get_mismatched_param_max_difference()

in flsim/utils/fl/common.py [0:0]


    def get_mismatched_param_max_difference(cls, models: List[nn.Module]):

        if len(models) <= 1:
            return 0.0
        dicts = [aModel.state_dict() for aModel in models]
        max_diff = 0
        # compute maximum element-wise difference of model parameters
        for name, param in dicts[0].items():
            for adict in dicts[1:]:
                param_here = adict[name]
                param_diff = torch.max(torch.abs(param - param_here)).item()
                # pyre-fixme[58]: `<` is not supported for operand types
                #  `Union[float, int]` and `int`.
                max_diff = param_diff if (param_diff > max_diff) else max_diff
                # if epsilon is specified, do approx comparison
        return max_diff