in experiments/closest_augs.py [0:0]
def train(cfg):
np.random.seed(cfg.rng_seed)
torch.manual_seed(cfg.rng_seed)
log.info(cfg.pretty())
model = instantiate(cfg.model).cuda()
optimizer = instantiate(cfg.optim, model.parameters())
lr_policy = instantiate(cfg.optim.lr_policy)
if cfg.transform_file and os.path.exists(cfg.transform_file):
log.info("Transforms found, loading feature extractor is unnecessary. Skipping.")
else:
feature_extractor = instantiate(cfg.ft)
feature_extractor.train()
if cfg.transform_file and os.path.exists(cfg.transform_file):
log.info("Transforms found, feature extraction is unnecessary. Skipping.")
elif cfg.aug_feature_file and os.path.exists(cfg.aug_feature_file):
log.info("Found feature file. Loading from {}".format(cfg.aug_feature_file))
data = np.load(cfg.aug_feature_file)
augmentation_features = data['features']
indices = data['indices']
transforms = data['transforms']
else:
ft_augmentation_dataset = instantiate(cfg.ft_augmentation)
transforms = ft_augmentation_dataset.transform_list
indices = np.random.choice(np.arange(len(ft_augmentation_dataset)), size=cfg.num_images, replace=False)
ft_augmentation_dataset = ft_augmentation_dataset.serialize(indices)
augmentation_features = extract_features(feature_extractor,
ft_augmentation_dataset,
cfg.ft_augmentation.batch_size,
cfg.data_loader,
average=True,
average_num=len(indices))
if cfg.aug_feature_file:
np.savez(cfg.aug_feature_file,
features=augmentation_features,
indices=indices,
transforms=transforms)
if cfg.transform_file and os.path.exists(cfg.transform_file):
log.info("Found transform file. Loading from {}.".format(cfg.transform_file))
sorted_transforms = np.load(cfg.transform_file)
else:
aug_strings = cfg.ft_corrupt.aug_string.split("--")
distances = np.zeros((len(augmentation_features), len(aug_strings)))
for i, aug in enumerate(aug_strings):
with omegaconf.open_dict(cfg):
ft_corrupt_dataset = instantiate(cfg.ft_corrupt, aug_string=aug)
if cfg.num_corrupt_images and i==0:
indices = np.random.choice(np.arange(len(ft_corrupt_dataset)), size=cfg.num_corrupt_images, replace=False)
ft_corrupt_dataset = ft_corrupt_dataset.serialize(indices)
corruption_features = extract_features(feature_extractor,
ft_corrupt_dataset,
cfg.ft_corrupt.batch_size,
cfg.data_loader,
average=True)
corruption_features = corruption_features.reshape(1, -1)
dists = np.linalg.norm(augmentation_features - corruption_features, axis=-1)
distances[:,i] = dists
sorted_dist_args = individual_sort(distances)
sorted_transforms = transforms[sorted_dist_args]
if cfg.transform_file:
np.save(cfg.transform_file, sorted_transforms)
train_dataset = instantiate(cfg.train)
if cfg.selection_type == 'closest':
train_dataset.transform_list = sorted_transforms[cfg.offset:cfg.offset+cfg.num_transforms]
elif cfg.selection_type == 'farthest':
train_dataset.transform_list = sorted_transforms[-cfg.offset-cfg.num_transforms:-cfg.offset]\
if cfg.offset != 0 else sorted_transforms[-cfg.num_transforms:]
else:
train_dataset.transform_list = sorted_transforms[np.random.choice(np.arange(len(sorted_transforms)), size=cfg.num_transforms, replace=False)]
test_dataset = instantiate(cfg.test)
train_net(model=model,
optimizer=optimizer,
train_dataset=train_dataset,
batch_size=cfg.train.batch_size,
max_epoch=cfg.optim.max_epoch,
loader_params=cfg.data_loader,
lr_policy=lr_policy,
save_period=cfg.train.checkpoint_period,
weights=cfg.train.weights
)
err = test_net(model=model,
test_dataset=test_dataset,
batch_size=cfg.test.batch_size,
loader_params=cfg.data_loader,
output_name='test_epoch')
test_corrupt_net(model=model,
corrupt_cfg=cfg.corrupt,
batch_size=cfg.corrupt.batch_size,
loader_params=cfg.data_loader,
aug_string=cfg.corrupt.aug_string,
clean_err=err,
mCE_denom=cfg.corrupt.mCE_baseline_file)