def main()

in sample_info/scripts/synthetic_example_make_informativeness_video.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str,
                        default='sample_info/configs/1hidden-mlp-n1024-binary-mnist.json')
    parser.add_argument('--device', '-d', default='cuda', help='specifies the main device')
    parser.add_argument('--seed', type=int, default=42)

    # hyper-parameters
    parser.add_argument('--model_class', '-m', type=str, default='ClassifierL2')

    parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate')
    args = parser.parse_args()
    print(args)

    # Build data
    data_X, data_Y = get_synthetic_data(args.seed)
    half = len(data_X) // 2
    train_data = TensorDataset(torch.tensor(data_X[:half]).float(), torch.tensor(data_Y[:half]).long().reshape((-1, 1)))
    val_data = TensorDataset(torch.tensor(data_X[half:]).float(), torch.tensor(data_Y[half:]).long().reshape((-1, 1)))

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_data[0][0].shape,
                        architecture_args=architecture_args,
                        device=args.device)

    jacobian_estimator = JacobianEstimator(projection='none')
    jacobians = jacobian_estimator.compute_jacobian(model=model, dataset=train_data, output_key='pred', cpu=False)
    # val_jacobians = get_jacobians(model=model, dataset=val_data, output_key='pred', cpu=False)
    init_preds = utils.apply_on_dataset(model=model, dataset=train_data, cpu=False)['pred']
    # val_init_preds = utils.apply_on_dataset(model=model, dataset=val_data, cpu=False)['pred']
    init_params = dict(model.named_parameters())
    ntk = compute_ntk(jacobians=jacobians)

    Y = [torch.tensor([y]) for (x, y) in train_data]
    Y = torch.stack(Y).float().to(ntk.device)

    ts = range(0, 1001, 20)
    for idx, t in tqdm(enumerate(ts), desc='main loop', total=len(ts)):
        _, q = weight_stability(t=t,
                                n=len(train_data),
                                eta=args.lr / len(train_data),
                                init_params=init_params,
                                jacobians=jacobians,
                                ntk=ntk,
                                init_preds=init_preds,
                                Y=Y,
                                continuous=False,
                                return_change_vectors=False,
                                scale_by_hessian=False)

        fig, ax = plot(q, data_X=data_X, data_Y=data_Y, half=half, t=t)
        file_path = f'sample_info/plots/synthetic-data/weight-{idx:04d}.png'
        utils.make_path(os.path.dirname(file_path))
        fig.savefig(file_path)
        plt.close()

    # save video
    cur_dir = os.path.abspath(os.curdir)
    os.chdir('sample_info/plots/synthetic-data')
    os.system("ffmpeg -r 2 -i weight-%04d.png movie.webm")
    os.chdir(cur_dir)