in training/train_vqa.py [0:0]
def train(rank, args, shared_model):
torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))
if args.input_type == 'ques':
model_kwargs = {'vocab': load_vocab(args.vocab_json)}
model = VqaLstmModel(**model_kwargs)
elif args.input_type == 'ques,image':
model_kwargs = {'vocab': load_vocab(args.vocab_json)}
model = VqaLstmCnnAttentionModel(**model_kwargs)
lossFn = torch.nn.CrossEntropyLoss().cuda()
optim = torch.optim.Adam(
filter(lambda p: p.requires_grad, shared_model.parameters()),
lr=args.learning_rate)
train_loader_kwargs = {
'questions_h5': args.train_h5,
'data_json': args.data_json,
'vocab': args.vocab_json,
'batch_size': args.batch_size,
'input_type': args.input_type,
'num_frames': args.num_frames,
'split': 'train',
'max_threads_per_gpu': args.max_threads_per_gpu,
'gpu_id': args.gpus[rank%len(args.gpus)],
'to_cache': args.cache
}
args.output_log_path = os.path.join(args.log_dir,
'train_' + str(rank) + '.json')
metrics = VqaMetric(
info={'split': 'train',
'thread': rank},
metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
log_json=args.output_log_path)
train_loader = EqaDataLoader(**train_loader_kwargs)
if args.input_type == 'ques,image':
train_loader.dataset._load_envs(start_idx=0, in_order=True)
print('train_loader has %d samples' % len(train_loader.dataset))
t, epoch = 0, 0
while epoch < int(args.max_epochs):
if args.input_type == 'ques':
for batch in train_loader:
t += 1
model.load_state_dict(shared_model.state_dict())
model.train()
model.cuda()
idx, questions, answers = batch
questions_var = Variable(questions.cuda())
answers_var = Variable(answers.cuda())
scores = model(questions_var)
loss = lossFn(scores, answers_var)
# zero grad
optim.zero_grad()
# update metrics
accuracy, ranks = metrics.compute_ranks(scores.data.cpu(), answers)
metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])
# backprop and update
loss.backward()
ensure_shared_grads(model.cpu(), shared_model)
optim.step()
if t % args.print_every == 0:
print(metrics.get_stat_string())
if args.log == True:
metrics.dump_log()
elif args.input_type == 'ques,image':
done = False
all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()
while done == False:
for batch in train_loader:
t += 1
model.load_state_dict(shared_model.state_dict())
model.train()
model.cnn.eval()
model.cuda()
idx, questions, answers, images, _, _, _ = batch
questions_var = Variable(questions.cuda())
answers_var = Variable(answers.cuda())
images_var = Variable(images.cuda())
scores, att_probs = model(images_var, questions_var)
loss = lossFn(scores, answers_var)
# zero grad
optim.zero_grad()
# update metrics
accuracy, ranks = metrics.compute_ranks(scores.data.cpu(), answers)
metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])
# backprop and update
loss.backward()
ensure_shared_grads(model.cpu(), shared_model)
optim.step()
if t % args.print_every == 0:
print(metrics.get_stat_string())
if args.log == True:
metrics.dump_log()
if all_envs_loaded == False:
print('[CHECK][Cache:%d][Total:%d]' % (len(train_loader.dataset.img_data_cache),
len(train_loader.dataset.env_list)))
train_loader.dataset._load_envs(in_order=True)
if len(train_loader.dataset.pruned_env_set) == 0:
done = True
else:
done = True
epoch += 1