in src/smolagents/_function_type_hints_utils.py [0:0]
def get_imports(code: str) -> list[str]:
"""
Extracts all the libraries (not relative imports) that are imported in a code.
Args:
code (`str`): Code text to inspect.
Returns:
`list[str]`: List of all packages required to use the input code.
"""
# filter out try/except block so in custom code we can have try/except imports
code = re.sub(r"\s*try\s*:.*?except.*?:", "", code, flags=re.DOTALL)
# filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
code = re.sub(
r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+",
"",
code,
flags=re.MULTILINE,
)
# Imports of the form `import xxx` or `import xxx as yyy`
imports = re.findall(r"^\s*import\s+(\S+?)(?:\s+as\s+\S+)?\s*$", code, flags=re.MULTILINE)
# Imports of the form `from xxx import yyy`
imports += re.findall(r"^\s*from\s+(\S+)\s+import", code, flags=re.MULTILINE)
# Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
return [get_package_name(import_name) for import_name in set(imports)]