in opacus/utils/module_utils.py [0:0]
def are_state_dict_equal(sd1: OrderedDict, sd2: OrderedDict):
"""
Compares two state dicts, while logging discrepancies
"""
if len(sd1) != len(sd2):
logger.error(f"Length mismatch: {len(sd1)} vs {len(sd2)}")
return False
for k1, v1 in sd1.items():
# check that all keys are accounted for.
if k1 not in sd2:
logger.error(f"Key missing: {k1} not in {sd2}")
return False
# check that value tensors are equal.
v2 = sd2[k1]
if not torch.allclose(v1, v2):
logger.error(f"Tensor mismatch: {v1} vs {v2}")
return False
return True