in utils/model_registry.py [0:0]
def collect_models(tasks: list[Task], training_run: TrainingRun, upload: bool):
"""
Lookup models from Google Cloud Storage.
"""
backwards = find_latest_task(
tasks,
# This was renamed
match_by_label(r"^train-backwards-")
or match_by_label(r"^backtranslations-train-backwards-model-"),
)
if backwards:
training_run.backwards = get_model_without_evals(
backwards,
training_run,
upload,
model_name="backward",
)
train_teacher_1 = find_latest_task(
tasks,
match_by_label(r"^train-teacher-.*-1") or match_by_label(r"^train-teacher-model-.*-1"),
)
if train_teacher_1:
training_run.teacher_1 = get_model(
train_teacher_1,
training_run,
tasks,
upload,
tc_model_name="teacher",
gcs_model_name="teacher0",
gcs_eval_name="teacher0",
)
train_teacher_2 = find_latest_task(tasks, match_by_label(r"^train-teacher-model-.*-2"))
if train_teacher_2:
training_run.teacher_2 = get_model(
train_teacher_2,
training_run,
tasks,
upload,
tc_model_name="teacher",
gcs_model_name="teacher1",
gcs_eval_name="teacher1",
)
student_finetuned = find_latest_task(
tasks,
match_by_label(r"^finetune-student")
or match_by_label(r"^distillation-student-model-finetune-"),
)
if student_finetuned:
training_run.student_finetuned = get_model(
student_finetuned,
training_run,
tasks,
upload,
tc_model_name="finetuned-student",
gcs_model_name="student-finetuned",
gcs_eval_name="student-finetuned",
)
train_student_task = find_latest_task(
tasks,
match_by_label(r"^train-student-")
or match_by_label(r"^distillation-student-model-train-"),
)
if train_student_task:
training_run.student = get_model(
train_student_task,
training_run,
tasks,
upload,
tc_model_name="student",
gcs_model_name="student",
gcs_eval_name="student",
)
student_quantize_task = find_latest_task(tasks, match_by_label(r"^quantize-"))
if student_quantize_task:
training_run.student_quantized = get_model(
student_quantize_task,
training_run,
tasks,
upload,
tc_model_name="quantized",
gcs_model_name="quantized",
gcs_eval_name="speed",
)
student_export_task = find_latest_task(tasks, match_by_label(r"^export-"))
if student_export_task:
training_run.student_exported = get_model(
student_export_task,
training_run,
tasks,
# These logs aren't useful to retain, as there is no training happening here.
upload=False,
tc_model_name="export",
gcs_model_name="exported",
gcs_eval_name="exported",
)
if training_run.student_quantized:
# The export step doesn't have an explicit eval, so take
# the one from the quantized step.
training_run.student_exported.flores = training_run.student_quantized.flores