def setup_training_model()

in pytorch_translate/train.py [0:0]


def setup_training_model(args):
    """Parse args, load dataset, and build model with criterion."""
    if not torch.cuda.is_available():
        print("Warning: training without CUDA is likely to be slow!")
    else:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task and load dataset
    task = tasks.setup_task(args)

    # Build model and criterion
    model = task.build_model(args)
    print("| building criterion")
    criterion = task.build_criterion(args)
    print(f"| model {args.arch}, criterion {criterion.__class__.__name__}")
    print(
        f"| num. model params: \
        {sum(p.numel() for p in model.parameters())}"
    )

    if args.task == constants.SEMI_SUPERVISED_TASK:
        # TODO(T35638969): hide this inside the task itself, just use self.args
        task.load_dataset(
            split=args.train_subset,
            src_bin_path=args.train_source_binary_path,
            tgt_bin_path=args.train_target_binary_path,
            forward_model=task.forward_model,
            backward_model=task.backward_model,
        )
    elif args.task == "pytorch_translate_denoising_autoencoder":
        task.load_dataset(
            split=args.train_subset,
            src_bin_path=args.train_source_binary_path,
            tgt_bin_path=args.train_target_binary_path,
            seed=args.seed,
            use_noiser=True,
        )
    elif args.task == "dual_learning_task":
        task.load_dataset(split=args.train_subset, seed=args.seed)
    elif args.task == "pytorch_translate_knowledge_distillation":
        task.load_dataset(
            split=args.train_subset,
            src_bin_path=args.train_source_binary_path,
            tgt_bin_path=args.train_target_binary_path,
            weights_file=getattr(args, "train_weights_path", None),
            is_train=True,
        )
    elif args.task == "pytorch_translate_cross_lingual_lm":
        task.load_dataset(args.train_subset, combine=True, epoch=0)
    elif args.task == "pytorch_translate":
        # Support both single and multi path loading for now
        task.load_dataset(
            split=args.train_subset,
            src_bin_path=args.train_source_binary_path,
            tgt_bin_path=args.train_target_binary_path,
            weights_file=getattr(args, "train_weights_path", None),
            is_npz=not args.fairseq_data_format,
        )
    else:
        # Support both single and multi path loading for now
        task.load_dataset(
            split=args.train_subset,
            src_bin_path=args.train_source_binary_path,
            tgt_bin_path=args.train_target_binary_path,
            weights_file=getattr(args, "train_weights_path", None),
        )

    if args.task == "dual_learning_task":
        task.load_dataset(split=args.valid_subset, seed=args.seed)
    elif args.task == "pytorch_translate_cross_lingual_lm":
        task.load_dataset(args.valid_subset, combine=True, epoch=0)
    elif args.task == "pytorch_translate":
        task.load_dataset(
            split=args.valid_subset,
            src_bin_path=args.eval_source_binary_path,
            tgt_bin_path=args.eval_target_binary_path,
            is_npz=not args.fairseq_data_format,
        )
    else:
        task.load_dataset(
            split=args.valid_subset,
            src_bin_path=args.eval_source_binary_path,
            tgt_bin_path=args.eval_target_binary_path,
        )

    return task, model, criterion