in scripts/train_model.py [0:0]
def get_baseline_model(args):
vocab = utils.load_vocab(args.vocab_json)
if args.baseline_start_from is not None:
model, kwargs = utils.load_baseline(args.baseline_start_from)
elif args.model_type == 'LSTM':
kwargs = {
'vocab': vocab,
'rnn_wordvec_dim': args.rnn_wordvec_dim,
'rnn_dim': args.rnn_hidden_dim,
'rnn_num_layers': args.rnn_num_layers,
'rnn_dropout': args.rnn_dropout,
'fc_dims': parse_int_list(args.classifier_fc_dims),
'fc_use_batchnorm': args.classifier_batchnorm == 1,
'fc_dropout': args.classifier_dropout,
}
model = LstmModel(**kwargs)
elif args.model_type == 'CNN+LSTM':
kwargs = {
'vocab': vocab,
'rnn_wordvec_dim': args.rnn_wordvec_dim,
'rnn_dim': args.rnn_hidden_dim,
'rnn_num_layers': args.rnn_num_layers,
'rnn_dropout': args.rnn_dropout,
'cnn_feat_dim': parse_int_list(args.feature_dim),
'cnn_num_res_blocks': args.cnn_num_res_blocks,
'cnn_res_block_dim': args.cnn_res_block_dim,
'cnn_proj_dim': args.cnn_proj_dim,
'cnn_pooling': args.cnn_pooling,
'fc_dims': parse_int_list(args.classifier_fc_dims),
'fc_use_batchnorm': args.classifier_batchnorm == 1,
'fc_dropout': args.classifier_dropout,
}
model = CnnLstmModel(**kwargs)
elif args.model_type == 'CNN+LSTM+SA':
kwargs = {
'vocab': vocab,
'rnn_wordvec_dim': args.rnn_wordvec_dim,
'rnn_dim': args.rnn_hidden_dim,
'rnn_num_layers': args.rnn_num_layers,
'rnn_dropout': args.rnn_dropout,
'cnn_feat_dim': parse_int_list(args.feature_dim),
'stacked_attn_dim': args.stacked_attn_dim,
'num_stacked_attn': args.num_stacked_attn,
'fc_dims': parse_int_list(args.classifier_fc_dims),
'fc_use_batchnorm': args.classifier_batchnorm == 1,
'fc_dropout': args.classifier_dropout,
}
model = CnnLstmSaModel(**kwargs)
if model.rnn.token_to_idx != vocab['question_token_to_idx']:
# Make sure new vocab is superset of old
for k, v in model.rnn.token_to_idx.items():
assert k in vocab['question_token_to_idx']
assert vocab['question_token_to_idx'][k] == v
for token, idx in vocab['question_token_to_idx'].items():
model.rnn.token_to_idx[token] = idx
kwargs['vocab'] = vocab
model.rnn.expand_vocab(vocab['question_token_to_idx'])
model.cuda()
model.train()
return model, kwargs