in crypten/debug/debug.py [0:0]
def validate_correctness(self, func, func_name, tolerance=0.05):
import crypten
import torch
if not hasattr(torch.tensor([]), func_name):
return func
def validation_function(*args, **kwargs):
with cfg.temp_override({"debug.validation_mode": False}):
# Compute crypten result
result_enc = func(*args, **kwargs)
result = (
result_enc.get_plain_text()
if crypten.is_encrypted_tensor(result_enc)
else result_enc
)
args = list(args)
# Compute torch result for corresponding function
for i, arg in enumerate(args):
if crypten.is_encrypted_tensor(arg):
args[i] = args[i].get_plain_text()
for key, value in kwargs.items():
if crypten.is_encrypted_tensor(value):
kwargs[key] = value.get_plain_text()
reference = getattr(self.get_plain_text(), func_name)(*args, **kwargs)
if not torch.is_tensor(reference):
if result_enc != reference:
raise ValueError(
f"Function {func_name} returned incorrect property value"
)
return result_enc
# Check sizes match
if result.size() != reference.size():
crypten_log(
f"Size mismatch: Expected {reference.size()} but got {result.size()}"
)
raise ValueError(f"Function {func_name} returned incorrect size")
# Check that results match
diff = (result - reference).abs_()
norm_diff = diff.div(result.abs() + reference.abs()).abs_()
test_passed = norm_diff.le(tolerance) + diff.le(tolerance * 0.1)
test_passed = test_passed.gt(0).all().item() == 1
if not test_passed:
crypten_log(f"Function {func_name} returned incorrect values")
crypten_log("Result %s" % result)
crypten_log("Result - Reference = %s" % (result - reference))
raise ValueError(f"Function {func_name} returned incorrect values")
return result_enc
return validation_function