def main()

in taskcluster/scripts/pipeline/train_taskcluster.py [0:0]


def main(args):
    logging.basicConfig(level=logging.INFO)

    script_args = list(args)
    src = args[2]
    trg = args[3]
    task_id = os.environ["TASK_ID"]
    run_id = int(os.environ["RUN_ID"])
    root_url = os.environ["TASKCLUSTER_ROOT_URL"]
    # Must line up with where model_dir is in `train-taskcluster.sh` while that script
    # still exists.
    model_dir = script_args[6]
    pretrained_model_mode = None
    if len(args) >= PRETRAINED_MODEL_MODE_ARG_NUMBER:
        pretrained_model_mode = script_args[PRETRAINED_MODEL_MODE_ARG_NUMBER - 1]

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    if run_id > 0:
        logging.info("run_id > 0, attempting to resume training from an earlier run...")
        prev_run_id = run_id - 1

        while prev_run_id >= 0:
            try:
                resp = requests.get(
                    ARTIFACTS_URL.format(root_url=root_url, task_id=task_id, run_id=prev_run_id)
                )
                resp.raise_for_status()
            except Exception:
                logging.exception("Caught exception, exiting with distinct code...")
                sys.exit(DOWNLOAD_ERROR_EXIT_CODE)

            run_artifacts = set([os.path.basename(a["name"]) for a in resp.json()["artifacts"]])

            resumable = True

            if run_artifacts.issuperset(
                CONTINUATION_ARTIFACTS.union({f"vocab.{src}.spm", f"vocab.{trg}.spm"})
            ) or run_artifacts.issuperset(CONTINUATION_ARTIFACTS.union({"vocab.spm"})):
                logging.info(
                    f"Run {prev_run_id} appears to have the artifacts we need! Downloading them..."
                )
            else:
                logging.info(f"Run {prev_run_id} is missing some necessary artifacts...")
                resumable = False

            if resumable:
                for artifact in resp.json()["artifacts"]:
                    # Skip Taskcluster logs - we only care about artifacts that the training tools create.
                    if artifact["name"].startswith("public/log"):
                        continue
                    out_name = os.path.basename(artifact["name"])
                    logging.info(f"Fetching {artifact['name']}...")

                    r = requests.get(
                        ARTIFACT_URL.format(
                            root_url=root_url,
                            task_id=task_id,
                            run_id=prev_run_id,
                            artifact_name=artifact["name"],
                        ),
                        stream=True,
                    )
                    if 400 <= r.status_code <= 500:
                        logging.exception(
                            f"Got 4xx error for {artifact['name']}, run {run_id} is not resumable..."
                        )
                        resumable = False
                        break
                    elif r.status_code >= 500:
                        logging.exception("Caught exception, exiting with distinct code...")
                        sys.exit(DOWNLOAD_ERROR_EXIT_CODE)

                    with open(os.path.join(model_dir, out_name), "wb+") as fd:
                        for chunk in r.iter_content(chunk_size=8192):
                            fd.write(chunk)

            if resumable:
                # We successfully downloaded all the artifacts from a previous run. Override
                # the pretrained model mode and we're done!
                pretrained_model_mode = "continue"
                break
            else:
                # We weren't able to get all of the necessary artifacts; try the next previous run
                prev_run_id -= 1

    if pretrained_model_mode:
        if len(script_args) < PRETRAINED_MODEL_MODE_ARG_NUMBER:
            script_args.append(pretrained_model_mode)
        else:
            script_args[PRETRAINED_MODEL_MODE_ARG_NUMBER - 1] = pretrained_model_mode
    subprocess.run([TRAINING_SCRIPT, *script_args], check=True)