def train()

in training/train_vqa.py [0:0]


def train(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    optim = torch.optim.Adam(
        filter(lambda p: p.requires_grad, shared_model.parameters()),
        lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank%len(args.gpus)],
        'to_cache': args.cache
    }

    args.output_log_path = os.path.join(args.log_dir,
                                        'train_' + str(rank) + '.json')

    metrics = VqaMetric(
        info={'split': 'train',
              'thread': rank},
        metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_log_path)

    train_loader = EqaDataLoader(**train_loader_kwargs)
    if args.input_type == 'ques,image':
        train_loader.dataset._load_envs(start_idx=0, in_order=True)

    print('train_loader has %d samples' % len(train_loader.dataset))

    t, epoch = 0, 0

    while epoch < int(args.max_epochs):

        if args.input_type == 'ques':

            for batch in train_loader:

                t += 1

                model.load_state_dict(shared_model.state_dict())
                model.train()
                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # zero grad
                optim.zero_grad()

                # update metrics
                accuracy, ranks = metrics.compute_ranks(scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                # backprop and update
                loss.backward()

                ensure_shared_grads(model.cpu(), shared_model)
                optim.step()

                if t % args.print_every == 0:
                    print(metrics.get_stat_string())
                    if args.log == True:
                        metrics.dump_log()

        elif args.input_type == 'ques,image':

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cnn.eval()
                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(scores.data.cpu(), answers)
                    metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        if args.log == True:
                            metrics.dump_log()

                if all_envs_loaded == False:
                    print('[CHECK][Cache:%d][Total:%d]' % (len(train_loader.dataset.img_data_cache),
                        len(train_loader.dataset.env_list)))
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        epoch += 1