def main()

in src/train.py [0:0]


def main(args):
    global HALT_filename, CHECKPOINT_tempfile

    # Create model directory & other aux folders for logging
    where_to_save = os.path.join(args.save_dir, args.dataset, args.model_name, args.image_model, args.experiment_name)
    checkpoints_dir = os.path.join(where_to_save, 'checkpoints')
    suffix = '_'.join([args.dataset, args.model_name, str(args.seed)])
    checkpoint_filename = os.path.join(checkpoints_dir, '_'.join([suffix, 'checkpoint']))
    print(checkpoint_filename)
    logs_dir = os.path.join(where_to_save, 'logs')
    tb_logs = os.path.join(where_to_save, 'tb_logs', args.dataset,
                           args.model_name + '_' + str(args.seed))
    make_dir(where_to_save)
    make_dir(logs_dir)
    make_dir(checkpoints_dir)
    make_dir(tb_logs)

    # Create loggers
    # stdout logger
    stdout_logger = logging.getLogger('STDOUT')
    stdout_logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(threadName)s - %(levelname)s: %(message)s')
    fh_out = logging.FileHandler(os.path.join(logs_dir, 'train_{}.log'.format(suffix)))
    fh_out.setFormatter(formatter)
    stdout_logger.addHandler(fh_out)
    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setFormatter(formatter)
    stdout_logger.addHandler(ch)
    # stderr logger
    stderr_logger = logging.getLogger('STDERR')
    fh_err = logging.FileHandler(os.path.join(logs_dir, 'train_{}.err'.format(suffix)), mode='w')
    fh_err.setFormatter(formatter)
    stderr_logger.addHandler(fh_err)
    sl_stderr = StreamToLogger(stderr_logger, logging.ERROR)
    sys.stderr = sl_stderr

    # HALT file is used as a sign of job completion.
    # Check if no HALT file left from previous runs.
    HALT_filename = os.path.join(where_to_save, 'HALT_{}'.format(suffix))
    if os.path.isfile(HALT_filename):
        os.remove(HALT_filename)

    # Remove CHECKPOINT_tempfile
    CHECKPOINT_tempfile = checkpoint_filename + '.tmp.ckpt'
    if os.path.isfile(CHECKPOINT_tempfile):
        os.remove(CHECKPOINT_tempfile)

    # Create tensorboard visualizer
    if args.tensorboard:
        logger = Visualizer(tb_logs, name='visual_results', resume=args.resume)

    # Check if we want to resume from last checkpoint of current model
    checkpoint = None
    if args.resume:
        if os.path.isfile(checkpoint_filename + '.ckpt'):
            checkpoint = torch.load(checkpoint_filename + '.ckpt', map_location=map_loc)
            num_epochs = args.num_epochs
            args = checkpoint['args']
            args.num_epochs = num_epochs

    # Build data loader
    data_loaders = {}
    datasets = {}
    for split in ['train', 'val']:

        transforms_list = [transforms.Resize(args.image_size)]

        # Image pre-processing
        if split == 'train':
            transforms_list.append(transforms.RandomHorizontalFlip())
            transforms_list.append(transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)))
            transforms_list.append(transforms.RandomCrop(args.crop_size))

        else:
            transforms_list.append(transforms.CenterCrop(args.crop_size))
        transforms_list.append(transforms.ToTensor())
        transforms_list.append(transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))
        transform = transforms.Compose(transforms_list)

        # Load dataset path
        datapaths = json.load(open('../configs/datapaths.json'))
        dataset_root = datapaths[args.dataset]
        data_loaders[split], datasets[split] = get_loader(
            dataset=args.dataset,
            dataset_root=dataset_root,
            split=split,
            transform=transform,
            batch_size=args.batch_size,
            include_eos=(args.decoder != 'ff'),
            shuffle=(split == 'train'),
            num_workers=args.num_workers,
            drop_last=(split == 'train'),
            shuffle_labels=args.shuffle_labels,
            seed=args.seed,
            checkpoint=checkpoint)
        stdout_logger.info('Dataset {} split contains {} images'.format(
            split, len(datasets[split])))

    vocab_size = len(datasets[split].get_vocab())
    stdout_logger.info('Vocabulary size is {}'.format(vocab_size))

    # Build the model
    model = get_model(args, vocab_size)

    # add model parameters
    if model.image_encoder.last_module is not None:
        params = list(model.decoder.parameters()) + list(
            model.image_encoder.last_module.parameters())
    else:
        params = list(model.decoder.parameters())
    params_cnn = list(model.image_encoder.pretrained_net.parameters())

    n_p_cnn = sum(p.numel() for p in params_cnn if p.requires_grad)
    n_p = sum(p.numel() for p in params if p.requires_grad)
    total = n_p + n_p_cnn
    stdout_logger.info("CNN params: {}".format(n_p_cnn))
    stdout_logger.info("decoder params: {}".format(n_p))
    stdout_logger.info("total params: {}".format(total))

    # encoder and decoder optimizers
    if params_cnn is not None and args.finetune_after == 0:
        optimizer = torch.optim.Adam(
            [{
                'params': params
            }, {
                'params': params_cnn,
                'lr': args.learning_rate * args.scale_learning_rate_cnn
            }],
            lr=args.learning_rate,
            weight_decay=args.weight_decay)
        keep_cnn_gradients = True
        stdout_logger.info("Fine tuning image encoder")
    else:
        optimizer = torch.optim.Adam(params, lr=args.learning_rate)
        keep_cnn_gradients = False
        stdout_logger.info("Freezing image encoder")

    # early stopping and checkpoint
    es_best = {'o_f1': 0, 'c_f1': 0, 'i_f1': 0, 'average': 0}
    epoch_best = {'o_f1': -1, 'c_f1': -1, 'i_f1': -1, 'average': -1}
    if checkpoint is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
        model.load_state_dict(checkpoint['state_dict'])
        es_best = checkpoint['es_best']
        epoch_best = checkpoint['epoch_best']

    if device != 'cpu' and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    model = model.to(device)
    cudnn.benchmark = True

    if not hasattr(args, 'current_epoch'):
        args.current_epoch = 0

    # Train the model
    decay_factor = 1.0
    start_step = 0 if checkpoint is None else checkpoint['current_step']
    curr_pat = 0 if checkpoint is None else checkpoint['current_pat']
    for epoch in range(args.current_epoch, args.num_epochs):

        # save current epoch for resuming
        if args.tensorboard:
            logger.reset()

        # increase / decrease values for moving params
        if args.decay_lr:
            frac = epoch // args.lr_decay_every
            decay_factor = args.lr_decay_rate**frac
            new_lr = args.learning_rate * decay_factor
            stdout_logger.info('Epoch %d. lr: %.5f' % (epoch, new_lr))
            set_lr(optimizer, decay_factor)

        if args.finetune_after != -1 and args.finetune_after < epoch \
                and not keep_cnn_gradients and params_cnn is not None:

            stdout_logger.info("Starting to fine tune CNN")
            # start with learning rates as they were (if decayed during training)
            optimizer = torch.optim.Adam(
                [{
                    'params': params
                }, {
                    'params': params_cnn,
                    'lr': decay_factor * args.learning_rate * args.scale_learning_rate_cnn
                }],
                lr=decay_factor * args.learning_rate)
            keep_cnn_gradients = True

        for split in ['train', 'val']:

            if split == 'train':
                model.train()
            else:
                model.eval()
            total_step = len(data_loaders[split])
            loader = iter(data_loaders[split])

            total_loss_dict = {
                'label_loss': [],
                'eos_loss': [],
                'cardinality_loss': [],
                'loss': [],
                'o_f1': [],
                'c_f1': [],
                'i_f1': [],
            }

            torch.cuda.synchronize()
            start = time.time()

            overall_error_counts = {
                'tp_c': 0,
                'fp_c': 0,
                'fn_c': 0,
                'tn_c': 0,
                'tp_all': 0,
                'fp_all': 0,
                'fn_all': 0
            }

            i = 0 if split == 'val' else start_step
            for info in loader:
                img_inputs, gt = info

                # adapt gts by adding pad_value to match maxnumlabel length
                gt = [
                    sublist + [vocab_size - 1] * (args.maxnumlabels - len(sublist))
                    for sublist in gt
                ]
                gt = torch.LongTensor(gt)

                # move to device
                img_inputs = img_inputs.to(device)
                gt = gt.to(device)

                loss_dict = {}

                if split == 'val':
                    with torch.no_grad():
                        # get losses and label predictions
                        _, predictions = model(
                            img_inputs,
                            maxnumlabels=args.maxnumlabels,
                            compute_losses=False,
                            compute_predictions=True)

                        # convert model predictions and targets to k-hots
                        pred_k_hots = label2_k_hots(
                            predictions, vocab_size - 1, remove_eos=(args.decoder != 'ff'))
                        target_k_hots = label2_k_hots(
                            gt, vocab_size - 1, remove_eos=(args.decoder != 'ff'))

                        # update overall and per class error types
                        update_error_counts(overall_error_counts, pred_k_hots, target_k_hots)

                        # update per image error types
                        i_f1s = []
                        for i in range(pred_k_hots.size(0)):
                            image_error_counts = {
                                'tp_c': 0,
                                'fp_c': 0,
                                'fn_c': 0,
                                'tn_c': 0,
                                'tp_all': 0,
                                'fp_all': 0,
                                'fn_all': 0
                            }
                            update_error_counts(image_error_counts, pred_k_hots[i].unsqueeze(0),
                                                target_k_hots[i].unsqueeze(0))
                            image_metrics = compute_metrics(
                                image_error_counts, which_metrics=['f1'])
                            i_f1s.append(image_metrics['f1'])

                        loss_dict['i_f1'] = np.mean(i_f1s)
                        del predictions, pred_k_hots, target_k_hots, image_metrics

                else:
                    losses, _ = model(
                        img_inputs,
                        gt,
                        maxnumlabels=args.maxnumlabels,
                        keep_cnn_gradients=keep_cnn_gradients,
                        compute_losses=True)

                    # label loss
                    label_loss = losses['label_loss']
                    label_loss = label_loss.mean()
                    loss_dict['label_loss'] = label_loss.item()

                    # cardinality loss
                    if args.pred_cardinality != 'none':
                        cardinality_loss = losses['cardinality_loss']
                        cardinality_loss = cardinality_loss.mean()
                        loss_dict['cardinality_loss'] = cardinality_loss.item()
                    else:
                        cardinality_loss = 0

                    # eos loss
                    if args.perminv:
                        eos_loss = losses['eos_loss']
                        eos_loss = eos_loss.mean()
                        loss_dict['eos_loss'] = eos_loss.item()
                    else:
                        eos_loss = 0

                    # total loss
                    loss = args.loss_weight[0] * label_loss \
                           + args.loss_weight[1]*cardinality_loss + \
                           args.loss_weight[2]*eos_loss
                    loss_dict['loss'] = loss.item()

                    # optimizer step
                    model.zero_grad()
                    loss.backward()
                    optimizer.step()

                    del loss, losses
                del img_inputs

                for key in loss_dict.keys():
                    total_loss_dict[key].append(loss_dict[key])

                # Print log info
                if args.log_step != -1 and i % args.log_step == 0:
                    elapsed_time = time.time() - start
                    lossesstr = ""
                    for k in total_loss_dict.keys():
                        if len(total_loss_dict[k]) == 0:
                            continue
                        this_one = "%s: %.4f" % (k, np.mean(total_loss_dict[k][-args.log_step:]))
                        lossesstr += this_one + ', '
                    # this only displays nll loss on captions, the rest of losses will
                    # be in tensorboard logs
                    strtoprint = 'Split: %s, Epoch [%d/%d], Step [%d/%d], Losses: %sTime: %.4f' % (
                        split, epoch, args.num_epochs, i, total_step, lossesstr, elapsed_time)
                    stdout_logger.info(strtoprint)
                    if args.tensorboard and split == 'train':
                        logger.scalar_summary(
                            mode=split + '_iter',
                            epoch=total_step * epoch + i,
                            **{
                                k: np.mean(v[-args.log_step:])
                                for k, v in total_loss_dict.items()
                                if v
                            })

                    torch.cuda.synchronize()
                    start = time.time()

                i += 1

            if split == 'train':
                increase_loader_epoch(data_loaders['train'])
                start_step = 0

            if split == 'val':
                overal_metrics = compute_metrics(overall_error_counts, ['f1', 'c_f1'], weights=None)

                total_loss_dict['o_f1'] = overal_metrics['f1']
                total_loss_dict['c_f1'] = overal_metrics['c_f1']

                if args.tensorboard:
                    # 1. Log scalar values (scalar summary)
                    logger.scalar_summary(
                        mode=split,
                        epoch=epoch,
                        **{k: np.mean(v)
                           for k, v in total_loss_dict.items()
                           if v})

        # early stopping
        metric_average = 0
        best_at_checkpoint_metric = False
        if args.metric_to_checkpoint != 'average':
            es_value = np.mean(total_loss_dict[args.metric_to_checkpoint])
            if es_value > es_best[args.metric_to_checkpoint]:
                es_best[args.metric_to_checkpoint] = es_value
                epoch_best[args.metric_to_checkpoint] = epoch
                best_at_checkpoint_metric = True
                save_checkpoint(model, optimizer, args, es_best, epoch_best, 0, 0,
                                '{}.best.{}'.format(checkpoint_filename, args.metric_to_checkpoint))
        else:
            for metric in ['o_f1', 'c_f1', 'i_f1']:
                es_value = np.mean(total_loss_dict[metric])
                metric_average += es_value
            metric_average /= 3
            if metric_average > es_best['average']:
                es_best['average'] = metric_average
                epoch_best['average'] = epoch
                if 'average' == args.metric_to_checkpoint:
                    best_at_checkpoint_metric = True
                    save_checkpoint(model, optimizer, args, es_best, epoch_best, 0, 0,
                                    '{}.best.average'.format(checkpoint_filename))

        if best_at_checkpoint_metric:
            curr_pat = 0
        else:
            curr_pat += 1

        args.current_epoch = epoch + 1  # Save the epoch at which the model needs to start
        save_checkpoint(model, optimizer, args, es_best, epoch_best, 0, curr_pat,
                        checkpoint_filename)
        stdout_logger.info('Saved checkpoint for epoch {}.'.format(epoch))

        if curr_pat > args.patience:
            break

    # Mark job as finished
    f = open(HALT_filename, 'w')
    for metric in es_best.keys():
        f.write('{}:{}\n'.format(metric, es_best[metric]))
    f.close()

    if args.tensorboard:
        logger.close()