def _iter_template_context_keys_from_original_return()

in scripts/ci/pre_commit/check_template_context_variable_in_sync.py [0:0]


def _iter_template_context_keys_from_original_return() -> typing.Iterator[str]:
    ti_mod = ast.parse(TASKRUNNER_PY.read_text("utf-8"), str(TASKRUNNER_PY))

    # Locate the RuntimeTaskInstance class definition
    runtime_task_instance_class = next(
        node
        for node in ast.iter_child_nodes(ti_mod)
        if isinstance(node, ast.ClassDef) and node.name == "RuntimeTaskInstance"
    )

    # Locate the get_template_context method in RuntimeTaskInstance
    fn_get_template_context = next(
        node
        for node in ast.iter_child_nodes(runtime_task_instance_class)
        if isinstance(node, ast.FunctionDef) and node.name == "get_template_context"
    )

    # Helper function to extract keys from a dictionary node
    def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]:
        for key in node.keys:
            if not isinstance(key, ast.Constant) or not isinstance(key.value, str):
                raise ValueError("Key in dictionary is not a string literal")
            yield key.value

    # Extract keys from the main `context` dictionary assignment
    context_assignment = next(
        stmt
        for stmt in fn_get_template_context.body
        if isinstance(stmt, ast.AnnAssign)
        and isinstance(stmt.target, ast.Name)
        and stmt.target.id == "context"
    )

    if not isinstance(context_assignment.value, ast.Dict):
        raise ValueError("'context' is not assigned a dictionary literal")
    yield from extract_keys_from_dict(context_assignment.value)

    # Handle keys added conditionally in `if from_server`
    for stmt in fn_get_template_context.body:
        if isinstance(stmt, ast.If) and isinstance(stmt.test, ast.Name) and stmt.test.id == "from_server":
            for sub_stmt in stmt.body:
                # Get keys from `context_from_server` assignment
                if (
                    isinstance(sub_stmt, ast.AnnAssign)
                    and isinstance(sub_stmt.target, ast.Name)
                    and isinstance(sub_stmt.value, ast.Dict)
                    and sub_stmt.target.id == "context_from_server"
                ):
                    yield from extract_keys_from_dict(sub_stmt.value)