in jax_md/partition.py [0:0]
def cell_list(box_size: Box,
minimum_cell_size: float,
buffer_size_multiplier: float = 1.25
) -> CellListFns:
r"""Returns a function that partitions point data spatially.
Given a set of points {x_i \in R^d} with associated data {k_i \in R^m} it is
often useful to partition the points / data spatially. A simple partitioning
that can be implemented efficiently within XLA is a dense partition into a
uniform grid called a cell list.
Since XLA requires that shapes be statically specified inside of a JIT block,
the cell list code can operate in two modes: allocation and update.
Allocation creates a new cell list that uses a set of input positions to
estimate the capacity of the cell list. This capacity can be adjusted by
setting the `buffer_size_multiplier` or setting the `extra_capacity`.
Allocation cannot be JIT.
Updating takes a previously allocated cell list and places a new set of
particles in the cells. Updating cannot resize the cell list and is therefore
compatible with JIT. However, if the configuration has changed substantially
it is possible that the existing cell list won't be large enough to
accommodate all of the particles. In this case the `did_buffer_overflow` bit
will be set to True.
Args:
box_size: A float or an ndarray of shape [spatial_dimension] specifying the
size of the system. Note, this code is written for the case where the
boundaries are periodic. If this is not the case, then the current code
will be slightly less efficient.
minimum_cell_size: A float specifying the minimum side length of each cell.
Cells are enlarged so that they exactly fill the box.
buffer_size_multiplier: A floating point multiplier that multiplies the
estimated cell capacity to allow for fluctuations in the maximum cell
occupancy.
Returns:
A CellListFns object that contains two methods, one to allocate the cell
list and one to update the cell list. The update function can be called
with either a cell list from which the capacity can be inferred or with
an explicit integer denoting the capacity. Note that an existing cell list
can also be updated by calling `cell_list.update(position)`.
"""
if util.is_array(box_size):
box_size = onp.array(box_size)
if len(box_size.shape) == 1:
box_size = jnp.reshape(box_size, (1, -1))
if util.is_array(minimum_cell_size):
minimum_cell_size = onp.array(minimum_cell_size)
def cell_list_fn(position: Array,
capacity_overflow_update: Optional[
Tuple[int, bool, Callable[..., CellList]]] = None,
extra_capacity: int = 0, **kwargs) -> CellList:
N = position.shape[0]
dim = position.shape[1]
if dim != 2 and dim != 3:
# NOTE(schsam): Do we want to check this in compute_fn as well?
raise ValueError(
f'Cell list spatial dimension must be 2 or 3. Found {dim}.')
_, cell_size, cells_per_side, cell_count = \
_cell_dimensions(dim, box_size, minimum_cell_size)
if capacity_overflow_update is None:
cell_capacity = _estimate_cell_capacity(position, box_size, cell_size,
buffer_size_multiplier)
cell_capacity += extra_capacity
overflow = False
update_fn = cell_list_fn
else:
cell_capacity, overflow, update_fn = capacity_overflow_update
hash_multipliers = _compute_hash_constants(dim, cells_per_side)
# Create cell list data.
particle_id = lax.iota(i32, N)
# NOTE(schsam): We use the convention that particles that are successfully,
# copied have their true id whereas particles empty slots have id = N.
# Then when we copy data back from the grid, copy it to an array of shape
# [N + 1, output_dimension] and then truncate it to an array of shape
# [N, output_dimension] which ignores the empty slots.
cell_position = jnp.zeros((cell_count * cell_capacity, dim),
dtype=position.dtype)
cell_id = N * jnp.ones((cell_count * cell_capacity, 1), dtype=i32)
# It might be worth adding an occupied mask. However, that will involve
# more compute since often we will do a mask for species that will include
# an occupancy test. It seems easier to design around this empty_data_value
# for now and revisit the issue if it comes up later.
empty_kwarg_value = 10 ** 5
cell_kwargs = {}
# pytype: disable=attribute-error
for k, v in kwargs.items():
if not util.is_array(v):
raise ValueError((f'Data must be specified as an ndarray. Found "{k}" '
f'with type {type(v)}.'))
if v.shape[0] != position.shape[0]:
raise ValueError(('Data must be specified per-particle (an ndarray '
f'with shape ({N}, ...)). Found "{k}" with '
f'shape {v.shape}.'))
kwarg_shape = v.shape[1:] if v.ndim > 1 else (1,)
cell_kwargs[k] = empty_kwarg_value * jnp.ones(
(cell_count * cell_capacity,) + kwarg_shape, v.dtype)
# pytype: enable=attribute-error
indices = jnp.array(position / cell_size, dtype=i32)
hashes = jnp.sum(indices * hash_multipliers, axis=1)
# Copy the particle data into the grid. Here we use a trick to allow us to
# copy into all cells simultaneously using a single lax.scatter call. To do
# this we first sort particles by their cell hash. We then assign each
# particle to have a cell id = hash * cell_capacity + grid_id where
# grid_id is a flat list that repeats 0, .., cell_capacity. So long as
# there are fewer than cell_capacity particles per cell, each particle is
# guarenteed to get a cell id that is unique.
sort_map = jnp.argsort(hashes)
sorted_position = position[sort_map]
sorted_hash = hashes[sort_map]
sorted_id = particle_id[sort_map]
sorted_kwargs = {}
for k, v in kwargs.items():
sorted_kwargs[k] = v[sort_map]
sorted_cell_id = jnp.mod(lax.iota(i32, N), cell_capacity)
sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id
cell_position = cell_position.at[sorted_cell_id].set(sorted_position)
sorted_id = jnp.reshape(sorted_id, (N, 1))
cell_id = cell_id.at[sorted_cell_id].set(sorted_id)
cell_position = _unflatten_cell_buffer(cell_position, cells_per_side, dim)
cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)
for k, v in sorted_kwargs.items():
if v.ndim == 1:
v = jnp.reshape(v, v.shape + (1,))
cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v)
cell_kwargs[k] = _unflatten_cell_buffer(
cell_kwargs[k], cells_per_side, dim)
occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)
max_occupancy = jnp.max(occupancy)
overflow = overflow | (max_occupancy >= cell_capacity)
return CellList(cell_position, cell_id, cell_kwargs,
overflow, cell_capacity, update_fn) # pytype: disable=wrong-arg-count
def allocate_fn(position: Array, extra_capacity: int = 0, **kwargs
) -> CellList:
return cell_list_fn(position, extra_capacity=extra_capacity, **kwargs)
def update_fn(position: Array, cl_or_capacity: Union[CellList, int], **kwargs
) -> CellList:
if isinstance(cl_or_capacity, int):
capacity = int(cl_or_capacity)
return cell_list_fn(position, (capacity, False, cell_list_fn), **kwargs)
cl = cl_or_capacity
cl_data = (cl.cell_capacity, cl.did_buffer_overflow, cl.update_fn)
return cell_list_fn(position, cl_data, **kwargs)
return CellListFns(allocate_fn, update_fn) # pytype: disable=wrong-arg-count