in task-sdk/src/airflow/sdk/bases/operator.py [0:0]
def _apply_defaults(cls, func: T) -> T:
"""
Look for an argument named "default_args", and fill the unspecified arguments from it.
Since python2.* isn't clear about which arguments are missing when
calling a function, and that this can be quite confusing with multi-level
inheritance and argument defaults, this decorator also alerts with
specific information about the missing arguments.
"""
# Cache inspect.signature for the wrapper closure to avoid calling it
# at every decorated invocation. This is separate sig_cache created
# per decoration, i.e. each function decorated using apply_defaults will
# have a different sig_cache.
sig_cache = inspect.signature(func)
non_variadic_params = {
name: param
for (name, param) in sig_cache.parameters.items()
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
}
non_optional_args = {
name
for name, param in non_variadic_params.items()
if param.default == param.empty and name != "task_id"
}
fixup_decorator_warning_stack(func)
@wraps(func)
def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
if args:
raise TypeError("Use keyword arguments when initializing operators")
instantiated_from_mapped = kwargs.pop(
"_airflow_from_mapped",
getattr(self, "_BaseOperator__from_mapped", False),
)
dag: DAG | None = kwargs.get("dag")
if dag is None:
dag = DagContext.get_current()
if dag is not None:
kwargs["dag"] = dag
task_group: TaskGroup | None = kwargs.get("task_group")
if dag and not task_group:
task_group = TaskGroupContext.get_current(dag)
if task_group is not None:
kwargs["task_group"] = task_group
default_args, merged_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=kwargs.pop("params", None),
task_default_args=kwargs.pop("default_args", None),
)
for arg in sig_cache.parameters:
if arg not in kwargs and arg in default_args:
kwargs[arg] = default_args[arg]
missing_args = non_optional_args.difference(kwargs)
if len(missing_args) == 1:
raise TypeError(f"missing keyword argument {missing_args.pop()!r}")
if missing_args:
display = ", ".join(repr(a) for a in sorted(missing_args))
raise TypeError(f"missing keyword arguments {display}")
if merged_params:
kwargs["params"] = merged_params
hook = getattr(self, "_hook_apply_defaults", None)
if hook:
args, kwargs = hook(**kwargs, default_args=default_args)
default_args = kwargs.pop("default_args", {})
if not hasattr(self, "_BaseOperator__init_kwargs"):
object.__setattr__(self, "_BaseOperator__init_kwargs", {})
object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped)
result = func(self, **kwargs, default_args=default_args)
# Store the args passed to init -- we need them to support task.map serialization!
self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
# Set upstream task defined by XComArgs passed to template fields of the operator.
# BUT: only do this _ONCE_, not once for each class in the hierarchy
if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc]
self._set_xcomargs_dependencies()
# Mark instance as instantiated so that futre attr setting updates xcomarg-based deps.
object.__setattr__(self, "_BaseOperator__instantiated", True)
return result
apply_defaults.__non_optional_args = non_optional_args # type: ignore
apply_defaults.__param_names = set(non_variadic_params) # type: ignore
return cast("T", apply_defaults)