in make_data_radioactive.py [0:0]
def main(params):
logger = initialize_exp(params)
if params.img_list is None:
params.img_paths = [s.strip() for s in params.img_paths.split(",")]
else:
assert ":" in params.img_paths
chunks = params.img_paths.split(":")
assert len(chunks) == 2
n_start, n_end = int(chunks[0]), int(chunks[1])
img_list = torch.load(params.img_list)
params.img_paths = [img_list[i] for i in range(n_start, n_end)]
print("Image paths", params.img_paths)
# Build model / cuda
ckpt = torch.load(params.marking_network)
params.num_classes = ckpt["params"]["num_classes"]
params.architecture = ckpt['params']['architecture']
print("Building %s model ..." % params.architecture)
model = build_model(params)
model.cuda()
model.load_state_dict({k.replace("module.", ""): v for k, v in ckpt['model'].items()}, strict=False)
model = model.eval()
model.fc = nn.Sequential()
loader = default_loader
transform = getImagenetTransform("none", img_size=params.img_size, crop_size=params.crop_size)
img_orig = [transform(loader(p)).unsqueeze(0) for p in params.img_paths]
# Loading carriers
direction = torch.load(params.carrier_path).cuda()
assert direction.dim() == 2
direction = direction[params.carrier_id:params.carrier_id + 1]
rho = -1
if params.angle is not None:
rho = 1 + np.tan(params.angle)**2
img = [x.clone() for x in img_orig]
# Load differentiable data augmentations
center_da = CenterCrop(params.img_size, params.crop_size)
random_da = RandomResizedCropFlip(params.crop_size)
if params.data_augmentation == "center":
data_augmentation = center_da
elif params.data_augmentation == "random":
data_augmentation = random_da
for i in range(len(img)):
img[i].requires_grad = True
optimizer, schedule = get_optimizer(img, params.optimizer)
if schedule is not None:
schedule = repeat_to(schedule, params.epochs)
img_center = torch.cat([center_da(x, 0).cuda(non_blocking=True) for x in img_orig], dim=0)
# ft_orig = model(center_da(img_orig, 0).cuda(non_blocking=True)).detach()
ft_orig = model(img_center).detach()
if params.angle is not None:
ft_orig = torch.load("/checkpoint/asablayrolles/radioactive_data/imagenet_ckpt_2/features/valid_resnet18_center.pth").cuda()
for iteration in range(params.epochs):
if schedule is not None:
lr = schedule[iteration]
logger.info("New learning rate for %f" % lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Differentially augment images
batch = []
for x in img:
aug_params = data_augmentation.sample_params(x)
aug_img = data_augmentation(x, aug_params)
batch.append(aug_img.cuda(non_blocking=True))
batch = torch.cat(batch, dim=0)
# Forward augmented images
ft = model(batch)
if params.angle is None:
loss_ft = - torch.sum((ft - ft_orig) * direction)
loss_ft_l2 = params.lambda_ft_l2 * torch.norm(ft - ft_orig, dim=1).sum()
else:
dot_product = torch.sum((ft - ft_orig) * direction)
print("Dot product: ", dot_product.item())
if params.half_cone:
loss_ft = - rho * dot_product * torch.abs(dot_product)
else:
loss_ft = - rho * (dot_product ** 2)
loss_ft_l2 = torch.norm(ft - ft_orig)**2
loss_norm = 0
for i in range(len(img)):
loss_norm += params.lambda_l2_img * torch.norm(img[i].cuda(non_blocking=True) - img_orig[i].cuda(non_blocking=True))**2
loss = loss_ft + loss_norm + loss_ft_l2
optimizer.zero_grad()
loss.backward()
optimizer.step()
logs = {
"keyword": "iteration",
"loss": loss.item(),
"loss_ft": loss_ft.item(),
"loss_norm": loss_norm.item(),
"loss_ft_l2": loss_ft_l2.item(),
}
if params.angle is not None:
logs["R"] = - (loss_ft + loss_ft_l2).item()
if schedule is not None:
logs["lr"] = schedule[iteration]
logger.info("__log__:%s" % json.dumps(logs))
for i in range(len(img)):
img[i].data[0] = project_linf(img[i].data[0], img_orig[i][0], params.radius)
if iteration % 10 == 0:
img[i].data[0] = roundPixel(img[i].data[0])
img_new = [numpyPixel(x.data[0]).astype(np.float32) for x in img]
img_old = [numpyPixel(x[0]).astype(np.float32) for x in img_orig]
img_totest = torch.cat([center_da(x, 0).cuda(non_blocking=True) for x in img])
with torch.no_grad():
ft_new = model(img_totest)
logger.info("__log__:%s" % json.dumps({
"keyword": "final",
"psnr": np.mean([psnr(x_new - x_old) for x_new, x_old in zip(img_new, img_old)]),
"ft_direction": torch.mv(ft_new - ft_orig, direction[0]).mean().item(),
"ft_norm": torch.norm(ft_new - ft_orig, dim=1).mean().item(),
"rho": rho,
"R": (rho * torch.dot(ft_new[0] - ft_orig[0], direction[0])**2 - torch.norm(ft_new - ft_orig)**2).item(),
}))
for i in range(len(img)):
img_name = basename(params.img_paths[i])
extension = ".%s" % (img_name.split(".")[-1])
np.save(join(params.dump_path, img_name).replace(extension, ".npy"), img_new[i].astype(np.uint8))