in experiments/feature_corrupt_error.py [0:0]
def train(cfg):
np.random.seed(cfg.rng_seed)
torch.manual_seed(cfg.rng_seed)
log.info(cfg.pretty())
model = instantiate(cfg.model).cuda()
optimizer = instantiate(cfg.optim, model.parameters())
train_dataset = instantiate(cfg.train)
test_dataset = instantiate(cfg.test)
lr_policy = instantiate(cfg.optim.lr_policy)
feature_extractor = instantiate(cfg.ft)
feature_extractor.train()
if cfg.aug_feature_file and os.path.exists(cfg.aug_feature_file):
log.info("Found feature file. Loading from {}".format(cfg.aug_feature_file))
data = np.load(cfg.aug_feature_file)
augmentation_features = data['features']
indices = data['indices']
else:
ft_augmentation_dataset = instantiate(cfg.ft_augmentation)
indices = np.random.choice(np.arange(len(ft_augmentation_dataset)), size=cfg.num_images, replace=False)
ft_augmentation_dataset = ft_augmentation_dataset.serialize(indices)
augmentation_features = extract_features(feature_extractor,
ft_augmentation_dataset,
cfg.ft_augmentation.batch_size,
cfg.data_loader,
average=True,
average_num=len(indices))
#nf, lf = augmentation_features.shape
#augmentation_features = np.mean(augmentation_features.reshape(len(indices), nf//len(indices), lf), axis=0)
if cfg.aug_feature_file:
np.savez(cfg.aug_feature_file, features=augmentation_features, indices=indices)
aug_strings = cfg.ft_corrupt.aug_string.split("--")
for aug in aug_strings:
with omegaconf.open_dict(cfg):
ft_corrupt_dataset = instantiate(cfg.ft_corrupt, aug_string=aug)
ft_corrupt_dataset = ft_corrupt_dataset.serialize(indices)
corruption_features = extract_features(feature_extractor,
ft_corrupt_dataset,
cfg.ft_corrupt.batch_size,
cfg.data_loader,
average=True,
average_num=len(indices))
nf, lf = corruption_features.shape
#corruption_features = np.mean(corruption_features.reshape(len(indices), nf//len(indices), lf), axis=0)
augmentation_features = augmentation_features.reshape(-1, 1, lf)
corruption_features = corruption_features.reshape(1, -1, lf)
mean_aug = np.mean(augmentation_features.reshape(-1,lf), axis=0)
mean_corr = np.mean(corruption_features.reshape(-1,lf), axis=0)
mmd = np.linalg.norm(mean_aug-mean_corr, axis=0)
msd = np.min(np.linalg.norm(augmentation_features.reshape(-1,lf)-mean_corr.reshape(1,lf),axis=1),axis=0)
stats = {"_type" : aug,
"mmd" : str(mmd),
"msd" : str(msd),
}
lu.log_json_stats(stats)
train_net(model=model,
optimizer=optimizer,
train_dataset=train_dataset,
batch_size=cfg.train.batch_size,
max_epoch=cfg.optim.max_epoch,
loader_params=cfg.data_loader,
lr_policy=lr_policy,
save_period=cfg.train.checkpoint_period,
weights=cfg.train.weights
)
err = test_net(model=model,
test_dataset=test_dataset,
batch_size=cfg.test.batch_size,
loader_params=cfg.data_loader,
output_name='test_epoch')
test_corrupt_net(model=model,
corrupt_cfg=cfg.corrupt,
batch_size=cfg.corrupt.batch_size,
loader_params=cfg.data_loader,
aug_string=cfg.corrupt.aug_string,
clean_err=err,
mCE_denom=cfg.corrupt.mCE_baseline_file)