in airflow-core/src/airflow/models/taskmap.py [0:0]
def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
"""
Create the mapped task instances for mapped task.
:raise NotMapped: If this task does not need expansion.
:return: The newly created mapped task instances (if any) in ascending
order by map index, and the maximum map index value.
"""
from airflow.models.baseoperator import BaseOperator as DBBaseOperator
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.settings import task_instance_mutation_hook
if not isinstance(task, (BaseOperator, MappedOperator)):
raise RuntimeError(
f"cannot expand unrecognized operator type {type(task).__module__}.{type(task).__name__}"
)
try:
total_length: int | None = DBBaseOperator.get_mapped_ti_count(task, run_id, session=session)
except NotFullyPopulated as e:
if not task.dag or not task.dag.partial:
task.log.error(
"Cannot expand %r for run %s; missing upstream values: %s",
task,
run_id,
sorted(e.missing),
)
total_length = None
state: TaskInstanceState | None = None
unmapped_ti: TaskInstance | None = session.scalars(
select(TaskInstance).where(
TaskInstance.dag_id == task.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == run_id,
TaskInstance.map_index == -1,
or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)),
)
).one_or_none()
all_expanded_tis: list[TaskInstance] = []
if unmapped_ti:
if TYPE_CHECKING:
assert task.dag is None or isinstance(task.dag, SchedulerDAG)
# The unmapped task instance still exists and is unfinished, i.e. we
# haven't tried to run it before.
if total_length is None:
# If the DAG is partial, it's likely that the upstream tasks
# are not done yet, so the task can't fail yet.
if not task.dag or not task.dag.partial:
unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
elif total_length < 1:
# If the upstream maps this to a zero-length value, simply mark
# the unmapped task instance as SKIPPED (if needed).
task.log.info(
"Marking %s as SKIPPED since the map has %d values to expand",
unmapped_ti,
total_length,
)
unmapped_ti.state = TaskInstanceState.SKIPPED
else:
zero_index_ti_exists = exists_query(
TaskInstance.dag_id == task.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == run_id,
TaskInstance.map_index == 0,
session=session,
)
if not zero_index_ti_exists:
# Otherwise convert this into the first mapped index, and create
# TaskInstance for other indexes.
unmapped_ti.map_index = 0
task.log.debug("Updated in place to become %s", unmapped_ti)
all_expanded_tis.append(unmapped_ti)
# execute hook for task instance map index 0
task_instance_mutation_hook(unmapped_ti)
session.flush()
else:
task.log.debug("Deleting the original task instance: %s", unmapped_ti)
session.delete(unmapped_ti)
state = unmapped_ti.state
dag_version_id = unmapped_ti.dag_version_id
if total_length is None or total_length < 1:
# Nothing to fixup.
indexes_to_map: Iterable[int] = ()
else:
# Only create "missing" ones.
current_max_mapping = session.scalar(
select(func.max(TaskInstance.map_index)).where(
TaskInstance.dag_id == task.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == run_id,
)
)
indexes_to_map = range(current_max_mapping + 1, total_length)
if unmapped_ti:
dag_version_id = unmapped_ti.dag_version_id
elif dag_version := DagVersion.get_latest_version(task.dag_id, session=session):
dag_version_id = dag_version.id
else:
dag_version_id = None
for index in indexes_to_map:
# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
ti = TaskInstance(
task,
run_id=run_id,
map_index=index,
state=state,
dag_version_id=dag_version_id,
)
task.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
ti.refresh_from_task(task) # session.merge() loses task information.
all_expanded_tis.append(ti)
# Coerce the None case to 0 -- these two are almost treated identically,
# except the unmapped ti (if exists) is marked to different states.
total_expanded_ti_count = total_length or 0
# Any (old) task instances with inapplicable indexes (>= the total
# number we need) are set to "REMOVED".
query = select(TaskInstance).where(
TaskInstance.dag_id == task.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == run_id,
TaskInstance.map_index >= total_expanded_ti_count,
)
query = with_row_locks(query, of=TaskInstance, session=session, skip_locked=True)
to_update = session.scalars(query)
for ti in to_update:
ti.state = TaskInstanceState.REMOVED
session.flush()
return all_expanded_tis, total_expanded_ti_count - 1