in jax_md/smap.py [0:0]
def pair(fn: Callable[..., Array],
displacement_or_metric: DisplacementOrMetricFn,
species: Optional[Array]=None,
reduce_axis: Optional[Tuple[int, ...]]=None,
keepdims: bool=False,
ignore_unused_parameters: bool=False,
**kwargs) -> Callable[..., Array]:
"""Promotes a function that acts on a pair of particles to one on a system.
Args:
fn: A function that takes an ndarray of pairwise distances or displacements
of shape [n, m] or [n, m, d_in] respectively as well as kwargs specifying
parameters for the function. fn returns an ndarray of evaluations of shape
[n, m, d_out].
metric: A function that takes two ndarray of positions of shape
[spatial_dimension] and [spatial_dimension] respectively and returns
an ndarray of distances or displacements of shape [] or [d_in]
respectively. The metric can optionally take a floating point time as a
third argument.
species: A list of species for the different particles. This should either
be None (in which case it is assumed that all the particles have the same
species), an integer ndarray of shape [n] with species data, or an
integer in which case the species data will be specified dynamically with
`species` giving the naximum number of types of particles. Note: that
dynamic species specification is less efficient, because we cannot
specialize shape information.
reduce_axis: A list of axes to reduce over. This is supplied to jnp.sum and
so the same convention is used.
keepdims: A boolean specifying whether the empty dimensions should be kept
upon reduction. This is supplied to jnp.sum and so the same convention is
used.
ignore_unused_parameters: A boolean that denotes whether dynamically
specified keyword arguments passed to the mapped function get ignored
if they were not first specified as keyword arguments when calling
`smap.pair(...)`.
kwargs: Arguments providing parameters to the mapped function. In cases
where no species information is provided these should be either 1) a
scalar, 2) an ndarray of shape [n], 3) an ndarray of shape [n, n],
3) a binary function that determines how per-particle parameters are to
be combined, 4) a binary function as well as a default set of parameters
as in 2). If unspecified then this is taken to be the average of the
two per-particle parameters. If species information is provided then the
parameters should be specified as either 1) a scalar or 2) an ndarray of
shape [max_species, max_species].
Returns:
A function fn_mapped.
If species is None or statically specified then fn_mapped takes as arguments
an ndarray of positions of shape [n, spatial_dimension].
If species is dynamic then fn_mapped takes as input an ndarray of shape
[n, spatial_dimension], an integer ndarray of species of shape [n], and an
integer specifying the maximum species.
The mapped function can also optionally take keyword arguments that get
threaded through the metric.
"""
# Each application of vmap adds a single batch dimension. For computations
# over all pairs of particles, we would like to promote the metric function
# from one that computes the displacement / distance between two vectors to
# one that acts over the cartesian product of two sets of vectors. This is
# equivalent to two applications of vmap adding one batch dimension for the
# first set and then one for the second.
kwargs, param_combinators = _split_params_and_combinators(kwargs)
merge_dicts = partial(util.merge_dicts,
ignore_unused_parameters=ignore_unused_parameters)
if species is None:
def fn_mapped(R: Array, **dynamic_kwargs) -> Array:
d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs))
_kwargs = merge_dicts(kwargs, dynamic_kwargs)
_kwargs = _kwargs_to_parameters(None, _kwargs, param_combinators)
dr = d(R, R)
# NOTE(schsam): Currently we place a diagonal mask no matter what function
# we are mapping. Should this be an option?
return high_precision_sum(_diagonal_mask(fn(dr, **_kwargs)),
axis=reduce_axis, keepdims=keepdims) * f32(0.5)
elif util.is_array(species):
species = onp.array(species)
_check_species_dtype(species)
species_count = int(onp.max(species))
if reduce_axis is not None or keepdims:
# TODO(schsam): Support reduce_axis with static species.
raise ValueError
def fn_mapped(R, **dynamic_kwargs):
U = f32(0.0)
d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs))
for i in range(species_count + 1):
for j in range(i, species_count + 1):
_kwargs = merge_dicts(kwargs, dynamic_kwargs)
s_kwargs = _kwargs_to_parameters((i, j), _kwargs, param_combinators)
Ra = R[species == i]
Rb = R[species == j]
dr = d(Ra, Rb)
if j == i:
dU = high_precision_sum(_diagonal_mask(fn(dr, **s_kwargs)))
U = U + f32(0.5) * dU
else:
dU = high_precision_sum(fn(dr, **s_kwargs))
U = U + dU
return U
elif isinstance(species, int):
species_count = species
def fn_mapped(R, species, **dynamic_kwargs):
_check_species_dtype(species)
U = f32(0.0)
N = R.shape[0]
d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs))
_kwargs = merge_dicts(kwargs, dynamic_kwargs)
dr = d(R, R)
for i in range(species_count):
for j in range(species_count):
s_kwargs = _kwargs_to_parameters((i, j), _kwargs, param_combinators)
mask_a = jnp.array(jnp.reshape(species == i, (N,)), dtype=R.dtype)
mask_b = jnp.array(jnp.reshape(species == j, (N,)), dtype=R.dtype)
mask = mask_a[:, jnp.newaxis] * mask_b[jnp.newaxis, :]
if i == j:
mask = mask * _diagonal_mask(mask)
dU = mask * fn(dr, **s_kwargs)
U = U + high_precision_sum(dU, axis=reduce_axis, keepdims=keepdims)
return U / f32(2.0)
else:
raise ValueError(
'Species must be None, an ndarray, or an integer. Found {}.'.format(
species))
return fn_mapped