example_zoo/tools/cmle_package.py (222 lines of code) (raw):

# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import shutil import urllib2 class Pipe(object): # a pipe is a triplet of (source path, destination path, list of transformations) def __init__(self, source, destination, transformations=[]): self.source = source self.destination = destination self.transformations = transformations def _handle_dir(self): # ignore tests and test data shutil.copytree( self.source, self.destination, ignore=shutil.ignore_patterns('*_test.py', '*testing*') ) def _handle_file(self): parent, _ = os.path.split(self.destination) if not os.path.exists(parent): os.makedirs(parent) with open(self.source, 'r') as source_file, open(self.destination, 'w') as destination_file: content = source_file.read() for transformation in self.transformations: content = transformation(content) destination_file.write(content) def run(self): if not os.path.exists(self.source): raise ValueError('{} does not exist'.format(self.source)) if os.path.isdir(self.source): self._handle_dir() elif os.path.isfile(self.source): self._handle_file() class CMLEPackage(object): WEB_BASE = 'https://github.com/{org}/{repository}/blob/{branch}/{full_path}' TEMPLATE_FILENAMES = [ 'setup.py', 'config.yaml', 'submit_27.sh', 'submit_35.sh', 'README.md', ] def __init__(self, sample_dict, repo): self.org = sample_dict['org'] self.repository = sample_dict['repository'] self.runtime_version = sample_dict['runtime_version'] self.branch = sample_dict['branch'] self.module_path = sample_dict.get('module_path', '') self.script_path = sample_dict.get('script_path', '') self.script_name = sample_dict['script_name'] self.replace = sample_dict.get('replace', []) self.artifact = sample_dict['artifact'] self.wait_time = sample_dict['wait_time'] # check out the specified branch self.repo = repo self.repo.git.checkout(self.branch) self.working_dir = self.repo.working_dir # optional configs self.other_sources = sample_dict.get('other_sources', []) if 'args' in sample_dict: sep = ' \\\n ' self.args = sep + sep.join(sample_dict['args']) else: self.args = '' if 'requires' in sample_dict: self.requires = ','.join("'{}'".format(req) for req in sample_dict['requires']) else: self.requires = '' self.tfgfile_wrap = sample_dict.get('tfgfile_wrap', []) self.pipes = [] def format(self, content): return content.format(**self.format_dict) def add_tfgfile_wrapper(self, content): lines = [] add_import = True for line in content.split('\n'): if add_import and 'import' in line and 'from __future__' not in line: lines.append(self.tfgfile_wrapper_import) add_import = False for to_wrap in self.tfgfile_wrap: if 'def {}'.format(to_wrap) in line: lines.append('@tfgfile_wrapper') lines.append(line) return '\n'.join(lines) def add_job_dir_flag(self, content): flags_define = 'flags.DEFINE_string(name="job-dir", default="/tmp", help="AI Platform Training passes this to the training script.")' lines = [] add_flags_define = False for line in content.split('\n'): # inject the flags define line right after the import if add_flags_define == True: lines.append(flags_define) add_flags_define = False if line == 'from absl import flags': add_flags_define = True lines.append(line) return '\n'.join(lines) def make_replace_transformation(self, match, replace): def replace_transformation(content): return content.replace(match, replace) return replace_transformation def build_pipes(self): for template_filename in self.TEMPLATE_FILENAMES: self.pipes.append( Pipe( os.path.join('templates', template_filename), os.path.join(self.output_dir, template_filename), [self.format] ) ) # test self.pipes.append( Pipe( 'templates/cmle_test.py', os.path.join(self.output_dir, self.test_name), [self.format] ) ) # tfgfile_wrapper if needed if self.tfgfile_wrap: self.pipes.append( Pipe( 'templates/tfgfile_wrapper.py', os.path.join(self.output_dir, self.output_script_path, 'tfgfile_wrapper.py') ) ) # source source_transformations = [self.add_job_dir_flag] if self.tfgfile_wrap: source_transformations.append(self.add_tfgfile_wrapper) for match, replace in self.replace: source_transformations.append( self.make_replace_transformation(match, replace) ) self.pipes.append( Pipe( os.path.join(self.working_dir, self.module_path, self.script_path, self.script_name), os.path.join(self.output_dir, self.output_script_path, self.script_name), source_transformations ) ) # other source files/directories # use source_finder to find minimally required other source files from source_finder import SourceFinder sf = SourceFinder( os.path.join(self.working_dir, self.module_path, self.package_path), os.path.join(self.working_dir, self.module_path, self.script_path, self.script_name) ) sf.process() for module_path in sf.script_imports.keys(): # skip the script itself if module_path == sf.script_path: continue rel_path = sf.path_to_relative_path(module_path) self.pipes.append( Pipe( module_path, os.path.join(self.output_dir, rel_path) ) ) @property def name(self): return self.script_name.split('.')[0] @property def test_name(self): return 'cmle_{}_test.py'.format(self.name) # for the generated package, putting the script into a `trainer` directory if no script_path is specified @property def output_script_path(self): return self.script_path or 'trainer' @property def output_package_path(self): return self.output_script_path.split('/')[0] @property def package_path(self): if self.script_path: return self.script_path.split('/')[0] else: return 'trainer' @property def module_parent(self): return self.output_script_path.replace('/', '.') @property def module_name(self): return '{}.{}'.format(self.module_parent, self.name) @property def output_dir(self): return os.path.join('..', self.org, self.repository, self.name) @property def web_url(self): web_url = self.WEB_BASE.format( org=self.org, repository=self.repository, branch=self.branch, full_path=self.full_path ) return web_url @property def full_path(self): return os.path.join(self.module_path, self.script_path, self.script_name) @property def tfgfile_wrapper_import(self): return 'from {}.tfgfile_wrapper import tfgfile_wrapper'.format(self.module_parent) @property def format_dict(self): format_dict = { 'org': self.org, 'repository': self.repository, 'name': self.name, 'runtime_version': self.runtime_version, 'output_package_path': self.output_package_path, 'module_name': self.module_name, 'full_path': self.full_path, 'requires': self.requires, 'web_url': self.web_url, 'artifact': self.artifact, 'wait_time': self.wait_time, 'args': self.args } return format_dict def generate(self): print('Building package for {}'.format(self.name)) # clean up previously generated package if os.path.exists(self.output_dir): shutil.rmtree(self.output_dir) os.makedirs(os.path.join(self.output_dir, self.output_script_path)) self.build_pipes() for pipe in self.pipes: pipe.run() # add __init__.py to all directories for path, _, files in os.walk(self.output_dir): if path != self.output_dir and '__init__.py' not in files: Pipe( 'templates/__init__.py', os.path.join(path, '__init__.py') ).run()