def run()

in run.py [0:0]


def run(args_override={}):
    run_dir = "runs"
    disable_cuda = False
    checkpoint_dir = "/checkpoint/{}/checkpoints".format(os.environ["USER"])
    default_momentum = 0.9
    default_lr = 0.1 
    default_decay = 0.0001 
    default_epochs = 300
    default_batch_size = 128 
    default_tail_average = 0.0
    default_tail_average_all = False

    default_half_precision = False
    default_method = "sgd" #"svrg" #"sgd"

    default_log_diagnostics = False
    default_log_diagnostics_every_epoch = False
    default_log_fast_diagnostics = False

    default_logfname = "log"
    default_log_interval = 20
    default_transform_locking = True

    default_per_block = False
    default_dropout = False
    default_batchnorm = True
    default_vr_from_epoch = 1 # 1 is first epoch.

    default_calculate_train_loss_each_epoch = False
    default_save_model = False # Saving every 10 epochs
    default_resume = False
    default_resume_from = ""

    # It will always resume from a checkpoint
    default_full_checkpointing = False

    default_second_lambda = 0.5 
    default_inner_steps = 10
    default_clamping = 1000.0
    default_vr_bn_at_recalibration = True 
    default_variance_reg = 0.01
    default_lr_reduction = "150-225" 
    default_L = 1.0
    default_architecture = "default"
    default_problem = "cifar10"

    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch optimization testbed')
    parser.add_argument('--problem', type=str, default=default_problem,
        help='Problem instance (default: ' + default_problem + ')')
    parser.add_argument('--method', type=str, default=default_method,
        help='Optimization method (default: ' + default_method + ')')
    parser.add_argument('--batch-size', type=int,
        default=default_batch_size, metavar='M',
        help='minibatch size (default: ' + str(default_batch_size) + ')')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=default_epochs, metavar='N',
        help='number of epochs to train (default: ' + str(default_epochs) + ')')
    parser.add_argument('--lr', type=float, default=default_lr, metavar='LR',
        help='learning rate (default: ' + str(default_lr) + ')')
    parser.add_argument('--momentum', type=float, default=default_momentum,
        metavar='M',
        help='SGD momentum (default: ' + str(default_momentum) + ')')
    parser.add_argument('--decay', type=float, default=default_decay,
        metavar='M',
        help='SGD weight decay (default: ' + str(default_decay) + ')')
    parser.add_argument('--L', type=float, default=default_L,
        metavar='L',
        help='SGD L estimate (default: ' + str(default_L) + ')')

    parser.add_argument('--tail_average', type=float, default=default_tail_average,
        help='Use tail averaging of iterates every epoch, with the given tail fraction (default: ' + str(default_tail_average) + ')')
    parser.add_argument('--tail_average_all', type=str2bool, default=default_tail_average_all,
        help='Apply tail aveaging either to the whole run or just after the first lr reduction (default: ' + str(default_tail_average_all) + ')')

    parser.add_argument('--clamping', type=float, default=default_clamping,
        metavar='C', help='APS clamping (default: ' + str(default_clamping) + ')')
    parser.add_argument('--inner_steps', type=int, default=default_inner_steps, metavar='N',
        help='Inner steps for implicit methods (default: ' + str(default_inner_steps) + ')')
    parser.add_argument('--vr_from_epoch', type=int, default=default_vr_from_epoch,
        help='Start VR (if in use) at this epoch (default: ' + str(default_vr_from_epoch) + ')')

    parser.add_argument('--no-cuda', action='store_true', default=disable_cuda,
        help='disables CUDA training')
    parser.add_argument('--half_precision', type=str2bool, default=default_half_precision,
        help='Use half precision (default: ' + str(default_half_precision) + ')')

    parser.add_argument('--second_lambda', type=float, default=default_second_lambda,
        metavar='D',
        help='A second linear interpolation factor used by some algorithms (default: '
        + str(default_second_lambda) + ')')
    parser.add_argument('--variance_reg', type=float, default=default_variance_reg,
        metavar='D',
        help='Added to the variance in reparam to prevent divide by 0 problems (default: '
        + str(default_variance_reg) + ')')

    parser.add_argument('--architecture', type=str, default=default_architecture,
        help='architecture (default: ' + default_architecture + ')')

    parser.add_argument('--seed', type=int, default=1, metavar='S',
        help='random seed (default: 1)')
    parser.add_argument('--dropout', type=str2bool, default=default_dropout,
        help='Use dropout (default: ' + str(default_dropout) + ')')
    parser.add_argument('--batchnorm', type=str2bool, default=default_batchnorm,
        help='Use batchnorm (default: ' + str(default_batchnorm) + ')')

    parser.add_argument('--transform_locking', type=str2bool, default=default_transform_locking,
        help='Transform locking: ' + str(default_transform_locking) + ')')

    parser.add_argument('--log_diagnostics', type=str2bool, default=default_log_diagnostics,
        help='produce and log expensive diagnostics (default: ' + str(default_log_diagnostics) + ')')

    parser.add_argument('--log_diagnostics_every_epoch', type=str2bool, default=default_log_diagnostics_every_epoch,
        help='do full diagnostics every epoch instead of every 10')

    parser.add_argument('--log_diagnostics_deciles', type=str2bool, default=False,
        help='full diagnostics at every 10% of the epoch')

    parser.add_argument('--log_fast_diagnostics', type=str2bool, default=default_log_fast_diagnostics,
        help='produce and log cheap diagnostics (default: ' + str(default_log_fast_diagnostics) + ')')

    parser.add_argument('--logfname', type=str, default=default_logfname,
        help='Prefix for diagonstic log files (default: ' + str(default_logfname) + ')')

    parser.add_argument('--save_model', type=str2bool, default=default_save_model,
        help='Save model every 10 epochs (default: ' + str(default_save_model) + ')')
    parser.add_argument('--resume', type=str2bool, default=default_resume,
        help='Resume from resume_from (default: ' + str(default_resume) + ')')
    parser.add_argument('--resume_from', type=str, default=default_resume_from,
        help=' Path to saved model (default: ' + str(default_resume_from) + ')')
    parser.add_argument('--full_checkpointing', type=str2bool, default=default_full_checkpointing,
        help='Writeout and resume from checkpoints (default: ' + str(default_full_checkpointing) + ')')
    parser.add_argument('--calculate_train_loss_each_epoch', type=str, default=default_calculate_train_loss_each_epoch,
        help=' Do a 2nd pass after each epoch to calculate the training error rate/loss (default: ' + str(default_calculate_train_loss_each_epoch) + ')')


    parser.add_argument('--vr_bn_at_recalibration', type=str2bool, default=default_vr_bn_at_recalibration,
        help='Use batch norm on the recalibration pass (default: ' + str(default_vr_bn_at_recalibration) + ')')

    parser.add_argument('--lr_reduction', type=str, default=default_lr_reduction,
        help='Use lr reduction specified (default: ' + str(default_lr_reduction) + ')')
    parser.add_argument('--log_interval', type=int, default=default_log_interval, metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--per_block', type=str2bool, default=default_per_block,
        help='Use per block learning rates (default: ' + str(default_per_block) + ')')
    args = parser.parse_args([]) # Don't actually use command line arguments, put from call to function only

    # Apply overrides?
    args.__dict__.update(args_override)

    if isinstance(args, dict):
        args = Struct(**args)

    #"scsg"
    args.opt_vr = opt_vr = (args.method in ["saga", "svrg", "pointsaga", "recompute_svrg", "online_svrg"])

    run_name = (args.problem + "-" + args.architecture + "-" +
                args.method + "-lr" + str(args.lr) +
                "-m" + str(args.momentum) + "-" + "d" + str(args.decay) +
                "-epochs" + str(args.epochs) + "bs" +
                str(args.batch_size) +
                "reduct_" + args.lr_reduction)

    if not args.batchnorm:
        run_name += "_nobn"

    if args.dropout:
        run_name += "_dropout"

    if args.opt_vr and args.vr_from_epoch != 1:
        run_name += "_vr_from_" + str(args.vr_from_epoch)

    if not args.vr_bn_at_recalibration:
        run_name += "_bn_recal_" + str(args.vr_bn_at_recalibration)

    if args.resume:
        run_name += "_resume"

    if args.seed != 1:
        run_name += "seed_" + str(args.seed)

    if args.half_precision:
        run_name += "_half"

    if args.tail_average > 0:
        run_name += "_tavg_" + str(args.tail_average)
        if args.tail_average_all:
            run_name += "_tall"

    run_name = run_name.strip().replace('.', '_')

    # SETUP LOGGING
    root = logging.getLogger()
    root.setLevel(logging.INFO)

    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s | %(message)s')
    ch.setFormatter(formatter)
    #if 'ch' in locals():
    root.addHandler(ch)

    ############
    logging.info("Run " + run_name)
    logging.info("#########")
    logging.info(args)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    logging.info("Using CUDA: {} CUDA AVAIL: {} #DEVICES: {}".format(
        args.cuda, torch.cuda.is_available(), torch.cuda.device_count()))

    cudnn.benchmark = True

    logging.info("Loading data")
    train_loader, test_loader, model, train_dataset = problems.load(args)

    if hasattr(model, "sampler") and hasattr(model.sampler, "reorder"):
        logging.info("NOTE: Consistant batch sampling in use")

    if args.cuda:
        logging.info("model.cuda")
        model.cuda()
        logging.info("")

    if args.half_precision:
        logging.info("Using half precision")
        model = model.half()

    if args.resume:
        # Load
        model.load_state_dict(torch.load(args.resume_from, map_location=lambda storage, loc: storage))
        model.cuda()
        logging.info("Resuming from file: {}".format(args.resume_from))

    checkpoint_resume = False
    if args.full_checkpointing:
        # Look for and load checkpoint model
        checkpoint_model_path = checkpoint_dir + "/" + run_name + "_checkpoint.model"
        checkpoint_runinfo_path = checkpoint_dir + "/" + run_name + "_info.pkl"
        if os.path.exists(checkpoint_model_path):
            checkpoint_resume = True
            logging.info("Checkpoint found: {}".format(checkpoint_model_path))
            model.load_state_dict(torch.load(checkpoint_model_path, map_location=lambda storage, loc: storage))
            model.cuda()
            with open(checkpoint_runinfo_path, 'rb') as fcheckpoint:
                runinfo = pickle.load(fcheckpoint)

            if runinfo["epoch"] >= args.epochs:
                logging.error("runinfo['epoch']: {} >= args.epochs, checkpoint is past/at end of run".format(runinfo["epoch"]))
                return
        else:
            logging.info("No checkpoint exists, starting a fresh run")


    ############################
    # logging.info some information about the model
    logging.info("Model statistics:")
    nparams = 0
    group_idx = 0
    for param in model.parameters():
        #import pdb; pdb.set_trace()
        group_size = 1
        for g in param.size():
            group_size *= g
        nparams += group_size
        group_idx += 1
    train_nbatches = len(train_loader)
    logging.info("total parameters: {:,}".format(nparams))
    logging.info("minibatch size: {}".format(args.batch_size))
    logging.info("Rough training dataset size: {:,} number of minibatches: {}".format(
    len(train_loader)*args.batch_size, train_nbatches))
    logging.info("Rough test dataset size: {:,} number of test minibatches: {}".format(
    len(test_loader)*args.batch_size, len(test_loader)))

    # Averaging fraction calculation
    ntail_batches = int(train_nbatches*args.tail_average)
    if ntail_batches == 0:
        ntail_batches = 1
    ntail_from = train_nbatches - ntail_batches
    logging.info("Tail averaging fraction {:.2f} will average {} batches, from batch: {}, tail_average_all: {}".format(
        args.tail_average, ntail_batches, ntail_from, args.tail_average_all
    ))

    logging.info("Creating optimizer")
    optimizer = optimizers.optimizer(model, args)
    criterion = nn.CrossEntropyLoss()

    def train(epoch):
        model.train()
        interval = timer()
        start = timer()
        start_time = time.time()

        time_cuda = 0.0
        time_variable = 0.0
        time_forward = 0.0
        time_backward = 0.0
        time_step = 0.0
        time_load = 0.0

        if args.tail_average > 0.0:
            averaged_so_far = 0
            # create/zero tail_average storage
            for group in optimizer.param_groups:
                for p in group['params']:
                    param_state = optimizer.state[p]
                    if 'tail_average' not in param_state:
                        param_state['tail_average'] = p.data.clone().double().zero_()

        load_start_time = time.time()
        for batch_idx, (data, target) in enumerate(train_loader):
            time_load += time.time() - load_start_time

            cuda_time = time.time()
            if args.cuda:
                data, target = data.cuda(), target.cuda(non_blocking=True)
            if args.half_precision:
                data = data.half()
            variable_time = time.time()
            time_cuda += variable_time - cuda_time
            data, target = Variable(data), Variable(target)
            time_variable += time.time() - variable_time

            def eval_closure():
                nonlocal time_forward
                nonlocal time_backward
                closure_time = time.time()
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                eval_time = time.time()
                time_forward += eval_time - closure_time
                loss.backward()
                time_backward += time.time() - eval_time
                return loss

            step_start_time = time.time()
            if hasattr(optimizer, "step_preds"):
                def partial_closure():
                    optimizer.zero_grad()
                    output = model(data)
                    logprobs = log_softmax(output)
                    return logprobs
                loss = optimizer.step_preds(partial_closure, target)
            elif opt_vr:
                loss = optimizer.step(batch_idx, closure=eval_closure)
            else:
                loss = optimizer.step(closure=eval_closure)

            time_step += time.time() - step_start_time

            if args.log_diagnostics and epoch >= args.vr_from_epoch:
                if args.method == "svrg":
                    in_run_diagnostics(epoch, batch_idx, args, train_loader, optimizer, model, criterion)


            # Accumulate tail average
            if args.tail_average > 0.0:
                if batch_idx >= ntail_from:
                    averaged_so_far += 1
                    for group in optimizer.param_groups:
                        for p in group['params']:
                            param_state = optimizer.state[p]
                            tail = param_state['tail_average']
                            # Running mean calculation
                            tail.add_(1.0/averaged_so_far, p.data.double() - tail)

            if batch_idx % args.log_interval == 0:
                mid = timer()
                percent_done = 100. * batch_idx / len(train_loader)
                if percent_done > 0:
                    time_estimate = math.ceil((mid - start)*(100/percent_done))
                    time_estimate = str(datetime.timedelta(seconds=time_estimate))
                    inst_estimate =  math.ceil((mid - interval)*(len(train_loader)/args.log_interval))
                    inst_estimate = str(datetime.timedelta(seconds=inst_estimate))
                else:
                    time_estimate = "unknown"
                    inst_estimate = "unknown"
                logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, time: {} inst: {}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.data.item(), time_estimate, inst_estimate))

                if False:
                    since_start = time.time()
                    logging.info("load: {:.3f}, cuda: {:.3f}, variable: {:.3f}, forward: {:.3f}, backward: {:.3f}, step: {:.3f}, step-clo: {:.3f}, sum: {}, actual: {}".format(
                        time_load, time_cuda, time_variable, time_forward, time_backward, time_step, time_step - time_forward - time_backward,
                        time_load + time_cuda + time_variable + time_step, since_start - start_time
                    ))
                    time_cuda = 0.0
                    time_variable = 0.0
                    time_forward = 0.0
                    time_backward = 0.0
                    time_step = 0.0
                    time_load = 0.0
                    interval = timer()

            load_start_time = time.time()

        if args.tail_average > 0.0:
            if averaged_so_far != ntail_batches:
                raise Exception("Off by one: {}, {}".format(averaged_so_far, ntail_batches))
            current_lr = optimizer.param_groups[0]['lr']

            if args.tail_average_all or args.lr != current_lr:
                logging.info("Setting x to tail average ({}), current_lr: {}".format(
                    args.tail_average, current_lr))

                for group in optimizer.param_groups:
                    for p in group['params']:
                        param_state = optimizer.state[p]
                        tail = param_state['tail_average']
                        p.data.zero_().add_(tail.type_as(p.data))

        if hasattr(model, "sampler") and hasattr(model.sampler, "reorder"):
            model.sampler.reorder()
        if hasattr(train_dataset, "retransform"):
            logging.info("retransform")
            train_dataset.retransform()

    def loss_stats(epoch, loader, setname):
        model.eval()
        loss = 0.0
        correct = 0.0
        for data, target in loader:
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            if args.half_precision:
                data = data.half()
            data, target = Variable(data, volatile=True), Variable(target)
            output = model(data)
            loss += criterion(output, target).data.item()
            pred = output.data.max(1)[1] # index of the max log-probability
            correct += pred.eq(target.data).cpu().sum().float().item()

        loss /= len(loader) # loss function already averages over batch size
        error_rate = 100.0 * correct / len(loader.dataset)
        #pdb.set_trace()
        logging.info('{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
            setname, loss, correct, len(loader.dataset),
            error_rate))

        return (loss, error_rate)

    # Crate directory for saving model if needed
    problem_dir = run_dir + "/" + args.problem
    if not os.path.exists(run_dir):
        os.makedirs(run_dir)
    if not os.path.exists(problem_dir):
        os.makedirs(problem_dir)

    if not checkpoint_resume:
        runinfo = vars(args)
        runinfo["train_losses"] = []
        runinfo["train_errors"] = []
        runinfo["test_losses"] = []
        runinfo["test_errors"] = []
        runinfo["nparams"] = nparams
        runinfo["ndatapoints"] = len(train_loader)*args.batch_size
        runinfo["nminibatches"] = len(train_loader)
        runinfo["epoch"] = 0

    else:
        # When resuming
        if hasattr(optimizer, "recalibrate"):
            logging.info("Recalibrate for restart, epoch: {}".format(runinfo["epoch"]))
            seed = runinfo["seed"] + 1031*runinfo["epoch"]
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            recalibrate(runinfo["epoch"], args, train_loader, test_loader, model, train_dataset, optimizer, criterion)


    for epoch in range(runinfo["epoch"]+1, args.epochs + 1):
        runinfo["epoch"] = epoch
        logging.info("Starting epoch {}".format(epoch))

        seed = runinfo["seed"] + 1031*epoch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        if epoch == 1 and hasattr(optimizer, "recalibrate"):
            recalibrate(epoch, args, train_loader, test_loader, model, train_dataset, optimizer, criterion)

        if args.lr_reduction == "default":
            lr = args.lr * (0.1 ** (epoch // 75))
        elif args.lr_reduction == "none" or args.lr_reduction == "False":
            lr = args.lr
        elif args.lr_reduction == "150":
            lr = args.lr * (0.1 ** (epoch // 150))
        elif args.lr_reduction == "150-225":
            lr = args.lr * (0.1 ** (epoch // 150)) * (0.1 ** (epoch // 225))
        elif  args.lr_reduction == "up5x-20-down150":
            if epoch < 20:
                lr = args.lr
            else:
                lr = 3.0 * args.lr * (0.1 ** (epoch // 150))
        elif args.lr_reduction == "up30-150-225":
            if epoch < 30:
                lr = args.lr
            else:
                lr = 3.0 * args.lr * (0.1 ** (epoch // 150)) * (0.1 ** (epoch // 225))
        elif args.lr_reduction == "every30":
            lr = args.lr * (0.1 ** (epoch // 30))
        else:
            logging.info("Lr scheme not recognised: {}".format(args.lr_reduction))

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        logging.info("Learning rate: {}".format(lr))

        start = timer()
        if args.method == "scsg":
            train_scsg(epoch, args, train_loader, test_loader, model, train_dataset, optimizer, criterion)
        else:
            train(epoch)
        end = timer()

        logging.info("Epoch took: {}".format(end-start))
        logging.info("")

        if args.calculate_train_loss_each_epoch:
            (train_loss, train_err) = loss_stats(epoch, train_loader, "Train") #test(epoch)
        else:
            train_loss = 0
            train_err = 0

        runinfo["train_losses"].append(train_loss)
        runinfo["train_errors"].append(train_err)

        (test_loss, test_err) = loss_stats(epoch, test_loader, "Test") #test(epoch)
        runinfo["test_losses"].append(test_loss)
        runinfo["test_errors"].append(test_err)

        logging.info("")

        if args.log_fast_diagnostics and hasattr(optimizer, "store_old_table"):
            logging.info("Storing old table")
            optimizer.store_old_table()
        if hasattr(optimizer, "recalibrate"):
            recalibrate(epoch+1, args, train_loader, test_loader, model, train_dataset, optimizer, criterion)
        if False: # Only works for recompute_svrg I think
            batchnorm_diagnostics(epoch, args, train_loader, optimizer, model)

        if epoch >= args.vr_from_epoch and args.log_fast_diagnostics and hasattr(optimizer, "epoch_diagnostics"):
            optimizer.epoch_diagnostics(train_loss, train_err, test_loss, test_err)

        # Ocassionally save out the model.
        if args.save_model and epoch % 10 == 0:
            logging.info("Saving model ...")
            model_dir = problem_dir + "/model_" + run_name
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            model_fname = "{}/epoch_{}.model".format(model_dir, epoch)
            torch.save(model.state_dict(), model_fname)
            logging.info("Saved model {}".format(model_fname))

            out_fname = problem_dir + "/" + run_name + '_partial.pkl'
            with open(out_fname, 'wb') as output:
                pickle.dump(runinfo, output)
            print("Saved partial: {}".format(out_fname))

        if args.full_checkpointing:
            checkpoint_model_path_tmp = checkpoint_model_path + ".tmp"
            logging.info("Saving checkpoint model ...")
            torch.save(model.state_dict(), checkpoint_model_path_tmp)
            os.rename(checkpoint_model_path_tmp, checkpoint_model_path)
            logging.info("Saved {}".format(checkpoint_model_path))

            checkpoint_runinfo_path_tmp = checkpoint_runinfo_path + ".tmp"
            with open(checkpoint_runinfo_path_tmp, 'wb') as output:
                pickle.dump(runinfo, output)
            os.rename(checkpoint_runinfo_path_tmp, checkpoint_runinfo_path)
            print("Saved runinfo: {}".format(checkpoint_runinfo_path))

    if True:
        if args.method == "reparm":
            optimizer.print_diagnostics()

    out_fname = problem_dir + "/" + run_name + '_final.pkl'
    with open(out_fname, 'wb') as output:
        pickle.dump(runinfo, output)
    print("Saved {}".format(out_fname))