def pmap()

in trax/tf_numpy/extensions/extensions.py [0:0]


def pmap(f, axis_name=None, devices=None):
  """Transforms a function into a multi-device function.

  The semantics are similar to JAX's pmap.

  Args:
    f: The function to be converted.
    axis_name: Used for nested pmap, which is not supported yet.
    devices: The devices over which the returned function will run.

  Returns:
    A function that runs the underlying function `f` on `devices`. Its arguments
    can be `ShardedNdArray`s, tensors or other Python objects, and its return
    values are all `ShardedNdArray`s. If an input is a tensor, the length of its
    first dimension must equal the number of devices, and the tensor will be
    splitted along its first dimension among the devices. If an input is an
    unknown Python object, it will be replicated among the devices.
  """
  if devices is None:
    devices = accelerators()
  if not isinstance(devices, (list, tuple)):
    raise ValueError("Must pass a list or tuple of devices")
  num_devices = len(devices)
  if not num_devices:
    raise ValueError("There must be at least 1 device")
  has_tpu = bool(tpu_devices(devices))

  pmap_fn = _get_pmap_impl(f, devices, has_tpu)

  def wrapper(*args):
    """Wrapper that wraps/unwraps args, retvals, and runs the function."""
    if _pmap_config.devices() is not None:
      raise ValueError("Found a surrounding pmap. Nested pmap is not supported "
                       "yet.")
    # TODO(wangpeng): Maybe we should use `asarray` to convert everything
    # to ndarray first.

    flattened_input_args = tf.nest.flatten(args)
    flattened_per_device_args = [[] for _ in devices]
    for arg in flattened_input_args:
      if isinstance(arg, tf.Tensor):
        # TODO(nareshmodi): Try and use the dynamic shape instead.
        if (not arg.shape.rank) or arg.shape[0] != len(devices):
          # TODO(nareshmodi): Fix this restriction
          raise ValueError(
              "Input tensors need to have a first dimension equal to "
              "the number of devices; got tensor of shape %s and %s devices" %
              (arg.shape, len(devices)))
        # NOTE: Alternatively use tf.split, and place the split tensors on the
        # appropriate device. The best solution for this is to have an API that
        # splits a tensor across devices.
        for j, device in enumerate(devices):
          updated_arg = tf.gather(arg, j)
          # TODO(wangpeng): Investigate whether we need a tf.identity for TPU.
          if not has_tpu:
            with tf.device(device):
              updated_arg = tf.identity(updated_arg)
          flattened_per_device_args[j].append(updated_arg)
      elif isinstance(arg, ShardedNdArray):
        for device_args, tensor in zip(flattened_per_device_args, arg.tensors):
          device_args.append(tensor)
      else:
        for device_args in flattened_per_device_args:
          device_args.append(arg)

    all_per_device_args = [
        tf.nest.pack_sequence_as(args, device_args)
        for device_args in flattened_per_device_args
    ]

    with pmap_config(axis_name, devices):
      results = pmap_fn(all_per_device_args)

    # Rewrap things. This can probably be written better.
    flattened_results = [tf.nest.flatten(result) for result in results]
    final_tree = []

    # TODO(nareshmodi): assert all items in flattened_results have the same
    # structures

    for i in range(len(flattened_results[0])):
      tensors = []
      for j, device in enumerate(devices):
        assert isinstance(
            flattened_results[j][i],
            tf.Tensor), ("currently only tensor return items are supported")
        tensors.append(flattened_results[j][i])
      final_tree.append(ShardedNdArray(tensors))

    return tf.nest.pack_sequence_as(results[0], final_tree)

  return wrapper