in scripts/train_model.py [0:0]
def main(args):
if args.randomize_checkpoint_path == 1:
name, ext = os.path.splitext(args.checkpoint_path)
num = random.randint(1, 1000000)
args.checkpoint_path = '%s_%06d%s' % (name, num, ext)
vocab = utils.load_vocab(args.vocab_json)
if args.use_local_copies == 1:
shutil.copy(args.train_question_h5, '/tmp/train_questions.h5')
shutil.copy(args.train_features_h5, '/tmp/train_features.h5')
shutil.copy(args.val_question_h5, '/tmp/val_questions.h5')
shutil.copy(args.val_features_h5, '/tmp/val_features.h5')
args.train_question_h5 = '/tmp/train_questions.h5'
args.train_features_h5 = '/tmp/train_features.h5'
args.val_question_h5 = '/tmp/val_questions.h5'
args.val_features_h5 = '/tmp/val_features.h5'
question_families = None
if args.family_split_file is not None:
with open(args.family_split_file, 'r') as f:
question_families = json.load(f)
train_loader_kwargs = {
'question_h5': args.train_question_h5,
'feature_h5': args.train_features_h5,
'vocab': vocab,
'batch_size': args.batch_size,
'shuffle': args.shuffle_train_data == 1,
'question_families': question_families,
'max_samples': args.num_train_samples,
'num_workers': args.loader_num_workers,
}
val_loader_kwargs = {
'question_h5': args.val_question_h5,
'feature_h5': args.val_features_h5,
'vocab': vocab,
'batch_size': args.batch_size,
'question_families': question_families,
'max_samples': args.num_val_samples,
'num_workers': args.loader_num_workers,
}
with ClevrDataLoader(**train_loader_kwargs) as train_loader, \
ClevrDataLoader(**val_loader_kwargs) as val_loader:
train_loop(args, train_loader, val_loader)
if args.use_local_copies == 1 and args.cleanup_local_copies == 1:
os.remove('/tmp/train_questions.h5')
os.remove('/tmp/train_features.h5')
os.remove('/tmp/val_questions.h5')
os.remove('/tmp/val_features.h5')