def pair_neighbor_list()

in jax_md/smap.py [0:0]


def pair_neighbor_list(fn: Callable[..., Array],
                       displacement_or_metric: DisplacementOrMetricFn,
                       species: Optional[Array]=None,
                       reduce_axis: Optional[Tuple[int, ...]]=None,
                       ignore_unused_parameters: bool=False,
                       **kwargs) -> Callable[..., Array]:
  """Promotes a function acting on pairs of particles to use neighbor lists.

  Args:
    fn: A function that takes an ndarray of pairwise distances or displacements
      of shape [n, m] or [n, m, d_in] respectively as well as kwargs specifying
      parameters for the function. fn returns an ndarray of evaluations of
      shape [n, m, d_out].
    metric: A function that takes two ndarray of positions of shape
      [spatial_dimension] and [spatial_dimension] respectively and returns
      an ndarray of distances or displacements of shape [] or [d_in]
      respectively. The metric can optionally take a floating point time as a
      third argument.
    species: Species information for the different particles. Should either
      be None (in which case it is assumed that all the particles have the same
      species), an integer array of shape [n] with species data. Note that
      species data can be specified dynamically by passing a `species` keyword
      argument to the mapped function.
    reduce_axis: A list of axes to reduce over. We use a convention where axis
      0 corresponds to the particles, axis 1 corresponds to neighbors, and the
      remaining axes correspond to the output axes of `fn`. Note that it is not
      well-defined to sum over particles without summing over neighbors. One
      also cannot report per-particle values (excluding axis 0) for neighbor
      lists whose format is `OrderedSparse`.
    ignore_unused_parameters: A boolean that denotes whether dynamically
      specified keyword arguments passed to the mapped function get ignored
      if they were not first specified as keyword arguments when calling
      `smap.pair_neighbor_list(...)`.
    kwargs: Arguments providing parameters to the mapped function. In cases
      where no species information is provided these should be either 1) a
      scalar, 2) an ndarray of shape [n], 3) an ndarray of shape [n, n],
      3) a binary function that determines how per-particle parameters are to
      be combined. If unspecified then this is taken to be the average of the
      two per-particle parameters. If species information is provided then the
      parameters should be specified as either 1) a scalar or 2) an ndarray of
      shape [max_species, max_species].

  Returns:
    A function fn_mapped that takes an ndarray of floats of shape [N, d_in] of
    positions and and ndarray of integers of shape [N, max_neighbors]
    specifying neighbors.
  """
  kwargs, param_combinators = _split_params_and_combinators(kwargs)
  merge_dicts = partial(util.merge_dicts,
                        ignore_unused_parameters=ignore_unused_parameters)

  def fn_mapped(R: Array, neighbor: partition.NeighborList, **dynamic_kwargs
                ) -> Array:
    d = partial(displacement_or_metric, **dynamic_kwargs)
    _species = dynamic_kwargs.get('species', species)

    normalization = 2.0

    if partition.is_sparse(neighbor.format):
      d = space.map_bond(d)
      dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]])
      mask = neighbor.idx[0] < R.shape[0]
      if neighbor.format is partition.OrderedSparse:
        normalization = 1.0
    else:
      d = space.map_neighbor(d)
      R_neigh = R[neighbor.idx]
      dR = d(R, R_neigh)
      mask = neighbor.idx < R.shape[0]

    merged_kwargs = merge_dicts(kwargs, dynamic_kwargs)
    merged_kwargs = _neighborhood_kwargs_to_params(neighbor.format,
                                                   neighbor.idx,
                                                   _species,
                                                   merged_kwargs,
                                                   param_combinators)
    out = fn(dR, **merged_kwargs)
    if out.ndim > mask.ndim:
      ddim = out.ndim - mask.ndim
      mask = jnp.reshape(mask, mask.shape + (1,) * ddim)
    out *= mask

    if reduce_axis is None:
      return util.high_precision_sum(out) / normalization

    if 0 in reduce_axis and 1 not in reduce_axis:
      raise ValueError()

    if not partition.is_sparse(neighbor.format):
      return util.high_precision_sum(out, reduce_axis) / normalization

    _reduce_axis = tuple(a - 1 for a in reduce_axis if a > 1)

    if 0 in reduce_axis:
      return util.high_precision_sum(out, (0,) + _reduce_axis)

    if neighbor.format is partition.OrderedSparse:
      raise ValueError('Cannot report per-particle values with a neighbor '
                       'list whose format is `OrderedSparse`. Please use '
                       'either `Dense` or `Sparse`.')

    out = util.high_precision_sum(out, _reduce_axis)
    return ops.segment_sum(out, neighbor.idx[0], R.shape[0]) / normalization
  return fn_mapped