def are_state_dict_equal()

in opacus/utils/module_utils.py [0:0]


def are_state_dict_equal(sd1: OrderedDict, sd2: OrderedDict):
    """
    Compares two state dicts, while logging discrepancies
    """
    if len(sd1) != len(sd2):
        logger.error(f"Length mismatch: {len(sd1)} vs {len(sd2)}")
        return False

    for k1, v1 in sd1.items():
        # check that all keys are accounted for.
        if k1 not in sd2:
            logger.error(f"Key missing: {k1} not in {sd2}")
            return False
        # check that value tensors are equal.
        v2 = sd2[k1]
        if not torch.allclose(v1, v2):
            logger.error(f"Tensor mismatch: {v1} vs {v2}")
            return False
    return True