example_zoo/tools/source_finder.py (48 lines of code) (raw):

# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Helper functions to find all local dependencies between modules within a root package. import ast import os import re class SourceFinder(object): def __init__(self, package_path, script_path): # root should be an absolute path # remove trailing '/' self.package_path = package_path.rstrip('/') # script_path should be absolute path to script self.script_path = script_path # parent is used to figure out the absolute path of dependency scripts self.parent, self.package_name = os.path.split(self.package_path) # for example: # script_path = /tmp/a/b/c/d.py # package_path = /tmp/a/b # parent = /tmp/a # package_name = b self.externals = set([]) # keys are absolute paths, values are lists of module names self.script_imports = {} def process(self): # to_visit is a list of absolute paths to_visit = [self.script_path] while len(to_visit) > 0: visit_path = to_visit.pop(0) module_names = self.process_script(visit_path) for module_name in module_names: # turn this into absolute path module_path = os.path.join(self.parent, self.module_name_to_path(module_name)) # sometimes a variable is imported, in which case we back track one level if not os.path.exists(module_path): parent, _ = os.path.split(module_path) module_path = parent + '.py' # at this point the file should exist if not os.path.exists(module_path): raise FileNotFoundError(module_path) # add to the to_visit list if not yet visited if module_path not in self.script_imports: to_visit.append(module_path) def process_script(self, path): # side effect: updates set self.externals and dict self.script_imports (adding value only for key = path) # returns the script_imports of the processed script with open(path, 'r') as f: code = f.read() tree = ast.parse(code) self.script_imports[path] = set([]) for node in tree.body: if node.__class__ is ast.Import: module_names = [alias.name for alias in node.names] elif node.__class__ is ast.ImportFrom: parent_module_name = node.module module_names = ['{}.{}'.format(parent_module_name, alias.name) for alias in node.names] else: continue for module_name in module_names: if module_name.startswith(self.package_name): self.script_imports[path].add(module_name) else: self.externals.add(module_name) return self.script_imports[path] def module_name_to_path(self, module_name): # converts module = 'a.b.c' to 'a/b/c.py' path = module_name.replace('.', '/') + '.py' return path def path_to_relative_path(self, path): # returns path starting with self.package_name # this is used in cmle_package.py return re.sub('^{}/'.format(self.parent), '', path)