airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/scripter.py (88 lines of code) (raw):

import inspect import ast import textwrap import sys def scriptize(func): # Get the source code of the decorated function source_code = textwrap.dedent(inspect.getsource(func)) func_tree = ast.parse(source_code) # Retrieve the module where the function is defined module_name = func.__module__ if module_name in sys.modules: module = sys.modules[module_name] else: raise RuntimeError(f"Cannot find module {module_name} for function {func.__name__}") # Attempt to get the module source. # If this fails (e.g., in a Jupyter notebook), fallback to an empty module tree. try: module_source = textwrap.dedent(inspect.getsource(module)) module_tree = ast.parse(module_source) except (TypeError, OSError): # In Jupyter (or certain environments), we can't get the module source this way. # Use an empty module tree as a fallback. module_tree = ast.parse("") # Find the function definition node func_def = next( (node for node in func_tree.body if isinstance(node, ast.FunctionDef)), None) if not func_def: raise ValueError("No function definition found in func_tree.") # ---- NEW: Identify used names in the function body ---- # We'll walk the function body to collect all names used. class NameCollector(ast.NodeVisitor): def __init__(self): self.used_names = set() def visit_Name(self, node): self.used_names.add(node.id) self.generic_visit(node) def visit_Attribute(self, node): # This accounts for usage like time.sleep (attribute access) # We add 'time' if we see something like time.sleep # The top-level name is usually in node.value if isinstance(node.value, ast.Name): self.used_names.add(node.value.id) self.generic_visit(node) name_collector = NameCollector() name_collector.visit(func_def) used_names = name_collector.used_names # For imports, we need to consider a few cases: # - `import module` # - `import module as alias` # - `from module import name` # We'll keep an import if it introduces at least one name or module referenced by the function. def is_import_used(import_node): if isinstance(import_node, ast.Import): # import something [as alias] for alias in import_node.names: # If we have something like `import time` and "time" is used, # or `import pandas as pd` and "pd" is used, keep it. if alias.asname and alias.asname in used_names: return True if alias.name.split('.')[0] in used_names: return True return False elif isinstance(import_node, ast.ImportFrom): # from module import name(s) # Keep if any of the imported names or their asnames are used for alias in import_node.names: # Special case: if we have `from module import task_context`, ignore it if alias.name == "task_context": return False # If from module import x as y, check y; else check x if alias.asname and alias.asname in used_names: return True if alias.name in used_names: return True # Another subtlety: if we have `from time import sleep` # and we call `time.sleep()` is that detected? # Actually, we already caught attribute usage above, which would add "time" to used_names # but not "sleep". If the code does `sleep(n)` directly, then "sleep" is in used_names. return False return False # For other functions, include only if their name is referenced. def is_function_used(func_node): return func_node.name in used_names def wrapper(*args, **kwargs): # Bind arguments func_signature = inspect.signature(func) bound_args = func_signature.bind(*args, **kwargs) bound_args.apply_defaults() # Convert the original function body to source body_source_lines = [ast.unparse(stmt) for stmt in func_def.body] body_source_code = "\n".join(body_source_lines) # Collect relevant code blocks: relevant_code_blocks = [] for node in module_tree.body: if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom): # Include only used imports if is_import_used(node): relevant_code_blocks.append(ast.unparse(node).strip()) elif isinstance(node, ast.FunctionDef): # Include only used functions, excluding the decorator itself and the decorated function if node.name not in ('task_context', func.__name__) and is_function_used(node): func_code = ast.unparse(node).strip() relevant_code_blocks.append(func_code) # Prepare argument assignments arg_assignments = [] for arg_name, arg_value in bound_args.arguments.items(): # Stringify arguments as before if isinstance(arg_value, str): arg_assignments.append(f"{arg_name} = {arg_value!r}") else: arg_assignments.append(f"{arg_name} = {repr(arg_value)}") # Combine everything combined_code_parts = [] if relevant_code_blocks: combined_code_parts.append("\n\n".join(relevant_code_blocks)) if arg_assignments: if combined_code_parts: combined_code_parts.append("") # blank line before args combined_code_parts.extend(arg_assignments) if arg_assignments: combined_code_parts.append("") # blank line before body combined_code_parts.append(body_source_code) combined_code = "\n".join(combined_code_parts).strip() return combined_code return wrapper