def get_imports()

in src/smolagents/_function_type_hints_utils.py [0:0]


def get_imports(code: str) -> list[str]:
    """
    Extracts all the libraries (not relative imports) that are imported in a code.

    Args:
        code (`str`): Code text to inspect.

    Returns:
        `list[str]`: List of all packages required to use the input code.
    """
    # filter out try/except block so in custom code we can have try/except imports
    code = re.sub(r"\s*try\s*:.*?except.*?:", "", code, flags=re.DOTALL)

    # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
    code = re.sub(
        r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+",
        "",
        code,
        flags=re.MULTILINE,
    )

    # Imports of the form `import xxx` or `import xxx as yyy`
    imports = re.findall(r"^\s*import\s+(\S+?)(?:\s+as\s+\S+)?\s*$", code, flags=re.MULTILINE)
    # Imports of the form `from xxx import yyy`
    imports += re.findall(r"^\s*from\s+(\S+)\s+import", code, flags=re.MULTILINE)
    # Only keep the top-level module
    imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
    return [get_package_name(import_name) for import_name in set(imports)]