def main()

in sample_info/archive/total_gradient.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device', '-d', default='cuda', help='specifies the main device')
    parser.add_argument('--all_device_ids', nargs='+', type=str, default=None,
                        help="If not None, this list specifies devices for multiple GPU training. "
                             "The first device should match with the main device (args.device).")
    parser.add_argument('--batch_size', '-b', type=int, default=256)
    parser.add_argument('--epochs', '-e', type=int, default=400)
    parser.add_argument('--stopping_param', type=int, default=2**30)
    parser.add_argument('--save_iter', '-s', type=int, default=10)
    parser.add_argument('--vis_iter', '-v', type=int, default=10)
    parser.add_argument('--log_dir', '-l', type=str, default=None)
    parser.add_argument('--seed', type=int, default=42)

    # data parameters
    parser.add_argument('--dataset', '-D', type=str, default='corrupt4_mnist')
    parser.add_argument('--data_augmentation', '-A', action='store_true', dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation', action='store_true', default=False)

    # hyper-parameters
    parser.add_argument('--model_class', '-m', type=str, default='ClassifierL2WithGradCollector')

    parser.add_argument('--weight_decay', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'])

    parser.add_argument('--output_dir', '-o', type=str, default='results/stability/mnist-4vs9-1000-samples/')
    args = parser.parse_args()
    print(args)

    # Load data
    # TODO: remove hard coding
    train_data, val_data, test_data, _ = load_data_from_arguments({'dataset': 'mnist',
                                                                   'num_train_examples': 10 * 500},
                                                                  build_loaders=False)
    train_data = BinaryDatasetWrapper(train_data, which_labels=(4, 9))
    val_data = BinaryDatasetWrapper(val_data, which_labels=(4, 9))
    test_data = BinaryDatasetWrapper(test_data, which_labels=(4, 9))

    train_data = ReturnSampleIndexWrapper(train_data)
    val_data = ReturnSampleIndexWrapper(val_data)
    test_data = ReturnSampleIndexWrapper(test_data)

    train_loader, val_loader, test_loader = get_loaders_from_datasets(train_data, val_data, test_data,
                                                                      batch_size=2 ** 30,
                                                                      shuffle_train=False, num_workers=0)

    # Options
    optimization_args = {
        'optimizer': {
            'name': args.optimizer,
            'lr': args.lr,
            'weight_decay': args.weight_decay
        }
    }

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    ts = range(100, 401, 100)

    for t in ts:
        model_class = getattr(methods, args.model_class)

        model = model_class(input_shape=train_loader.dataset[0][0][0].shape,
                            architecture_args=architecture_args,
                            device=args.device,
                            seed=args.seed)

        metrics_list = [metrics.Accuracy(output_key='pred')]

        training.train(model=model,
                       train_loader=train_loader,
                       val_loader=val_loader,
                       epochs=t,
                       save_iter=args.save_iter,
                       vis_iter=2**30,  # NOTE: never visualize
                       optimization_args=optimization_args,
                       log_dir=args.log_dir,
                       args_to_log=args,
                       metrics=metrics_list,
                       device_ids=args.all_device_ids)

        vectors = model._grad_updates

        norms = []
        for i in range(len(train_data)):
            grad_dict = vectors[i]
            norm = 0.0
            for k, v in grad_dict.items():
                norm += torch.norm(v.flatten())
            norms.append(norm)

        quantities = norms

        meta = {
            'description': 'Total gradient update per example. The measures are the norm of total gradient update.',
            'time': t,
            'continuous': False,
            'args': args
        }

        process_results(vectors=vectors, quantities=quantities, meta=meta,
                        exp_name=f'total-grad-t{t}', output_dir=args.output_dir, train_data=train_data.dataset)