in xformers/sparse/csr_tensor.py [0:0]
def _binary_op(cls, func, arg0, arg1):
if not (
isinstance(arg0, (cls, int, float)) and isinstance(arg1, (cls, int, float))
):
return NotImplemented
v0, v1 = arg0, arg1
if isinstance(arg0, cls):
v0 = arg0.__values
if isinstance(arg1, cls):
v1 = arg1.__values
# assert arg0.shape == arg1.shape
if isinstance(arg0, cls) and isinstance(arg1, cls):
msg = f"arg0 and arg1 need to have the same sparsity pattern in {func} (for now)"
if not arg0.__row_offsets.shape == arg1.__row_offsets.shape:
raise NotImplementedError(msg)
if not arg0.__column_indices.shape == arg1.__column_indices.shape:
raise NotImplementedError(msg)
if not arg0.__values.shape == arg1.__values.shape:
raise NotImplementedError(msg)
# TODO this is not always true, but is a fast approximation for now
if arg0.__row_offsets is not arg1.__row_offsets:
raise NotImplementedError(msg)
if arg0.__column_indices is not arg1.__column_indices:
raise NotImplementedError(msg)
out = func(v0, v1)
return cls._wrap(
arg0.shape,
out,
arg0.__row_indices,
arg0.__row_offsets,
arg0.__column_indices,
arg0.__transp_info,
)