from typing import TYPE_CHECKING

from airflow.models.taskinstance import clear_task_instances
from airflow.utils.context import Context
from airflow.utils.db import provide_session
from sqlalchemy.orm.session import Session

if TYPE_CHECKING:
    from airflow.models.dagrun import DagRun


@provide_session
def retry_tasks_callback(context: Context, session: Session | None = None) -> None:
    """
    Clear tasks specified by the `retry_tasks` task param.

    Intended to be used to as an `on_retry_callback` to also retry other tasks when a task fails.
    """
    retry_task_ids: list[str] = context["params"].get("retry_tasks", [])
    if isinstance(retry_task_ids, str):
        retry_task_ids = [retry_task_ids]
    dag_run: DagRun = context["dag_run"]
    retry_task_instances = [
        task_instance
        for task_instance in dag_run.get_task_instances(session=session)
        if task_instance.task_id in retry_task_ids
    ]
    if retry_task_instances:
        clear_task_instances(retry_task_instances, session=session)
