def _binary_op()

in xformers/sparse/csr_tensor.py [0:0]


    def _binary_op(cls, func, arg0, arg1):
        if not (
            isinstance(arg0, (cls, int, float)) and isinstance(arg1, (cls, int, float))
        ):
            return NotImplemented
        v0, v1 = arg0, arg1
        if isinstance(arg0, cls):
            v0 = arg0.__values
        if isinstance(arg1, cls):
            v1 = arg1.__values
        # assert arg0.shape == arg1.shape
        if isinstance(arg0, cls) and isinstance(arg1, cls):
            msg = f"arg0 and arg1 need to have the same sparsity pattern in {func} (for now)"
            if not arg0.__row_offsets.shape == arg1.__row_offsets.shape:
                raise NotImplementedError(msg)
            if not arg0.__column_indices.shape == arg1.__column_indices.shape:
                raise NotImplementedError(msg)
            if not arg0.__values.shape == arg1.__values.shape:
                raise NotImplementedError(msg)
            # TODO this is not always true, but is a fast approximation for now
            if arg0.__row_offsets is not arg1.__row_offsets:
                raise NotImplementedError(msg)
            if arg0.__column_indices is not arg1.__column_indices:
                raise NotImplementedError(msg)
        out = func(v0, v1)
        return cls._wrap(
            arg0.shape,
            out,
            arg0.__row_indices,
            arg0.__row_offsets,
            arg0.__column_indices,
            arg0.__transp_info,
        )