def gen_module()

in tensorflow_probability/python/internal/backend/meta/gen_linear_operators.py [0:0]


def gen_module(module_name):
  """Rewrite for numpy the code loaded from the given linalg module."""
  module = importlib.import_module(
      'tensorflow.python.ops.linalg.{}'.format(module_name))
  code = inspect.getsource(module)
  for k, v in MODULE_MAPPINGS.items():
    code = code.replace(
        'from tensorflow.python.{}'.format(k),
        'from tensorflow_probability.python.internal.backend.numpy '
        'import {}'.format(v))
  for k in COMMENT_OUT:
    code = code.replace(k, '# {}'.format(k))
  code = code.replace(
      '.backend.numpy import tensor_shape',
      '.backend.numpy.gen import tensor_shape')
  code = code.replace(
      'from tensorflow.python.platform import tf_logging',
      'from absl import logging')
  code = code.replace(
      'from tensorflow.python.framework import '
      'composite_tensor',
      'from tensorflow_probability.python.internal.backend.numpy '
      'import composite_tensor')
  code = code.replace(
      'from tensorflow.python.ops import '
      'resource_variable_ops',
      'from tensorflow_probability.python.internal.backend.numpy '
      'import resource_variable_ops')
  code = code.replace(
      'from tensorflow.python.framework import tensor_spec',
      'from tensorflow_probability.python.internal.backend.numpy import '
      'tensor_spec')
  code = code.replace(
      'from tensorflow.python.framework import type_spec',
      'from tensorflow_probability.python.internal.backend.numpy '
      'import type_spec')
  code = code.replace(
      'from tensorflow.python.ops import variables',
      'from tensorflow_probability.python.internal.backend.numpy '
      'import variables')
  code = code.replace(
      'from tensorflow.python.training.tracking '
      'import data_structures',
      'from tensorflow_probability.python.internal.backend.numpy '
      'import data_structures')
  code = re.sub(
      r'from tensorflow\.python\.linalg import (\w+)',
      'from tensorflow_probability.python.internal.backend.numpy.gen import \\1 '
      'as \\1', code)
  code = code.replace(
      'from tensorflow.python.ops.linalg import ',
      '# from tensorflow.python.ops.linalg import ')
  for f in FLAGS.allowlist:
    code = code.replace(
        '# from tensorflow.python.ops.linalg '
        'import {}'.format(f),
        'from tensorflow.python.ops.linalg '
        'import {}'.format(f))
  code = code.replace(
      'tensorflow.python.ops.linalg import',
      'tensorflow_probability.python.internal.backend.numpy.gen import')
  code = code.replace(
      'tensorflow.python.util import',
      'tensorflow_probability.python.internal.backend.numpy import')
  code = code.replace('tensor_util.constant_value(', 'ops.get_static_value(')
  code = code.replace('tensor_util.is_tensor(', 'ops.is_tensor(')
  code = code.replace('tensor_util.is_tf_type(', 'ops.is_tensor(')
  code = code.replace(
      'from tensorflow.python.ops.distributions import '
      'util as distribution_util', UTIL_IMPORTS)
  code = code.replace(
      'control_flow_ops.with_dependencies',
      'distribution_util.with_dependencies')
  code = code.replace('.base_dtype', '')
  code = code.replace('.get_shape()', '.shape')
  code = re.sub(r'([_a-zA-Z0-9.\[\]]+\.shape)([^(_])',
                'tensor_shape.TensorShape(\\1)\\2', code)
  code = re.sub(r'([_a-zA-Z0-9.\[\]]+).is_floating',
                'np.issubdtype(\\1, np.floating)', code)
  code = re.sub(r'([_a-zA-Z0-9.\[\]]+).is_complex',
                'np.issubdtype(\\1, np.complexfloating)', code)
  code = re.sub(r'([_a-zA-Z0-9.\[\]]+).is_integer',
                'np.issubdtype(\\1, np.integer)', code)

  code = code.replace('array_ops.shape', 'prefer_static.shape')
  code = code.replace('array_ops.concat', 'prefer_static.concat')
  code = code.replace('array_ops.broadcast_static_shape',
                      '_ops.broadcast_static_shape')
  code = code.replace('array_ops.broadcast_to', '_ops.broadcast_to')
  code = code.replace('array_ops.matrix_diag', '_linalg.diag')
  code = code.replace('array_ops.matrix_band_part', '_linalg.band_part')
  code = code.replace('array_ops.matrix_diag_part', '_linalg.diag_part')
  code = code.replace('array_ops.matrix_set_diag', '_linalg.set_diag')
  code = code.replace('array_ops.matrix_transpose', '_linalg.matrix_transpose')
  code = code.replace('array_ops.newaxis', '_ops.newaxis')
  code = code.replace('linalg_ops.matrix_determinant', '_linalg.det')
  code = code.replace('linalg_ops.matrix_solve', '_linalg.solve')
  code = code.replace('linalg_ops.matrix_triangular_solve',
                      'linalg_ops.triangular_solve')
  code = code.replace('math_ops.cast', '_ops.cast')
  code = code.replace('math_ops.matmul', '_linalg.matmul')
  code = code.replace('math_ops.range', 'array_ops.range')
  code = code.replace('ops.convert_to_tensor_v2_with_dispatch(',
                      'ops.convert_to_tensor(')
  code = code.replace('ops.convert_to_tensor(dim_value)',
                      'np.array(dim_value, np.int32)')

  code = code.replace('self.dtype.real_dtype', 'dtypes.real_dtype(self.dtype)')
  code = code.replace('dtype.real_dtype', 'dtypes.real_dtype(dtype)')
  code = code.replace('.as_numpy_dtype', '')

  # Replace `x.set_shape(...)` with `tensorshape_util.set_shape(x, ...)`.
  code = re.sub(r' (\w*)\.set_shape\(',
                ' tensorshape_util.set_shape(\\1, ', code)

  # Replace in-place Python operators (e.g. `+=`) with implicit copying.
  code = re.sub(r'([_a-zA-Z0-9.\[\]]+)[ ]{0,1}(\+|\-|\*|\/)[\=][ ]{0,1}',
                '\\1 = \\1 \\2 ', code)

  for lint in DISABLED_LINTS:
    code = code.replace('pylint: enable={}'.format(lint),
                        'pylint: disable={}'.format(lint))

  print('# Copyright 2020 The TensorFlow Probability Authors. '
        'All Rights Reserved.')
  print('# ' + '@' * 78)
  print('# THIS FILE IS AUTO-GENERATED BY `gen_linear_operators.py`.')
  print('# DO NOT MODIFY DIRECTLY.')
  print('# ' + '@' * 78)
  for lint in DISABLED_LINTS:
    print('# pylint: disable={}'.format(lint))
  print()
  print(code)
  print('import numpy as np')
  print('from tensorflow_probability.python.internal.backend.numpy import '
        'linalg_impl as _linalg')
  print('from tensorflow_probability.python.internal.backend.numpy import '
        'ops as _ops')
  print('from tensorflow_probability.python.internal.backend.numpy.gen import '
        'tensor_shape')
  if module_name == 'linear_operator_util':
    print(LINOP_UTIL_SUFFIX)
  print(UTIL_IMPORTS)