def expand_mapped_task()

in airflow-core/src/airflow/models/taskmap.py [0:0]


    def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
        """
        Create the mapped task instances for mapped task.

        :raise NotMapped: If this task does not need expansion.
        :return: The newly created mapped task instances (if any) in ascending
            order by map index, and the maximum map index value.
        """
        from airflow.models.baseoperator import BaseOperator as DBBaseOperator
        from airflow.models.expandinput import NotFullyPopulated
        from airflow.models.taskinstance import TaskInstance
        from airflow.sdk.bases.operator import BaseOperator
        from airflow.sdk.definitions.mappedoperator import MappedOperator
        from airflow.settings import task_instance_mutation_hook

        if not isinstance(task, (BaseOperator, MappedOperator)):
            raise RuntimeError(
                f"cannot expand unrecognized operator type {type(task).__module__}.{type(task).__name__}"
            )

        try:
            total_length: int | None = DBBaseOperator.get_mapped_ti_count(task, run_id, session=session)
        except NotFullyPopulated as e:
            if not task.dag or not task.dag.partial:
                task.log.error(
                    "Cannot expand %r for run %s; missing upstream values: %s",
                    task,
                    run_id,
                    sorted(e.missing),
                )
            total_length = None

        state: TaskInstanceState | None = None
        unmapped_ti: TaskInstance | None = session.scalars(
            select(TaskInstance).where(
                TaskInstance.dag_id == task.dag_id,
                TaskInstance.task_id == task.task_id,
                TaskInstance.run_id == run_id,
                TaskInstance.map_index == -1,
                or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)),
            )
        ).one_or_none()

        all_expanded_tis: list[TaskInstance] = []

        if unmapped_ti:
            if TYPE_CHECKING:
                assert task.dag is None or isinstance(task.dag, SchedulerDAG)

            # The unmapped task instance still exists and is unfinished, i.e. we
            # haven't tried to run it before.
            if total_length is None:
                # If the DAG is partial, it's likely that the upstream tasks
                # are not done yet, so the task can't fail yet.
                if not task.dag or not task.dag.partial:
                    unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
            elif total_length < 1:
                # If the upstream maps this to a zero-length value, simply mark
                # the unmapped task instance as SKIPPED (if needed).
                task.log.info(
                    "Marking %s as SKIPPED since the map has %d values to expand",
                    unmapped_ti,
                    total_length,
                )
                unmapped_ti.state = TaskInstanceState.SKIPPED
            else:
                zero_index_ti_exists = exists_query(
                    TaskInstance.dag_id == task.dag_id,
                    TaskInstance.task_id == task.task_id,
                    TaskInstance.run_id == run_id,
                    TaskInstance.map_index == 0,
                    session=session,
                )
                if not zero_index_ti_exists:
                    # Otherwise convert this into the first mapped index, and create
                    # TaskInstance for other indexes.
                    unmapped_ti.map_index = 0
                    task.log.debug("Updated in place to become %s", unmapped_ti)
                    all_expanded_tis.append(unmapped_ti)
                    # execute hook for task instance map index 0
                    task_instance_mutation_hook(unmapped_ti)
                    session.flush()
                else:
                    task.log.debug("Deleting the original task instance: %s", unmapped_ti)
                    session.delete(unmapped_ti)
                state = unmapped_ti.state
            dag_version_id = unmapped_ti.dag_version_id

        if total_length is None or total_length < 1:
            # Nothing to fixup.
            indexes_to_map: Iterable[int] = ()
        else:
            # Only create "missing" ones.
            current_max_mapping = session.scalar(
                select(func.max(TaskInstance.map_index)).where(
                    TaskInstance.dag_id == task.dag_id,
                    TaskInstance.task_id == task.task_id,
                    TaskInstance.run_id == run_id,
                )
            )
            indexes_to_map = range(current_max_mapping + 1, total_length)

        if unmapped_ti:
            dag_version_id = unmapped_ti.dag_version_id
        elif dag_version := DagVersion.get_latest_version(task.dag_id, session=session):
            dag_version_id = dag_version.id
        else:
            dag_version_id = None

        for index in indexes_to_map:
            # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
            ti = TaskInstance(
                task,
                run_id=run_id,
                map_index=index,
                state=state,
                dag_version_id=dag_version_id,
            )
            task.log.debug("Expanding TIs upserted %s", ti)
            task_instance_mutation_hook(ti)
            ti = session.merge(ti)
            ti.refresh_from_task(task)  # session.merge() loses task information.
            all_expanded_tis.append(ti)

        # Coerce the None case to 0 -- these two are almost treated identically,
        # except the unmapped ti (if exists) is marked to different states.
        total_expanded_ti_count = total_length or 0

        # Any (old) task instances with inapplicable indexes (>= the total
        # number we need) are set to "REMOVED".
        query = select(TaskInstance).where(
            TaskInstance.dag_id == task.dag_id,
            TaskInstance.task_id == task.task_id,
            TaskInstance.run_id == run_id,
            TaskInstance.map_index >= total_expanded_ti_count,
        )
        query = with_row_locks(query, of=TaskInstance, session=session, skip_locked=True)
        to_update = session.scalars(query)
        for ti in to_update:
            ti.state = TaskInstanceState.REMOVED
        session.flush()
        return all_expanded_tis, total_expanded_ti_count - 1