in task-sdk/src/airflow/sdk/execution_time/supervisor.py [0:0]
def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
log.debug("Received message from task runner", msg=msg)
resp: BaseModel | None = None
dump_opts = {}
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
elif isinstance(msg, SucceedTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.succeed(
id=self.id,
when=msg.end_date,
task_outlets=msg.task_outlets,
outlet_events=msg.outlet_events,
rendered_map_index=self._rendered_map_index,
)
elif isinstance(msg, RetryTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.retry(
id=self.id,
end_date=msg.end_date,
rendered_map_index=self._rendered_map_index,
)
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
if isinstance(conn, ConnectionResponse):
if conn.password:
mask_secret(conn.password)
if conn.extra:
mask_secret(conn.extra)
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result
dump_opts = {"exclude_unset": True, "by_alias": True}
else:
resp = conn
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
if isinstance(var, VariableResponse):
if var.value:
mask_secret(var.value)
var_result = VariableResult.from_variable_response(var)
resp = var_result
dump_opts = {"exclude_unset": True}
else:
resp = var
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates
)
xcom_result = XComResult.from_xcom_response(xcom)
resp = xcom_result
elif isinstance(msg, GetXComCount):
len = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key)
resp = XComCountResponse(len=len)
elif isinstance(msg, GetXComSequenceItem):
xcom = self.client.xcoms.get_sequence_item(
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset
)
if isinstance(xcom, XComResponse):
resp = XComResult.from_xcom_response(xcom)
else:
resp = xcom
elif isinstance(msg, DeferTask):
self._terminal_state = TaskInstanceState.DEFERRED
self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.defer(self.id, msg)
elif isinstance(msg, RescheduleTask):
self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
self.client.task_instances.reschedule(self.id, msg)
elif isinstance(msg, SkipDownstreamTasks):
self.client.task_instances.skip_downstream_tasks(self.id, msg)
elif isinstance(msg, SetXCom):
self.client.xcoms.set(
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length
)
elif isinstance(msg, DeleteXCom):
self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
elif isinstance(msg, SetRenderedFields):
self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
elif isinstance(msg, GetAssetByName):
asset_resp = self.client.assets.get(name=msg.name)
if isinstance(asset_resp, AssetResponse):
asset_result = AssetResult.from_asset_response(asset_resp)
resp = asset_result
dump_opts = {"exclude_unset": True}
else:
resp = asset_resp
elif isinstance(msg, GetAssetByUri):
asset_resp = self.client.assets.get(uri=msg.uri)
if isinstance(asset_resp, AssetResponse):
asset_result = AssetResult.from_asset_response(asset_resp)
resp = asset_result
dump_opts = {"exclude_unset": True}
else:
resp = asset_resp
elif isinstance(msg, GetAssetEventByAsset):
asset_event_resp = self.client.asset_events.get(uri=msg.uri, name=msg.name)
asset_event_result = AssetEventsResult.from_asset_events_response(asset_event_resp)
resp = asset_event_result
dump_opts = {"exclude_unset": True}
elif isinstance(msg, GetAssetEventByAssetAlias):
asset_event_resp = self.client.asset_events.get(alias_name=msg.alias_name)
asset_event_result = AssetEventsResult.from_asset_events_response(asset_event_resp)
resp = asset_event_result
dump_opts = {"exclude_unset": True}
elif isinstance(msg, GetPrevSuccessfulDagRun):
dagrun_resp = self.client.task_instances.get_previous_successful_dagrun(self.id)
dagrun_result = PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
resp = dagrun_result
dump_opts = {"exclude_unset": True}
elif isinstance(msg, TriggerDagRun):
resp = self.client.dag_runs.trigger(
msg.dag_id,
msg.run_id,
msg.conf,
msg.logical_date,
msg.reset_dag_run,
)
elif isinstance(msg, GetDagRunState):
dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id)
resp = DagRunStateResult.from_api_response(dr_resp)
elif isinstance(msg, GetTaskRescheduleStartDate):
resp = self.client.task_instances.get_reschedule_start_date(msg.ti_id, msg.try_number)
elif isinstance(msg, GetTICount):
resp = self.client.task_instances.get_count(
dag_id=msg.dag_id,
map_index=msg.map_index,
task_ids=msg.task_ids,
task_group_id=msg.task_group_id,
logical_dates=msg.logical_dates,
run_ids=msg.run_ids,
states=msg.states,
)
elif isinstance(msg, GetTaskStates):
task_states_map = self.client.task_instances.get_task_states(
dag_id=msg.dag_id,
map_index=msg.map_index,
task_ids=msg.task_ids,
task_group_id=msg.task_group_id,
logical_dates=msg.logical_dates,
run_ids=msg.run_ids,
)
if isinstance(task_states_map, TaskStatesResponse):
resp = TaskStatesResult.from_api_response(task_states_map)
else:
resp = task_states_map
elif isinstance(msg, GetDRCount):
resp = self.client.dag_runs.get_count(
dag_id=msg.dag_id,
logical_dates=msg.logical_dates,
run_ids=msg.run_ids,
states=msg.states,
)
elif isinstance(msg, DeleteVariable):
resp = self.client.variables.delete(msg.key)
else:
log.error("Unhandled request", msg=msg)
return
if resp:
self.send_msg(resp, **dump_opts)