in pytorch_translate/train.py [0:0]
def setup_training_model(args):
"""Parse args, load dataset, and build model with criterion."""
if not torch.cuda.is_available():
print("Warning: training without CUDA is likely to be slow!")
else:
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed)
# Setup task and load dataset
task = tasks.setup_task(args)
# Build model and criterion
model = task.build_model(args)
print("| building criterion")
criterion = task.build_criterion(args)
print(f"| model {args.arch}, criterion {criterion.__class__.__name__}")
print(
f"| num. model params: \
{sum(p.numel() for p in model.parameters())}"
)
if args.task == constants.SEMI_SUPERVISED_TASK:
# TODO(T35638969): hide this inside the task itself, just use self.args
task.load_dataset(
split=args.train_subset,
src_bin_path=args.train_source_binary_path,
tgt_bin_path=args.train_target_binary_path,
forward_model=task.forward_model,
backward_model=task.backward_model,
)
elif args.task == "pytorch_translate_denoising_autoencoder":
task.load_dataset(
split=args.train_subset,
src_bin_path=args.train_source_binary_path,
tgt_bin_path=args.train_target_binary_path,
seed=args.seed,
use_noiser=True,
)
elif args.task == "dual_learning_task":
task.load_dataset(split=args.train_subset, seed=args.seed)
elif args.task == "pytorch_translate_knowledge_distillation":
task.load_dataset(
split=args.train_subset,
src_bin_path=args.train_source_binary_path,
tgt_bin_path=args.train_target_binary_path,
weights_file=getattr(args, "train_weights_path", None),
is_train=True,
)
elif args.task == "pytorch_translate_cross_lingual_lm":
task.load_dataset(args.train_subset, combine=True, epoch=0)
elif args.task == "pytorch_translate":
# Support both single and multi path loading for now
task.load_dataset(
split=args.train_subset,
src_bin_path=args.train_source_binary_path,
tgt_bin_path=args.train_target_binary_path,
weights_file=getattr(args, "train_weights_path", None),
is_npz=not args.fairseq_data_format,
)
else:
# Support both single and multi path loading for now
task.load_dataset(
split=args.train_subset,
src_bin_path=args.train_source_binary_path,
tgt_bin_path=args.train_target_binary_path,
weights_file=getattr(args, "train_weights_path", None),
)
if args.task == "dual_learning_task":
task.load_dataset(split=args.valid_subset, seed=args.seed)
elif args.task == "pytorch_translate_cross_lingual_lm":
task.load_dataset(args.valid_subset, combine=True, epoch=0)
elif args.task == "pytorch_translate":
task.load_dataset(
split=args.valid_subset,
src_bin_path=args.eval_source_binary_path,
tgt_bin_path=args.eval_target_binary_path,
is_npz=not args.fairseq_data_format,
)
else:
task.load_dataset(
split=args.valid_subset,
src_bin_path=args.eval_source_binary_path,
tgt_bin_path=args.eval_target_binary_path,
)
return task, model, criterion