in online_attacks/classifiers/madry/__init__.py [0:0]
def load_madry_model(dataname, weights_path, device="cuda"):
if dataname == "mnist":
try:
from .madry_mnist.model import Model
print("madry_mnist found and imported")
except (ImportError, ModuleNotFoundError):
print(
"madry_mnist not found, please install madry challenge with install_madry_challenge.sh"
)
def _process_inputs_val(val):
return val.view(val.shape[0], 784)
def _process_grads_val(val):
return val.view(val.shape[0], 1, 28, 28)
elif dataname == "cifar":
try:
from .madry_cifar.model import Model
print("madry_cifar found and imported")
except (ImportError, ModuleNotFoundError):
print(
"madry_cifar not found, please install madry challenge with install_madry_challenge.sh"
)
from functools import partial
Model = partial(Model, mode="eval")
def _process_inputs_val(val):
return 255.0 * val.permute(0, 2, 3, 1)
def _process_grads_val(val):
return val.permute(0, 3, 1, 2) / 255.0
else:
raise ValueError(dataname)
def _wrap_forward(forward):
def new_forward(inputs_val):
return forward(_process_inputs_val(inputs_val))
return new_forward
def _wrap_backward(backward):
def new_backward(inputs_val, logits_grad_val):
return _process_grads_val(
backward(_process_inputs_val(*inputs_val), *logits_grad_val)
)
return new_backward
ptmodel = TorchWrappedModel(WrappedTfModel(weights_path, Model), device)
model = BPDAWrapper(
forward=_wrap_forward(ptmodel.forward),
backward=_wrap_backward(ptmodel.backward),
)
return model