context.py (103 lines of code) (raw):
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Context vars for managing task groups."""
import asyncio
from collections.abc import Coroutine, Iterable
import contextvars
from typing import Any, TypeVar
from absl import logging
_PROCESSOR_TASK_GROUP: contextvars.ContextVar['CancellableContextTaskGroup'] = (
contextvars.ContextVar('processor_task_group')
)
PROMPT_STREAM = 'prompt'
DEBUG_STREAM = 'debug'
STATUS_STREAM = 'status'
_PROCESSOR_RESERVED_SUBSTREAMS: contextvars.ContextVar[frozenset[str]] = (
contextvars.ContextVar(
'processor_reserved_substreams',
default=frozenset({DEBUG_STREAM, STATUS_STREAM}),
)
)
def raise_flattened_exception_group(exception: Exception):
e = exception
while isinstance(e, ExceptionGroup):
e = e.exceptions[0]
if e is exception:
raise exception
else:
raise e from exception
class CancellableContextTaskGroup(asyncio.TaskGroup):
"""TaskGroup that adds itself to a contextvar to be accessed by create_task.
Includes a method for cancelling all tasks in the group.
"""
def __init__(
self, *args, reserved_substreams: Iterable[str] | None = None, **kwargs
):
super().__init__(*args, **kwargs)
self._cancel_tasks = set()
self._reserved_substreams = reserved_substreams
def create_task(self, *args, **kwargs) -> asyncio.Task:
t = super().create_task(*args, **kwargs)
self._cancel_tasks.add(t)
t.add_done_callback(self._cancel_tasks.discard)
return t
async def __aenter__(self) -> 'CancellableContextTaskGroup':
self._current_taskgroup_token = _PROCESSOR_TASK_GROUP.set(self)
if self._reserved_substreams is not None:
self._reserved_substreams_token = _PROCESSOR_RESERVED_SUBSTREAMS.set(
frozenset(self._reserved_substreams)
)
else:
self._reserved_substreams_token = None
return await super().__aenter__()
async def __aexit__(self, et, exc, tb):
try:
return await super().__aexit__(et, exc, tb)
except BaseExceptionGroup as e:
raise_flattened_exception_group(e)
finally:
try:
_PROCESSOR_TASK_GROUP.reset(self._current_taskgroup_token)
if self._reserved_substreams_token is not None:
_PROCESSOR_RESERVED_SUBSTREAMS.reset(self._reserved_substreams_token)
except ValueError:
# ValueError is raised when the Context self._current_taskgroup_token
# was created in doesn't match the current Context.
# This can happen when an asyncgenerator is garbage collected.
# The task loop will call loop.call_soon(loop.create_task, agen.aclose).
# aclose raises GeneratorExit in the asyncgenerator which is seen here
# but from a different context.
if et is GeneratorExit:
logging.log_first_n(
logging.WARNING,
'GeneratorExit was seen in processors.context. This'
' indicates that the asyncgenerator that opened the context has'
' been closed from a different context. For example, it has been'
' garbage collected. This is usually means a task executing the'
' generator was also garbage collected. Consider turning on'
' asyncio debug mode to investigate further.',
1,
)
pass
def cancel(self):
for task in self._cancel_tasks:
task.cancel()
def context(
reserved_substreams: Iterable[str] | None = None,
) -> CancellableContextTaskGroup:
return CancellableContextTaskGroup(reserved_substreams=reserved_substreams)
def task_group() -> CancellableContextTaskGroup | None:
return _PROCESSOR_TASK_GROUP.get(None)
def get_reserved_substreams() -> frozenset[str]:
return _PROCESSOR_RESERVED_SUBSTREAMS.get()
def is_reserved_substream(substream_name: str) -> bool:
return any(
substream_name.startswith(prefix) for prefix in get_reserved_substreams()
)
# If a task is created without a task group then a reference to it must be kept.
_without_context_background_tasks = set()
def create_task(*args, **kwargs) -> asyncio.Task:
"""Creates a task that uses the context TaskGroup.
If no context is available then `asyncio.create_task` will be used.
Args:
*args: Positional arguments to pass to `asyncio.create_task`.
**kwargs: Keyword arguments to pass to `asyncio.create_task`.
Returns:
An asyncio task.
"""
tg = task_group()
if tg is None:
task = asyncio.create_task(*args, **kwargs)
_without_context_background_tasks.add(task)
task.add_done_callback(_without_context_background_tasks.discard)
return task
return tg.create_task(*args, **kwargs)
_T = TypeVar('_T')
async def context_cancel_coro(
f: Coroutine[Any, Any, _T],
) -> _T:
"""Wrapper that cancels all tasks in context if the wrapper is cancelled."""
async with context() as ctx:
try:
return await f
except asyncio.CancelledError:
ctx.cancel()
raise