def validate_correctness()

in crypten/debug/debug.py [0:0]


def validate_correctness(self, func, func_name, tolerance=0.05):
    import crypten
    import torch

    if not hasattr(torch.tensor([]), func_name):
        return func

    def validation_function(*args, **kwargs):
        with cfg.temp_override({"debug.validation_mode": False}):
            # Compute crypten result
            result_enc = func(*args, **kwargs)
            result = (
                result_enc.get_plain_text()
                if crypten.is_encrypted_tensor(result_enc)
                else result_enc
            )

            args = list(args)

            # Compute torch result for corresponding function
            for i, arg in enumerate(args):
                if crypten.is_encrypted_tensor(arg):
                    args[i] = args[i].get_plain_text()
            for key, value in kwargs.items():
                if crypten.is_encrypted_tensor(value):
                    kwargs[key] = value.get_plain_text()
            reference = getattr(self.get_plain_text(), func_name)(*args, **kwargs)

            if not torch.is_tensor(reference):
                if result_enc != reference:
                    raise ValueError(
                        f"Function {func_name} returned incorrect property value"
                    )
                return result_enc

            # Check sizes match
            if result.size() != reference.size():
                crypten_log(
                    f"Size mismatch: Expected {reference.size()} but got {result.size()}"
                )
                raise ValueError(f"Function {func_name} returned incorrect size")

            # Check that results match
            diff = (result - reference).abs_()
            norm_diff = diff.div(result.abs() + reference.abs()).abs_()
            test_passed = norm_diff.le(tolerance) + diff.le(tolerance * 0.1)
            test_passed = test_passed.gt(0).all().item() == 1
            if not test_passed:
                crypten_log(f"Function {func_name} returned incorrect values")
                crypten_log("Result %s" % result)
                crypten_log("Result - Reference = %s" % (result - reference))
                raise ValueError(f"Function {func_name} returned incorrect values")

        return result_enc

    return validation_function