in airflow-core/src/airflow/jobs/scheduler_job_runner.py [0:0]
def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -> list[TI]:
"""
Find TIs that are ready for execution based on conditions.
Conditions include:
- pool limits
- DAG max_active_tasks
- executor state
- priority
- max active tis per DAG
- max active tis per DAG run
:param max_tis: Maximum number of TIs to queue in this loop.
:return: list[airflow.models.TaskInstance]
"""
from airflow.models.pool import Pool
from airflow.utils.db import DBLocks
executable_tis: list[TI] = []
if session.get_bind().dialect.name == "postgresql":
# Optimization: to avoid littering the DB errors of "ERROR: canceling statement due to lock
# timeout", try to take out a transactional advisory lock (unlocks automatically on
# COMMIT/ROLLBACK)
lock_acquired = session.execute(
text("SELECT pg_try_advisory_xact_lock(:id)").bindparams(
id=DBLocks.SCHEDULER_CRITICAL_SECTION.value
)
).scalar()
if not lock_acquired:
# Throw an error like the one that would happen with NOWAIT
raise OperationalError(
"Failed to acquire advisory lock", params=None, orig=RuntimeError("55P03")
)
# Get the pool settings. We get a lock on the pool rows, treating this as a "critical section"
# Throws an exception if lock cannot be obtained, rather than blocking
pools = Pool.slots_stats(lock_rows=True, session=session)
# If the pools are full, there is no point doing anything!
# If _somehow_ the pool is overfull, don't let the limit go negative - it breaks SQL
pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values())
if pool_slots_free == 0:
self.log.debug("All pools are full!")
return []
max_tis = min(max_tis, pool_slots_free)
starved_pools = {pool_name for pool_name, stats in pools.items() if stats["open"] <= 0}
# dag_id to # of running tasks and (dag_id, task_id) to # of running tasks.
concurrency_map = ConcurrencyMap()
concurrency_map.load(session=session)
# Number of tasks that cannot be scheduled because of no open slot in pool
num_starving_tasks_total = 0
# dag and task ids that can't be queued because of concurrency limits
starved_dags: set[str] = set()
starved_tasks: set[tuple[str, str]] = set()
starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] = set()
pool_num_starving_tasks: dict[str, int] = Counter()
for loop_count in itertools.count(start=1):
num_starved_pools = len(starved_pools)
num_starved_dags = len(starved_dags)
num_starved_tasks = len(starved_tasks)
num_starved_tasks_task_dagrun_concurrency = len(starved_tasks_task_dagrun_concurrency)
query = (
select(TI)
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
.join(TI.dag_run)
.where(DR.state == DagRunState.RUNNING)
.join(TI.dag_model)
.where(~DM.is_paused)
.where(TI.state == TaskInstanceState.SCHEDULED)
.where(DM.bundle_name.is_not(None))
.options(selectinload(TI.dag_model))
.order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
)
if starved_pools:
query = query.where(TI.pool.not_in(starved_pools))
if starved_dags:
query = query.where(TI.dag_id.not_in(starved_dags))
if starved_tasks:
query = query.where(tuple_(TI.dag_id, TI.task_id).not_in(starved_tasks))
if starved_tasks_task_dagrun_concurrency:
query = query.where(
tuple_(TI.dag_id, TI.run_id, TI.task_id).not_in(starved_tasks_task_dagrun_concurrency)
)
query = query.limit(max_tis)
timer = Stats.timer("scheduler.critical_section_query_duration")
timer.start()
try:
query = with_row_locks(query, of=TI, session=session, skip_locked=True)
task_instances_to_examine: list[TI] = session.scalars(query).all()
timer.stop(send=True)
except OperationalError as e:
timer.stop(send=False)
raise e
# TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything.
# Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine))
if not task_instances_to_examine:
self.log.debug("No tasks to consider for execution.")
break
# Put one task instance on each line
task_instance_str = "\n".join(f"\t{x!r}" for x in task_instances_to_examine)
self.log.info("%s tasks up for execution:\n%s", len(task_instances_to_examine), task_instance_str)
executor_slots_available: dict[ExecutorName, int] = {}
# First get a mapping of executor names to slots they have available
for executor in self.job.executors:
if TYPE_CHECKING:
# All executors should have a name if they are initted from the executor_loader.
# But we need to check for None to make mypy happy.
assert executor.name
executor_slots_available[executor.name] = executor.slots_available
for task_instance in task_instances_to_examine:
pool_name = task_instance.pool
pool_stats = pools.get(pool_name)
if not pool_stats:
self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name)
starved_pools.add(pool_name)
continue
# Make sure to emit metrics if pool has no starving tasks
pool_num_starving_tasks.setdefault(pool_name, 0)
pool_total = pool_stats["total"]
open_slots = pool_stats["open"]
if open_slots <= 0:
self.log.info(
"Not scheduling since there are %s open slots in pool %s", open_slots, pool_name
)
# Can't schedule any more since there are no more open slots.
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_pools.add(pool_name)
continue
if task_instance.pool_slots > pool_total:
self.log.warning(
"Not executing %s. Requested pool slots (%s) are greater than "
"total pool slots: '%s' for pool: %s.",
task_instance,
task_instance.pool_slots,
pool_total,
pool_name,
)
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue
if task_instance.pool_slots > open_slots:
self.log.info(
"Not executing %s since it requires %s slots "
"but there are %s open slots in the pool %s.",
task_instance,
task_instance.pool_slots,
open_slots,
pool_name,
)
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
# Though we can execute tasks with lower priority if there's enough room
continue
# Check to make sure that the task max_active_tasks of the DAG hasn't been
# reached.
dag_id = task_instance.dag_id
dag_run_key = (dag_id, task_instance.run_id)
current_active_tasks_per_dag_run = concurrency_map.dag_run_active_tasks_map[dag_run_key]
dag_max_active_tasks = task_instance.dag_model.max_active_tasks
self.log.info(
"DAG %s has %s/%s running and queued tasks",
dag_id,
current_active_tasks_per_dag_run,
dag_max_active_tasks,
)
if current_active_tasks_per_dag_run >= dag_max_active_tasks:
self.log.info(
"Not executing %s since the number of tasks running or queued "
"from DAG %s is >= to the DAG's max_active_tasks limit of %s",
task_instance,
dag_id,
dag_max_active_tasks,
)
starved_dags.add(dag_id)
continue
if task_instance.dag_model.has_task_concurrency_limits:
# Many dags don't have a task_concurrency, so where we can avoid loading the full
# serialized DAG the better.
serialized_dag = self.scheduler_dag_bag.get_dag(
dag_run=task_instance.dag_run, session=session
)
# If the dag is missing, fail the task and continue to the next task.
if not serialized_dag:
self.log.error(
"DAG '%s' for task instance %s not found in serialized_dag table",
dag_id,
task_instance,
)
session.execute(
update(TI)
.where(TI.dag_id == dag_id, TI.state == TaskInstanceState.SCHEDULED)
.values(state=TaskInstanceState.FAILED)
.execution_options(synchronize_session="fetch")
)
continue
task_concurrency_limit: int | None = None
if serialized_dag.has_task(task_instance.task_id):
task_concurrency_limit = serialized_dag.get_task(
task_instance.task_id
).max_active_tis_per_dag
if task_concurrency_limit is not None:
current_task_concurrency = concurrency_map.task_concurrency_map[
(task_instance.dag_id, task_instance.task_id)
]
if current_task_concurrency >= task_concurrency_limit:
self.log.info(
"Not executing %s since the task concurrency for this task has been reached.",
task_instance,
)
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue
task_dagrun_concurrency_limit: int | None = None
if serialized_dag.has_task(task_instance.task_id):
task_dagrun_concurrency_limit = serialized_dag.get_task(
task_instance.task_id
).max_active_tis_per_dagrun
if task_dagrun_concurrency_limit is not None:
current_task_dagrun_concurrency = concurrency_map.task_dagrun_concurrency_map[
(task_instance.dag_id, task_instance.run_id, task_instance.task_id)
]
if current_task_dagrun_concurrency >= task_dagrun_concurrency_limit:
self.log.info(
"Not executing %s since the task concurrency per DAG run for"
" this task has been reached.",
task_instance,
)
starved_tasks_task_dagrun_concurrency.add(
(task_instance.dag_id, task_instance.run_id, task_instance.task_id)
)
continue
if executor_obj := self._try_to_load_executor(task_instance.executor):
if TYPE_CHECKING:
# All executors should have a name if they are initted from the executor_loader.
# But we need to check for None to make mypy happy.
assert executor_obj.name
if executor_slots_available[executor_obj.name] <= 0:
self.log.debug(
"Not scheduling %s since its executor %s does not currently have any more "
"available slots"
)
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue
executor_slots_available[executor_obj.name] -= 1
else:
# This is a defensive guard for if we happen to have a task who's executor cannot be
# found. The check in the dag parser should make this not realistically possible but the
# loader can fail if some direct DB modification has happened or another as yet unknown
# edge case. _try_to_load_executor will log an error message explaining the executor
# cannot be found.
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue
executable_tis.append(task_instance)
open_slots -= task_instance.pool_slots
concurrency_map.dag_run_active_tasks_map[dag_run_key] += 1
concurrency_map.task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1
concurrency_map.task_dagrun_concurrency_map[
(task_instance.dag_id, task_instance.run_id, task_instance.task_id)
] += 1
pool_stats["open"] = open_slots
is_done = executable_tis or len(task_instances_to_examine) < max_tis
# Check this to avoid accidental infinite loops
found_new_filters = (
len(starved_pools) > num_starved_pools
or len(starved_dags) > num_starved_dags
or len(starved_tasks) > num_starved_tasks
or len(starved_tasks_task_dagrun_concurrency) > num_starved_tasks_task_dagrun_concurrency
)
if is_done or not found_new_filters:
break
self.log.info(
"Found no task instances to queue on query iteration %s "
"but there could be more candidate task instances to check.",
loop_count,
)
for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
Stats.gauge(f"pool.starving_tasks.{pool_name}", num_starving_tasks)
# Same metric with tagging
Stats.gauge("pool.starving_tasks", num_starving_tasks, tags={"pool_name": pool_name})
Stats.gauge("scheduler.tasks.starving", num_starving_tasks_total)
Stats.gauge("scheduler.tasks.executable", len(executable_tis))
if executable_tis:
task_instance_str = "\n".join(f"\t{x!r}" for x in executable_tis)
self.log.info("Setting the following tasks to queued state:\n%s", task_instance_str)
# set TIs to queued state
filter_for_tis = TI.filter_for_tis(executable_tis)
session.execute(
update(TI)
.where(filter_for_tis)
.values(
# TODO[ha]: should we use func.now()? How does that work with DB timezone
# on mysql when it's not UTC?
state=TaskInstanceState.QUEUED,
queued_dttm=timezone.utcnow(),
queued_by_job_id=self.job.id,
)
.execution_options(synchronize_session=False)
)
for ti in executable_tis:
ti.emit_state_change_metric(TaskInstanceState.QUEUED)
for ti in executable_tis:
make_transient(ti)
return executable_tis