in sample_info/scripts/synthetic_example_make_informativeness_video.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str,
default='sample_info/configs/1hidden-mlp-n1024-binary-mnist.json')
parser.add_argument('--device', '-d', default='cuda', help='specifies the main device')
parser.add_argument('--seed', type=int, default=42)
# hyper-parameters
parser.add_argument('--model_class', '-m', type=str, default='ClassifierL2')
parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate')
args = parser.parse_args()
print(args)
# Build data
data_X, data_Y = get_synthetic_data(args.seed)
half = len(data_X) // 2
train_data = TensorDataset(torch.tensor(data_X[:half]).float(), torch.tensor(data_Y[:half]).long().reshape((-1, 1)))
val_data = TensorDataset(torch.tensor(data_X[half:]).float(), torch.tensor(data_Y[half:]).long().reshape((-1, 1)))
with open(args.config, 'r') as f:
architecture_args = json.load(f)
model_class = getattr(methods, args.model_class)
model = model_class(input_shape=train_data[0][0].shape,
architecture_args=architecture_args,
device=args.device)
jacobian_estimator = JacobianEstimator(projection='none')
jacobians = jacobian_estimator.compute_jacobian(model=model, dataset=train_data, output_key='pred', cpu=False)
# val_jacobians = get_jacobians(model=model, dataset=val_data, output_key='pred', cpu=False)
init_preds = utils.apply_on_dataset(model=model, dataset=train_data, cpu=False)['pred']
# val_init_preds = utils.apply_on_dataset(model=model, dataset=val_data, cpu=False)['pred']
init_params = dict(model.named_parameters())
ntk = compute_ntk(jacobians=jacobians)
Y = [torch.tensor([y]) for (x, y) in train_data]
Y = torch.stack(Y).float().to(ntk.device)
ts = range(0, 1001, 20)
for idx, t in tqdm(enumerate(ts), desc='main loop', total=len(ts)):
_, q = weight_stability(t=t,
n=len(train_data),
eta=args.lr / len(train_data),
init_params=init_params,
jacobians=jacobians,
ntk=ntk,
init_preds=init_preds,
Y=Y,
continuous=False,
return_change_vectors=False,
scale_by_hessian=False)
fig, ax = plot(q, data_X=data_X, data_Y=data_Y, half=half, t=t)
file_path = f'sample_info/plots/synthetic-data/weight-{idx:04d}.png'
utils.make_path(os.path.dirname(file_path))
fig.savefig(file_path)
plt.close()
# save video
cur_dir = os.path.abspath(os.curdir)
os.chdir('sample_info/plots/synthetic-data')
os.system("ffmpeg -r 2 -i weight-%04d.png movie.webm")
os.chdir(cur_dir)