in pipeline/train/train.py [0:0]
def main() -> None:
parser = argparse.ArgumentParser(
description=__doc__,
# Preserves whitespace in the help text.
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--model_type",
type=ModelType,
choices=ModelType,
required=True,
help="The type of model to train",
)
parser.add_argument(
"--student_model",
type=StudentModel,
choices=StudentModel,
required=False,
default=StudentModel.tiny,
help="Type of student model",
)
parser.add_argument(
"--training_type",
type=TrainingType,
choices=TrainingType,
help="Type of teacher training",
)
parser.add_argument(
"--gpus",
type=str,
required=True,
help='The indexes of the GPUs to use on a system, e.g. --gpus "0 1 2 3"',
)
parser.add_argument(
"--marian_dir",
type=Path,
required=True,
help="Path to Marian binary directory. This allows for overriding to use the browser-mt fork.",
)
parser.add_argument(
"--workspace",
type=str,
required=True,
help="The amount of Marian memory (in MB) to preallocate",
)
parser.add_argument("--src", type=str, help="Source language")
parser.add_argument("--trg", type=str, help="Target language")
parser.add_argument(
"--train_set_prefixes",
type=str,
help="Comma separated prefixes to datasets for curriculum learning",
)
parser.add_argument("--validation_set_prefix", type=str, help="Prefix to validation dataset")
parser.add_argument("--artifacts", type=Path, help="Where to save the model artifacts")
parser.add_argument("--src_vocab", type=Path, help="Path to source language vocab file")
parser.add_argument("--trg_vocab", type=Path, help="Path to target language vocab file")
parser.add_argument(
"--best_model_metric",
type=BestModelMetric,
help="Multiple metrics are gathered, but only the best model for a given metric will be retained",
)
parser.add_argument(
"--alignments",
type=str,
help="Comma separated alignment paths corresponding to each training dataset, or 'None' to train without alignments",
)
parser.add_argument("--seed", type=int, help="Random seed")
parser.add_argument(
"--teacher_mode",
type=TeacherMode,
choices=TeacherMode,
help="Teacher mode",
)
parser.add_argument(
"extra_marian_args",
nargs=argparse.REMAINDER,
help="Additional parameters for the training script",
)
with tempfile.TemporaryDirectory() as temp_dir:
train_cli = TrainCLI(parser.parse_args(), Path(temp_dir))
train_cli.log_config()
train_cli.validate_args()
train_cli.build_datasets()
train_cli.generate_opustrainer_config()
train_cli.run_training()