def main()

in tensorflow_probability/substrates/meta/rewrite.py [0:0]


def main(argv):

  disabled_by_pkg = dict(DISABLED_BY_PKG)
  for dep in FLAGS.omit_deps:
    if '/python/' in dep:
      pkg = 'python.' + dep.split('/python/')[1].split(':')[0].replace('/', '.')
    elif '/google/' in dep:
      pkg = 'google.' + dep.split('/google/')[1].split(':')[0].replace('/', '.')
    lib = dep.split(':')[1]
    if pkg.endswith('.{}'.format(lib)):
      pkg = pkg.replace('.{}'.format(lib), '')
      disabled_by_pkg.setdefault(pkg, ())
      disabled_by_pkg[pkg] += (lib,)
    else:
      disabled_by_pkg.setdefault(pkg, ())
      disabled_by_pkg[pkg] += (lib,)

  replacements = collections.OrderedDict(TF_REPLACEMENTS)
  for pkg, disabled in disabled_by_pkg.items():
    replacements.update({
        'from tensorflow_probability.{}.{} '.format(pkg, item):
        '# from tensorflow_probability.{}.{} '.format(pkg, item)
        for item in disabled
    })
    replacements.update({
        'from tensorflow_probability.{} import {}'.format(pkg, item):
        '# from tensorflow_probability.{} import {}'.format(pkg, item)
        for item in disabled
    })
  replacements.update({
      'tensorflow_probability.python.{}'.format(lib):
      'tensorflow_probability.substrates.numpy.{}'.format(lib)
      for lib in LIBS
  })
  replacements.update({
      'tensorflow_probability.python import {}'.format(lib):
      'tensorflow_probability.substrates.numpy import {}'.format(lib)
      for lib in LIBS
  })
  replacements.update({
      'tensorflow_probability.google.{}'.format(lib):
      'tensorflow_probability.substrates.numpy.google.{}'.format(lib)
      for lib in LIBS
  })
  replacements.update({
      'tensorflow_probability.google import {}'.format(lib):
      'tensorflow_probability.substrates.numpy.google import {}'.format(lib)
      for lib in LIBS
  })
  replacements.update({
      'tensorflow_probability.python.internal.{}'.format(internal):
      'tensorflow_probability.substrates.numpy.internal.{}'.format(internal)
      for internal in INTERNALS
  })
  # pylint: disable=g-complex-comprehension
  replacements.update({
      'tensorflow_probability.python.internal import {}'.format(internal):
          'tensorflow_probability.substrates.numpy.internal import {}'.format(
              internal)
      for internal in INTERNALS
  })
  replacements.update({
      'tensorflow.python.ops import {}'.format(private):
      'tensorflow_probability.python.internal.backend.numpy import private'
      ' as {}'.format(private)
      for private in PRIVATE_TF_PKGS
  })
  replacements.update({
      'tensorflow.python.framework.ops import {}'.format(
          private):
      'tensorflow_probability.python.internal.backend.numpy import private'
      ' as {}'.format(private)
      for private in PRIVATE_TF_PKGS
  })
  # pylint: enable=g-complex-comprehension

  # TODO(bjp): Delete this block after TFP uses stateless samplers.
  replacements.update({
      'tf.random.{}'.format(sampler): 'tf.random.stateless_{}'.format(sampler)
      for sampler in SAMPLERS
  })
  replacements.update({
      'self._maybe_assert_dtype': '# self._maybe_assert_dtype',
      'SKIP_DTYPE_CHECKS = False': 'SKIP_DTYPE_CHECKS = True',
      '@test_util.test_all_tf_execution_regimes':
          '# @test_util.test_all_tf_execution_regimes',
      '@test_util.test_graph_and_eager_modes':
          '# @test_util.test_graph_and_eager_modes',
      '@test_util.test_graph_mode_only':
          '# @test_util.test_graph_mode_only',
      'TestCombinationsTest(test_util.TestCase)':
          'TestCombinationsDoNotTest(object)',
      '@six.add_metaclass(TensorMetaClass)':
          '# @six.add_metaclass(TensorMetaClass)',
  })

  filename = argv[1]
  contents = open(filename, encoding='utf-8').read()
  if '__init__.py' in filename:
    # Comment out items from __all__.
    for pkg, disabled in disabled_by_pkg.items():
      for item in disabled:
        def disable_all(name):
          replacements.update({
              '"{}"'.format(name): '# "{}"'.format(name),
              '\'{}\''.format(name): '# \'{}\''.format(name),
          })
        if 'from tensorflow_probability.{} import {}'.format(
            pkg, item) in contents:
          disable_all(item)
        for segment in contents.split(
            'from tensorflow_probability.{}.{} import '.format(
                pkg, item)):
          disable_all(segment.split('\n')[0])

  for find, replace in replacements.items():
    contents = contents.replace(find, replace)

  disabler = 'JAX_DISABLE' if FLAGS.numpy_to_jax else 'NUMPY_DISABLE'
  lines = contents.split('\n')
  for i, l in enumerate(lines):
    if disabler in l:
      lines[i] = '# {}'.format(l)
  contents = '\n'.join(lines)

  if not FLAGS.numpy_to_jax:
    contents = contents.replace('NUMPY_MODE = False', 'NUMPY_MODE = True')
  if FLAGS.numpy_to_jax:
    contents = contents.replace('tfp.substrates.numpy', 'tfp.substrates.jax')
    contents = contents.replace('substrates.numpy', 'substrates.jax')
    contents = contents.replace('backend.numpy', 'backend.jax')
    contents = contents.replace('backend import numpy as tf',
                                'backend import jax as tf')

    contents = contents.replace('def _call_jax', 'def __call__')
    contents = contents.replace('JAX_MODE = False', 'JAX_MODE = True')
    contents = contents.replace('SKIP_DTYPE_CHECKS = True',
                                'SKIP_DTYPE_CHECKS = False')

  substrate = 'jax' if FLAGS.numpy_to_jax else 'numpy'
  if '/python/' in filename:
    path = filename.split('/python/')[1]
  elif '/google/' in filename:
    path = 'google/' + filename.split('/google/')[1]
  footer = '\n'.join([
      '\n',
      '# ' + '@' * 78,
      '# This file is auto-generated by substrates/meta/rewrite.py',
      '# It will be surfaced by the build system as a symlink at:',
      '#   `tensorflow_probability/substrates/{}/{}`'.format(substrate, path),
      '# For more info, see substrate_runfiles_symlinks in build_defs.bzl',
      '# ' + '@' * 78,
      ])

  print(contents + footer, file=open(1, 'w', encoding='utf-8', closefd=False))