aws_xray_sdk/core/async_context.py (56 lines of code) (raw):
import asyncio
import copy
from .context import Context as _Context
class AsyncContext(_Context):
"""
Async Context for storing segments.
Inherits nearly everything from the main Context class.
Replaces threading.local with a task based local storage class,
Also overrides clear_trace_entities
"""
def __init__(self, *args, loop=None, use_task_factory=True, **kwargs):
super().__init__(*args, **kwargs)
self._loop = loop
if loop is None:
self._loop = asyncio.get_event_loop()
if use_task_factory:
self._loop.set_task_factory(task_factory)
self._local = TaskLocalStorage(loop=loop)
def clear_trace_entities(self):
"""
Clear all trace_entities stored in the task local context.
"""
if self._local is not None:
self._local.clear()
class TaskLocalStorage:
"""
Simple task local storage
"""
def __init__(self, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
def __setattr__(self, name, value):
if name in ('_loop',):
# Set normal attributes
object.__setattr__(self, name, value)
else:
# Set task local attributes
task = asyncio.current_task(loop=self._loop)
if task is None:
return None
if not hasattr(task, 'context'):
task.context = {}
task.context[name] = value
def __getattribute__(self, item):
if item in ('_loop', 'clear'):
# Return references to local objects
return object.__getattribute__(self, item)
task = asyncio.current_task(loop=self._loop)
if task is None:
return None
if hasattr(task, 'context') and item in task.context:
return task.context[item]
raise AttributeError('Task context does not have attribute {0}'.format(item))
def clear(self):
# If were in a task, clear the context dictionary
task = asyncio.current_task(loop=self._loop)
if task is not None and hasattr(task, 'context'):
task.context.clear()
def task_factory(loop, coro):
"""
Task factory function
Fuction closely mirrors the logic inside of
asyncio.BaseEventLoop.create_task. Then if there is a current
task and the current task has a context then share that context
with the new task
"""
task = asyncio.Task(coro, loop=loop)
if task._source_traceback: # flake8: noqa
del task._source_traceback[-1] # flake8: noqa
# Share context with new task if possible
current_task = asyncio.current_task(loop=loop)
if current_task is not None and hasattr(current_task, 'context'):
if current_task.context.get('entities'):
# NOTE: (enowell) Because the `AWSXRayRecorder`'s `Context` decides
# the parent by looking at its `_local.entities`, we must copy the entities
# for concurrent subsegments. Otherwise, the subsegments would be
# modifying the same `entities` list and sugsegments would take other
# subsegments as parents instead of the original `segment`.
#
# See more: https://github.com/aws/aws-xray-sdk-python/blob/0f13101e4dba7b5c735371cb922f727b1d9f46d8/aws_xray_sdk/core/context.py#L90-L101
new_context = copy.copy(current_task.context)
new_context['entities'] = [item for item in current_task.context['entities']]
else:
new_context = current_task.context
setattr(task, 'context', new_context)
return task