in fairscale/utils/testing.py [0:0]
def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: Optional[str] = None) -> bool:
"""
Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values.
"""
if type(a) is not type(b):
if raise_exception:
raise ValueError(f"type mismatch {type(a)} vs. {type(b)}")
return False
if isinstance(a, dict):
if set(a.keys()) != set(b.keys()):
if raise_exception:
raise ValueError(f"keys mismatch {a.keys()} vs. {b.keys()}")
return False
for k in a.keys():
if not objects_are_equal(a[k], b[k], raise_exception, k):
return False
return True
elif isinstance(a, (list, tuple, set)):
if len(a) != len(b):
if raise_exception:
raise ValueError(f"length mismatch {len(a)} vs. {len(b)}")
return False
return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
elif torch.is_tensor(a):
try:
# assert_allclose doesn't strictly test shape, dtype and device
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
if not shape_dtype_device_match:
if raise_exception:
msg = f"sizes: {a.size()} vs. {b.size()}, "
msg += f"types: {a.dtype} vs. {b.dtype}, "
msg += f"device: {a.device} vs. {b.device}"
raise AssertionError(msg)
else:
return False
# assert_allclose.
torch.testing.assert_allclose(a, b)
return True
except (AssertionError, RuntimeError) as e:
if raise_exception:
if dict_key and isinstance(e, AssertionError):
# Add dict key to the assertion error.
msg = e.args[0]
new_msg = f"For dict key '{dict_key}': {msg}"
raise AssertionError(new_msg) from None
else:
raise e
else:
return False
else:
return a == b