def train_action()

in taskcluster/translations_taskgraph/actions/train.py [0:0]


def train_action(parameters, graph_config, input, task_group_id, task_id):
    # TODO: Add a whack load of verification here. Things such as:
    # - datasets all exist
    # - locale pair exists for each dataset
    # - stage is valid
    # etc.

    parameters = dict(parameters)

    start_stage = input.pop("start-stage", None)
    if start_stage:
        if "previous_group_ids" not in input:
            raise Exception(
                "'previous_group_ids' is required to use 'start-stage' (otherwise we can't skip earlier tasks)"
            )

        previous_group_ids = input.pop("previous_group_ids")

        # First, we create one big graph out of all of the tasks from the specified group IDs.
        label_to_task_id = {}
        combined_full_task_graph = {}
        for graph_id in previous_group_ids:
            label_to_task_id.update(get_artifact(graph_id, "public/label-to-taskid.json"))
            full_task_graph = get_artifact(graph_id, "public/full-task-graph.json")
            combined_full_task_graph.update(full_task_graph)
        _, combined_full_task_graph = TaskGraph.from_json(combined_full_task_graph)

        # Next, we find the task id(s) corresponding of the tasks that match the stage
        # we want to start at.
        start_task_ids = []
        for label, task in combined_full_task_graph.tasks.items():
            if task.attributes.get("stage") == start_stage:
                start_task_ids.append(label_to_task_id[label])

        # Finally, we walk up the graph from our starting point and add any tasks found
        # as `existing_tasks`. These map task labels (eg: backtranslations-train-backwards-model-ru-en) to
        # task ids, and will be used instead of scheduling new tasks for any tasks with
        # an identical name.
        # As of taskgraph 13.0 `get_ancestors` returns taskids -> labels
        # `existing_tasks` needs the opposite
        parameters["existing_tasks"] = {v: k for k, v in get_ancestors(start_task_ids).items()}

    # Override the `existing_tasks` explicitly provided in the action's input
    existing_tasks = input.pop("existing_tasks", {})

    # Find and log `overridden_existing_tasks`
    overridden_existing_tasks = {
        existing_task: parameters["existing_tasks"][existing_task]
        for existing_task in existing_tasks.keys()
        if existing_task in parameters["existing_tasks"]
    }

    if overridden_existing_tasks:
        logger.info(
            f"Old values for `overridden_existing_tasks`: {json.dumps(overridden_existing_tasks, indent=2)}"
        )

    # Do the override!
    parameters["existing_tasks"].update(existing_tasks)

    # Log the new values for the `overridden_existing_tasks`
    new_values_for_overridden = {
        existing_task: parameters["existing_tasks"][existing_task]
        for existing_task in overridden_existing_tasks.keys()
    }

    if new_values_for_overridden:
        logger.info(
            f"New values for `overridden_existing_tasks`: {json.dumps(new_values_for_overridden, indent=2)}"
        )

    parameters["target_tasks_method"] = "train-target-tasks"
    parameters["optimize_target_tasks"] = True
    parameters["tasks_for"] = "action"
    parameters["training_config"] = input

    validate_pretrained_models(parameters)

    parameters = Parameters(**parameters)
    taskgraph_decision({"root": graph_config.root_dir}, parameters=parameters)