utils/trigger_training.py (222 lines of code) (raw):
"""
Trigger a training task from the CLI on your current branch.
For example:
task train -- --config configs/experiments-2024-H2/en-lt-experiments-2024-H2-hplt-nllb.yml
"""
import argparse
import datetime
import json
from pathlib import Path
import subprocess
import sys
from time import sleep
from typing import Any, Optional, Tuple
from github import Github
import os
import yaml
import jsone
from taskgraph.util.taskcluster import get_artifact
from taskcluster import Hooks
from taskcluster.helper import TaskclusterConfig
ROOT_URL = "https://firefox-ci-tc.services.mozilla.com"
def run(command: list[str], env={}):
return subprocess.run(
command, capture_output=True, check=True, text=True, env={**os.environ, **env}
).stdout.strip()
def check_if_pushed(branch: str) -> bool:
try:
remote_commit = run(["git", "rev-parse", f"origin/{branch}"])
local_commit = run(["git", "rev-parse", branch])
return local_commit == remote_commit
except subprocess.CalledProcessError:
return False
def get_decision_task_push(branch: str):
g = Github()
repo_name = "mozilla/translations"
print(f'Looking up "{repo_name}"')
repo = g.get_repo(repo_name)
ref = f"heads/{branch}"
print('Finding the "Decision Task (push)"')
checks = repo.get_commit(ref).get_check_runs()
decision_task = None
for check in checks:
if check.name == "Decision Task (push)":
decision_task = check
return decision_task
def get_task_id_from_url(task_url: str):
"""
Extract the task id from a task url
e.g. https://firefox-ci-tc.services.mozilla.com/tasks/PhAMJTZBSmSeWStXbR72xA
returns
"PhAMJTZBSmSeWStXbR72xA"
"""
return task_url.split("/")[-1]
def get_train_action(decision_task_id: str):
actions_json = get_artifact(decision_task_id, "public/actions.json")
for action in actions_json["actions"]:
if action["name"] == "train":
return action
print("Could not find the train action.")
print(actions_json)
sys.exit(1)
def trigger_training(decision_task_id: str, config: dict[str, Any]) -> Optional[str]:
taskcluster = TaskclusterConfig(ROOT_URL)
taskcluster.auth()
hooks: Hooks = taskcluster.get_service("hooks")
train_action = get_train_action(decision_task_id)
# Render the payload using the jsone schema.
hook_payload = jsone.render(
train_action["hookPayload"],
{
"input": config,
"taskId": None,
"taskGroupId": decision_task_id,
},
)
start_stage: str = config["target-stage"]
if start_stage.startswith("train"):
evaluate_stage = start_stage.replace("train-", "evaluate-")
red = "\033[91m"
reset = "\x1b[0m"
print(
f'\n{red}WARNING:{reset} target-stage is "{start_stage}", did you mean "{evaluate_stage}"'
)
confirmation = input("\nStart training? [Y,n]\n")
if confirmation and confirmation.lower() != "y":
return None
# https://docs.taskcluster.net/docs/reference/core/hooks/api#triggerHook
response: Any = hooks.triggerHook(
train_action["hookGroupId"], train_action["hookId"], hook_payload
)
action_task_id = response["status"]["taskId"]
print(f"Train action triggered: {ROOT_URL}/tasks/{action_task_id}")
return action_task_id
def validate_taskcluster_credentials():
try:
run(["taskcluster", "--help"])
except Exception:
print("The taskcluster client library must be installed on the system.")
print("https://github.com/taskcluster/taskcluster/tree/main/clients/client-shell")
sys.exit(1)
if not os.environ.get("TASKCLUSTER_ACCESS_TOKEN"):
print("You must log in to Taskcluster. Run the following:")
print(f'eval `TASKCLUSTER_ROOT_URL="{ROOT_URL}" taskcluster signin`')
sys.exit(1)
try:
run(
[
"taskcluster",
"signin",
"--check",
],
{"TASKCLUSTER_ROOT_URL": ROOT_URL},
)
except Exception:
print("Your Taskcluster credentials have expired. Run the following:")
print(f'eval `TASKCLUSTER_ROOT_URL="{ROOT_URL}" taskcluster signin`')
sys.exit(1)
def log_config_info(config_path: Path, config: dict):
print(f"\nUsing config: {config_path}\n")
experiment = config["experiment"]
config_details: list[Tuple[str, Any]] = []
config_details.append(("experiment.name", experiment["name"]))
config_details.append(("experiment.src", experiment["src"]))
config_details.append(("experiment.trg", experiment["trg"]))
if config.get("start-stage"):
config_details.append(("start-stage", config["start-stage"]))
config_details.append(("target-stage", config["target-stage"]))
previous_group_ids = config.get("previous_group_ids")
if previous_group_ids:
config_details.append(("previous_group_ids", previous_group_ids))
pretrained_models: Optional[dict] = experiment.get("pretrained-models")
if pretrained_models:
for key, value in pretrained_models.items():
config_details.append((key, json.dumps(value, indent=2)))
key_len = 0
for key, _ in config_details:
key_len = max(key_len, len(key))
for key, value in config_details:
if "\n" in value:
# Nicely indent any multiline value.
padding = " " * (key_len + 6)
lines = [padding + n for n in value.split("\n")]
value = "\n".join(lines).strip() # noqa: PLW2901
print(f"{key.rjust(key_len + 4, ' ')}: {value}")
def write_to_log(config_path: Path, config: dict, action_task_id: str, branch: str):
"""
Persist the training log to disk.
"""
training_log = Path(__file__).parent / "../trigger-training.log"
experiment = config["experiment"]
git_hash = run(["git", "rev-parse", "--short", branch]).strip()
with open(training_log, "a") as file:
lines = [
"",
f"config: {config_path}",
f"name: {experiment['name']}",
f"langpair: {experiment['src']}-{experiment['trg']}",
f"time: {datetime.datetime.now()}",
f"train action: {ROOT_URL}/tasks/{action_task_id}",
f"branch: {branch}",
f"hash: {git_hash}",
]
for line in lines:
file.write(line + "\n")
def main() -> None:
parser = argparse.ArgumentParser(
description=__doc__,
# Preserves whitespace in the help text.
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--config", type=Path, required=True, help="Path the config")
parser.add_argument(
"--branch",
type=str,
required=False,
help="The name of the branch, defaults to the current branch",
)
parser.add_argument(
"--force",
action="store_true",
help="Skip the checks for the branch being up to date",
)
parser.add_argument(
"--no_interactive",
action="store_true",
help="Skip the confirmation",
)
args = parser.parse_args()
branch = args.branch
validate_taskcluster_credentials()
if branch:
print(f"Using --branch: {branch}")
else:
branch = run(["git", "branch", "--show-current"])
print(f"Using current branch: {branch}")
if branch != "main" and not branch.startswith("dev") and not branch.startswith("release"):
print(f'The git branch "{branch}" must be "main", or start with "dev" or "release"')
sys.exit(1)
if check_if_pushed(branch):
print(f"Branch '{branch}' is up to date with origin.")
elif args.force:
print(
f"Branch '{branch}' is not fully pushed to origin, bypassing this check because of --force."
)
else:
print(
f"Error: Branch '{branch}' is not fully pushed to origin. Use --force or push your changes."
)
sys.exit(1)
if branch != "main" and not branch.startswith("dev") and not branch.startswith("release"):
print(
f"Branch must be `main` or start with `dev` or `release` for training to run. Detected branch was {branch}"
)
timeout = 20
while True:
decision_task = get_decision_task_push(branch)
if decision_task:
if decision_task.status == "completed" and decision_task.conclusion == "success":
# The decision task is completed.
break
elif decision_task.status == "queued":
print(f"Decision task is queued, trying again in {timeout} seconds")
elif decision_task.status == "in_progress":
print(f"Decision task is in progress, trying again in {timeout} seconds")
else:
# The task failed.
print(
f'Decision task is "{decision_task.status}" with the conclusion "{decision_task.conclusion}"'
)
sys.exit(1)
else:
print(f"Decision task is not available, trying again in {timeout} seconds")
sleep(timeout)
decision_task_id = get_task_id_from_url(decision_task.details_url)
with args.config.open() as file:
config: dict = yaml.safe_load(file)
log_config_info(args.config, config)
action_task_id = trigger_training(decision_task_id, config)
if action_task_id:
write_to_log(args.config, config, action_task_id, branch)
if __name__ == "__main__":
main()