def train()

in training/train_nav.py [0:0]


def train(rank, args, shared_model):
    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'cnn':

        model_kwargs = {}
        model = NavCnnModel(**model_kwargs)

    elif args.model_type == 'cnn+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnModel(**model_kwargs)

    elif args.model_type == 'lstm':

        model_kwargs = {}
        model = NavCnnRnnModel(**model_kwargs)

    elif args.model_type == 'lstm-mult+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnRnnMultModel(**model_kwargs)

    elif args.model_type == 'lstm+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnRnnModel(**model_kwargs)

    elif args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    optim = torch.optim.Adamax(
        filter(lambda p: p.requires_grad, shared_model.parameters()),
        lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.model_type,
        'num_frames': 5,
        'map_resolution': args.map_resolution,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.cache,
        'overfit': args.overfit,
        'max_controller_actions': args.max_controller_actions,
        'max_actions': args.max_actions
    }

    args.output_log_path = os.path.join(args.log_dir,
                                        'train_' + str(rank) + '.json')

    if 'pacman' in args.model_type:

        metrics = NavMetric(
            info={'split': 'train',
                  'thread': rank},
            metric_names=['planner_loss', 'controller_loss'],
            log_json=args.output_log_path)

    else:

        metrics = NavMetric(
            info={'split': 'train',
                  'thread': rank},
            metric_names=['loss'],
            log_json=args.output_log_path)

    train_loader = EqaDataLoader(**train_loader_kwargs)

    print('train_loader has %d samples' % len(train_loader.dataset))
    logging.info('TRAIN: train loader has {} samples'.format(len(train_loader.dataset)))

    t, epoch = 0, 0

    while epoch < int(args.max_epochs):

        if 'cnn' in args.model_type:

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cuda()

                    idx, questions, _, img_feats, _, actions_out, _ = batch

                    img_feats_var = Variable(img_feats.cuda())
                    if '+q' in args.model_type:
                        questions_var = Variable(questions.cuda())
                    actions_out_var = Variable(actions_out.cuda())

                    if '+q' in args.model_type:
                        scores = model(img_feats_var, questions_var)
                    else:
                        scores = model(img_feats_var)

                    loss = lossFn(scores, actions_out_var)

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    metrics.update([loss.data[0]])

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        logging.info("TRAIN: metrics: {}".format(metrics.get_stat_string()))
                        if args.log == True:
                            metrics.dump_log()

                    print('[CHECK][Cache:%d][Total:%d]' %
                          (len(train_loader.dataset.img_data_cache),
                           len(train_loader.dataset.env_list)))
                    logging.info('TRAIN: [CHECK][Cache:{}][Total:{}]'.format(
                        len(train_loader.dataset.img_data_cache), len(train_loader.dataset.env_list)))

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                        if args.cache == False:
                            train_loader.dataset._load_envs(
                                start_idx=0, in_order=True)
                else:
                    done = True

        elif 'lstm' in args.model_type:

            lossFn = MaskedNLLCriterion().cuda()

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()
            total_times = []
            while done == False:

                start_time = time.time()
                for batch in train_loader:

                    t += 1


                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cuda()

                    idx, questions, _, img_feats, actions_in, actions_out, action_lengths, masks = batch

                    img_feats_var = Variable(img_feats.cuda())
                    if '+q' in args.model_type:
                        questions_var = Variable(questions.cuda())
                    actions_in_var = Variable(actions_in.cuda())
                    actions_out_var = Variable(actions_out.cuda())
                    action_lengths = action_lengths.cuda()
                    masks_var = Variable(masks.cuda())

                    action_lengths, perm_idx = action_lengths.sort(
                        0, descending=True)

                    img_feats_var = img_feats_var[perm_idx]
                    if '+q' in args.model_type:
                        questions_var = questions_var[perm_idx]
                    actions_in_var = actions_in_var[perm_idx]
                    actions_out_var = actions_out_var[perm_idx]
                    masks_var = masks_var[perm_idx]

                    if '+q' in args.model_type:
                        scores, hidden = model(img_feats_var, questions_var,
                                               actions_in_var,
                                               action_lengths.cpu().numpy())
                    else:
                        scores, hidden = model(img_feats_var, False,
                                               actions_in_var,
                                               action_lengths.cpu().numpy())

                    #block out masks
                    if args.curriculum:
                        curriculum_length = (epoch+1)*5
                        for i, action_length in enumerate(action_lengths):
                            if action_length - curriculum_length > 0:
                                masks_var[i, :action_length-curriculum_length] = 0

                    logprob = F.log_softmax(scores, dim=1)
                    loss = lossFn(
                        logprob, actions_out_var[:, :action_lengths.max()]
                        .contiguous().view(-1, 1),
                        masks_var[:, :action_lengths.max()].contiguous().view(
                            -1, 1))

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    metrics.update([loss.data[0]])
                    logging.info("TRAIN LSTM loss: {:.6f}".format(loss.data[0]))

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        logging.info("TRAIN: metrics: {}".format(metrics.get_stat_string()))
                        if args.log == True:
                            metrics.dump_log()

                    print('[CHECK][Cache:%d][Total:%d]' %
                          (len(train_loader.dataset.img_data_cache),
                           len(train_loader.dataset.env_list)))
                    logging.info('TRAIN: [CHECK][Cache:{}][Total:{}]'.format(
                        len(train_loader.dataset.img_data_cache), len(train_loader.dataset.env_list)))


                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                        if args.cache == False:
                            train_loader.dataset._load_envs(
                                start_idx=0, in_order=True)
                else:
                    done = True

        elif 'pacman' in args.model_type:

            planner_lossFn = MaskedNLLCriterion().cuda()
            controller_lossFn = MaskedNLLCriterion().cuda()

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cuda()

                    idx, questions, _, planner_img_feats, planner_actions_in, \
                        planner_actions_out, planner_action_lengths, planner_masks, \
                        controller_img_feats, controller_actions_in, planner_hidden_idx, \
                        controller_outs, controller_action_lengths, controller_masks = batch

                    questions_var = Variable(questions.cuda())

                    planner_img_feats_var = Variable(planner_img_feats.cuda())
                    planner_actions_in_var = Variable(
                        planner_actions_in.cuda())
                    planner_actions_out_var = Variable(
                        planner_actions_out.cuda())
                    planner_action_lengths = planner_action_lengths.cuda()
                    planner_masks_var = Variable(planner_masks.cuda())

                    controller_img_feats_var = Variable(
                        controller_img_feats.cuda())
                    controller_actions_in_var = Variable(
                        controller_actions_in.cuda())
                    planner_hidden_idx_var = Variable(
                        planner_hidden_idx.cuda())
                    controller_outs_var = Variable(controller_outs.cuda())
                    controller_action_lengths = controller_action_lengths.cuda(
                    )
                    controller_masks_var = Variable(controller_masks.cuda())

                    planner_action_lengths, perm_idx = planner_action_lengths.sort(
                        0, descending=True)

                    questions_var = questions_var[perm_idx]

                    planner_img_feats_var = planner_img_feats_var[perm_idx]
                    planner_actions_in_var = planner_actions_in_var[perm_idx]
                    planner_actions_out_var = planner_actions_out_var[perm_idx]
                    planner_masks_var = planner_masks_var[perm_idx]

                    controller_img_feats_var = controller_img_feats_var[
                        perm_idx]
                    controller_actions_in_var = controller_actions_in_var[
                        perm_idx]
                    controller_outs_var = controller_outs_var[perm_idx]
                    planner_hidden_idx_var = planner_hidden_idx_var[perm_idx]
                    controller_action_lengths = controller_action_lengths[
                        perm_idx]
                    controller_masks_var = controller_masks_var[perm_idx]

                    planner_scores, controller_scores, planner_hidden = model(
                        questions_var, planner_img_feats_var,
                        planner_actions_in_var,
                        planner_action_lengths.cpu().numpy(),
                        planner_hidden_idx_var, controller_img_feats_var,
                        controller_actions_in_var, controller_action_lengths)

                    planner_logprob = F.log_softmax(planner_scores, dim=1)
                    controller_logprob = F.log_softmax(
                        controller_scores, dim=1)

                    planner_loss = planner_lossFn(
                        planner_logprob,
                        planner_actions_out_var[:, :planner_action_lengths.max(
                        )].contiguous().view(-1, 1),
                        planner_masks_var[:, :planner_action_lengths.max()]
                        .contiguous().view(-1, 1))

                    controller_loss = controller_lossFn(
                        controller_logprob,
                        controller_outs_var[:, :controller_action_lengths.max(
                        )].contiguous().view(-1, 1),
                        controller_masks_var[:, :controller_action_lengths.max(
                        )].contiguous().view(-1, 1))

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    metrics.update(
                        [planner_loss.data[0], controller_loss.data[0]])
                    logging.info("TRAINING PACMAN planner-loss: {:.6f} controller-loss: {:.6f}".format(
                        planner_loss.data[0], controller_loss.data[0]))

                    # backprop and update
                    if args.max_controller_actions == 1:
                        (planner_loss).backward()
                    else:
                        (planner_loss + controller_loss).backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        logging.info("TRAIN: metrics: {}".format(metrics.get_stat_string()))
                        if args.log == True:
                            metrics.dump_log()

                    print('[CHECK][Cache:%d][Total:%d]' %
                          (len(train_loader.dataset.img_data_cache),
                           len(train_loader.dataset.env_list)))
                    logging.info('TRAIN: [CHECK][Cache:{}][Total:{}]'.format(
                        len(train_loader.dataset.img_data_cache), len(train_loader.dataset.env_list)))

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                        if args.cache == False:
                            train_loader.dataset._load_envs(
                                start_idx=0, in_order=True)
                else:
                    done = True

        epoch += 1

        if epoch % args.save_every == 0:

            model_state = get_state(model)
            optimizer_state = optim.state_dict()

            aad = dict(args.__dict__)
            ad = {}
            for i in aad:
                if i[0] != '_':
                    ad[i] = aad[i]

            checkpoint = {'args': ad,
                          'state': model_state,
                          'epoch': epoch,
                          'optimizer': optimizer_state}

            checkpoint_path = '%s/epoch_%d_thread_%d.pt' % (
                args.checkpoint_dir, epoch, rank)
            print('Saving checkpoint to %s' % checkpoint_path)
            logging.info("TRAIN: Saving checkpoint to {}".format(checkpoint_path))
            torch.save(checkpoint, checkpoint_path)