in scripts/ci/pre_commit/check_template_context_variable_in_sync.py [0:0]
def _iter_template_context_keys_from_original_return() -> typing.Iterator[str]:
ti_mod = ast.parse(TASKRUNNER_PY.read_text("utf-8"), str(TASKRUNNER_PY))
# Locate the RuntimeTaskInstance class definition
runtime_task_instance_class = next(
node
for node in ast.iter_child_nodes(ti_mod)
if isinstance(node, ast.ClassDef) and node.name == "RuntimeTaskInstance"
)
# Locate the get_template_context method in RuntimeTaskInstance
fn_get_template_context = next(
node
for node in ast.iter_child_nodes(runtime_task_instance_class)
if isinstance(node, ast.FunctionDef) and node.name == "get_template_context"
)
# Helper function to extract keys from a dictionary node
def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]:
for key in node.keys:
if not isinstance(key, ast.Constant) or not isinstance(key.value, str):
raise ValueError("Key in dictionary is not a string literal")
yield key.value
# Extract keys from the main `context` dictionary assignment
context_assignment = next(
stmt
for stmt in fn_get_template_context.body
if isinstance(stmt, ast.AnnAssign)
and isinstance(stmt.target, ast.Name)
and stmt.target.id == "context"
)
if not isinstance(context_assignment.value, ast.Dict):
raise ValueError("'context' is not assigned a dictionary literal")
yield from extract_keys_from_dict(context_assignment.value)
# Handle keys added conditionally in `if from_server`
for stmt in fn_get_template_context.body:
if isinstance(stmt, ast.If) and isinstance(stmt.test, ast.Name) and stmt.test.id == "from_server":
for sub_stmt in stmt.body:
# Get keys from `context_from_server` assignment
if (
isinstance(sub_stmt, ast.AnnAssign)
and isinstance(sub_stmt.target, ast.Name)
and isinstance(sub_stmt.value, ast.Dict)
and sub_stmt.target.id == "context_from_server"
):
yield from extract_keys_from_dict(sub_stmt.value)