def pair()

in jax_md/smap.py [0:0]


def pair(fn: Callable[..., Array],
         displacement_or_metric: DisplacementOrMetricFn,
         species: Optional[Array]=None,
         reduce_axis: Optional[Tuple[int, ...]]=None,
         keepdims: bool=False,
         ignore_unused_parameters: bool=False,
         **kwargs) -> Callable[..., Array]:
  """Promotes a function that acts on a pair of particles to one on a system.

  Args:
    fn: A function that takes an ndarray of pairwise distances or displacements
      of shape [n, m] or [n, m, d_in] respectively as well as kwargs specifying
      parameters for the function. fn returns an ndarray of evaluations of shape
      [n, m, d_out].
    metric: A function that takes two ndarray of positions of shape
      [spatial_dimension] and [spatial_dimension] respectively and returns
      an ndarray of distances or displacements of shape [] or [d_in]
      respectively. The metric can optionally take a floating point time as a
      third argument.
    species: A list of species for the different particles. This should either
      be None (in which case it is assumed that all the particles have the same
      species), an integer ndarray of shape [n] with species data, or an
      integer in which case the species data will be specified dynamically with
      `species` giving the naximum number of types of particles. Note: that
      dynamic species specification is less efficient, because we cannot
      specialize shape information.
    reduce_axis: A list of axes to reduce over. This is supplied to jnp.sum and
      so the same convention is used.
    keepdims: A boolean specifying whether the empty dimensions should be kept
      upon reduction. This is supplied to jnp.sum and so the same convention is
      used.
    ignore_unused_parameters: A boolean that denotes whether dynamically
      specified keyword arguments passed to the mapped function get ignored
      if they were not first specified as keyword arguments when calling
      `smap.pair(...)`.
    kwargs: Arguments providing parameters to the mapped function. In cases
      where no species information is provided these should be either 1) a
      scalar, 2) an ndarray of shape [n], 3) an ndarray of shape [n, n],
      3) a binary function that determines how per-particle parameters are to
      be combined, 4) a binary function as well as a default set of parameters
      as in 2). If unspecified then this is taken to be the average of the
      two per-particle parameters. If species information is provided then the
      parameters should be specified as either 1) a scalar or 2) an ndarray of
      shape [max_species, max_species].

  Returns:
    A function fn_mapped.

    If species is None or statically specified then fn_mapped takes as arguments
    an ndarray of positions of shape [n, spatial_dimension].

    If species is dynamic then fn_mapped takes as input an ndarray of shape
    [n, spatial_dimension], an integer ndarray of species of shape [n], and an
    integer specifying the maximum species.

    The mapped function can also optionally take keyword arguments that get
    threaded through the metric.
  """

  # Each application of vmap adds a single batch dimension. For computations
  # over all pairs of particles, we would like to promote the metric function
  # from one that computes the displacement / distance between two vectors to
  # one that acts over the cartesian product of two sets of vectors. This is
  # equivalent to two applications of vmap adding one batch dimension for the
  # first set and then one for the second.

  kwargs, param_combinators = _split_params_and_combinators(kwargs)

  merge_dicts = partial(util.merge_dicts,
                        ignore_unused_parameters=ignore_unused_parameters)

  if species is None:
    def fn_mapped(R: Array, **dynamic_kwargs) -> Array:
      d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs))
      _kwargs = merge_dicts(kwargs, dynamic_kwargs)
      _kwargs = _kwargs_to_parameters(None, _kwargs, param_combinators)
      dr = d(R, R)
      # NOTE(schsam): Currently we place a diagonal mask no matter what function
      # we are mapping. Should this be an option?
      return high_precision_sum(_diagonal_mask(fn(dr, **_kwargs)),
                                axis=reduce_axis, keepdims=keepdims) * f32(0.5)
  elif util.is_array(species):
    species = onp.array(species)
    _check_species_dtype(species)
    species_count = int(onp.max(species))
    if reduce_axis is not None or keepdims:
      # TODO(schsam): Support reduce_axis with static species.
      raise ValueError
    def fn_mapped(R, **dynamic_kwargs):
      U = f32(0.0)
      d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs))
      for i in range(species_count + 1):
        for j in range(i, species_count + 1):
          _kwargs = merge_dicts(kwargs, dynamic_kwargs)
          s_kwargs = _kwargs_to_parameters((i, j), _kwargs, param_combinators)
          Ra = R[species == i]
          Rb = R[species == j]
          dr = d(Ra, Rb)
          if j == i:
            dU = high_precision_sum(_diagonal_mask(fn(dr, **s_kwargs)))
            U = U + f32(0.5) * dU
          else:
            dU = high_precision_sum(fn(dr, **s_kwargs))
            U = U + dU
      return U
  elif isinstance(species, int):
    species_count = species
    def fn_mapped(R, species, **dynamic_kwargs):
      _check_species_dtype(species)
      U = f32(0.0)
      N = R.shape[0]
      d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs))
      _kwargs = merge_dicts(kwargs, dynamic_kwargs)
      dr = d(R, R)
      for i in range(species_count):
        for j in range(species_count):
          s_kwargs = _kwargs_to_parameters((i, j), _kwargs, param_combinators)
          mask_a = jnp.array(jnp.reshape(species == i, (N,)), dtype=R.dtype)
          mask_b = jnp.array(jnp.reshape(species == j, (N,)), dtype=R.dtype)
          mask = mask_a[:, jnp.newaxis] * mask_b[jnp.newaxis, :]
          if i == j:
            mask = mask * _diagonal_mask(mask)
          dU = mask * fn(dr, **s_kwargs)
          U = U + high_precision_sum(dU, axis=reduce_axis, keepdims=keepdims)
      return U / f32(2.0)
  else:
    raise ValueError(
        'Species must be None, an ndarray, or an integer. Found {}.'.format(
          species))
  return fn_mapped