in jax_md/smap.py [0:0]
def pair_neighbor_list(fn: Callable[..., Array],
displacement_or_metric: DisplacementOrMetricFn,
species: Optional[Array]=None,
reduce_axis: Optional[Tuple[int, ...]]=None,
ignore_unused_parameters: bool=False,
**kwargs) -> Callable[..., Array]:
"""Promotes a function acting on pairs of particles to use neighbor lists.
Args:
fn: A function that takes an ndarray of pairwise distances or displacements
of shape [n, m] or [n, m, d_in] respectively as well as kwargs specifying
parameters for the function. fn returns an ndarray of evaluations of
shape [n, m, d_out].
metric: A function that takes two ndarray of positions of shape
[spatial_dimension] and [spatial_dimension] respectively and returns
an ndarray of distances or displacements of shape [] or [d_in]
respectively. The metric can optionally take a floating point time as a
third argument.
species: Species information for the different particles. Should either
be None (in which case it is assumed that all the particles have the same
species), an integer array of shape [n] with species data. Note that
species data can be specified dynamically by passing a `species` keyword
argument to the mapped function.
reduce_axis: A list of axes to reduce over. We use a convention where axis
0 corresponds to the particles, axis 1 corresponds to neighbors, and the
remaining axes correspond to the output axes of `fn`. Note that it is not
well-defined to sum over particles without summing over neighbors. One
also cannot report per-particle values (excluding axis 0) for neighbor
lists whose format is `OrderedSparse`.
ignore_unused_parameters: A boolean that denotes whether dynamically
specified keyword arguments passed to the mapped function get ignored
if they were not first specified as keyword arguments when calling
`smap.pair_neighbor_list(...)`.
kwargs: Arguments providing parameters to the mapped function. In cases
where no species information is provided these should be either 1) a
scalar, 2) an ndarray of shape [n], 3) an ndarray of shape [n, n],
3) a binary function that determines how per-particle parameters are to
be combined. If unspecified then this is taken to be the average of the
two per-particle parameters. If species information is provided then the
parameters should be specified as either 1) a scalar or 2) an ndarray of
shape [max_species, max_species].
Returns:
A function fn_mapped that takes an ndarray of floats of shape [N, d_in] of
positions and and ndarray of integers of shape [N, max_neighbors]
specifying neighbors.
"""
kwargs, param_combinators = _split_params_and_combinators(kwargs)
merge_dicts = partial(util.merge_dicts,
ignore_unused_parameters=ignore_unused_parameters)
def fn_mapped(R: Array, neighbor: partition.NeighborList, **dynamic_kwargs
) -> Array:
d = partial(displacement_or_metric, **dynamic_kwargs)
_species = dynamic_kwargs.get('species', species)
normalization = 2.0
if partition.is_sparse(neighbor.format):
d = space.map_bond(d)
dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]])
mask = neighbor.idx[0] < R.shape[0]
if neighbor.format is partition.OrderedSparse:
normalization = 1.0
else:
d = space.map_neighbor(d)
R_neigh = R[neighbor.idx]
dR = d(R, R_neigh)
mask = neighbor.idx < R.shape[0]
merged_kwargs = merge_dicts(kwargs, dynamic_kwargs)
merged_kwargs = _neighborhood_kwargs_to_params(neighbor.format,
neighbor.idx,
_species,
merged_kwargs,
param_combinators)
out = fn(dR, **merged_kwargs)
if out.ndim > mask.ndim:
ddim = out.ndim - mask.ndim
mask = jnp.reshape(mask, mask.shape + (1,) * ddim)
out *= mask
if reduce_axis is None:
return util.high_precision_sum(out) / normalization
if 0 in reduce_axis and 1 not in reduce_axis:
raise ValueError()
if not partition.is_sparse(neighbor.format):
return util.high_precision_sum(out, reduce_axis) / normalization
_reduce_axis = tuple(a - 1 for a in reduce_axis if a > 1)
if 0 in reduce_axis:
return util.high_precision_sum(out, (0,) + _reduce_axis)
if neighbor.format is partition.OrderedSparse:
raise ValueError('Cannot report per-particle values with a neighbor '
'list whose format is `OrderedSparse`. Please use '
'either `Dense` or `Sparse`.')
out = util.high_precision_sum(out, _reduce_axis)
return ops.segment_sum(out, neighbor.idx[0], R.shape[0]) / normalization
return fn_mapped