def load_madry_model()

in online_attacks/classifiers/madry/__init__.py [0:0]


def load_madry_model(dataname, weights_path, device="cuda"):
    if dataname == "mnist":
        try:
            from .madry_mnist.model import Model

            print("madry_mnist found and imported")
        except (ImportError, ModuleNotFoundError):
            print(
                "madry_mnist not found, please install madry challenge with install_madry_challenge.sh"
            )

        def _process_inputs_val(val):
            return val.view(val.shape[0], 784)

        def _process_grads_val(val):
            return val.view(val.shape[0], 1, 28, 28)

    elif dataname == "cifar":

        try:
            from .madry_cifar.model import Model

            print("madry_cifar found and imported")
        except (ImportError, ModuleNotFoundError):
            print(
                "madry_cifar not found, please install madry challenge with install_madry_challenge.sh"
            )

        from functools import partial

        Model = partial(Model, mode="eval")

        def _process_inputs_val(val):
            return 255.0 * val.permute(0, 2, 3, 1)

        def _process_grads_val(val):
            return val.permute(0, 3, 1, 2) / 255.0

    else:
        raise ValueError(dataname)

    def _wrap_forward(forward):
        def new_forward(inputs_val):
            return forward(_process_inputs_val(inputs_val))

        return new_forward

    def _wrap_backward(backward):
        def new_backward(inputs_val, logits_grad_val):
            return _process_grads_val(
                backward(_process_inputs_val(*inputs_val), *logits_grad_val)
            )

        return new_backward

    ptmodel = TorchWrappedModel(WrappedTfModel(weights_path, Model), device)
    model = BPDAWrapper(
        forward=_wrap_forward(ptmodel.forward),
        backward=_wrap_backward(ptmodel.backward),
    )

    return model