def __torch_function__()

in xformers/sparse/csr_tensor.py [0:0]


    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func in [
            torch.Tensor.bmm,
            torch.bmm,
            torch.Tensor.__matmul__,
            torch.matmul,
            torch.Tensor.matmul,
        ]:
            assert len(args) == 2
            return cls._bmm(args[0], args[1])

        if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]:
            return cls._softmax(args[0], kwargs["dim"])

        if func in [torch.Tensor.transpose, torch.transpose]:
            assert len(kwargs) == 0
            return cls._transpose(args[0], args[1], args[2])

        if func == masked_matmul:
            assert len(args) == 3
            return cls._masked_matmul(args[0], args[1], args[2])

        if func in [
            torch.Tensor.add,
            torch.add,
            torch.Tensor.__add__,
        ]:
            assert len(args) == 2
            if not (isinstance(args[0], cls) and isinstance(args[1], cls)):
                raise NotImplementedError(
                    f"{func} with {type(args[0])} and {type(args[1])} not implemented"
                )
            return cls._binary_op(func, args[0], args[1])

        if func in [
            torch.Tensor.mul,
            torch.mul,
            torch.Tensor.__mul__,
        ]:
            assert len(args) == 2
            return cls._binary_op(func, args[0], args[1])

        if func in [torch.Tensor.logical_and, torch.logical_and, torch.Tensor.__and__]:
            assert len(args) == 2
            return cls._binary_op_slow(func, args[0], args[1])

        if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]:
            x = args[0]
            values = x.__values.clone()
            values = func(values, *args[1:], **kwargs)
            return cls._wrap(
                x.shape,
                values,
                x.__row_indices,
                x.__row_offsets,
                x.__column_indices,
                x.__transp_info,
            )

        if func == torch.Tensor.to:
            # print(args, kwargs)
            assert len(args) >= 2
            return cls._to(args[0], args[1])
            # return cls._to(args[0], kwargs["device"])

        if func in [torch.Tensor.copy_]:
            assert len(args) == 2
            return cls._copy(args[0], args[1])

        if func in [torch.Tensor.equal, torch.equal]:
            assert len(args) == 2
            return cls._equal(args[0], args[1])

        if func == torch.Tensor.to_dense:
            assert len(args) == 1
            return cls._to_dense(args[0])

        if func == torch.Tensor.detach:
            x = args[0]
            return cls._wrap(
                x.shape,
                x.__values.detach(),
                x.__row_indices,
                x.__row_offsets,
                x.__column_indices,
                x.__transp_info,
            )

        if func == torch.Tensor.__deepcopy__:
            x = args[0]
            memo = args[1]
            return cls._wrap(
                x.shape,
                x.__values.__deepcopy__(memo),
                x.__row_indices.__deepcopy__(memo),
                x.__row_offsets.__deepcopy__(memo),
                x.__column_indices.__deepcopy__(memo),
                tuple(v.__deepcopy__(memo) for v in x.__transp_info),
            )

        if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]:
            assert len(args) == 1
            assert len(kwargs) == 0
            x = args[0]
            return cls._wrap(
                x.shape,
                x.__values.grad,
                x.__row_indices,
                x.__row_offsets,
                x.__column_indices,
                x.__transp_info,
            )

        if func == torch.Tensor.requires_grad_:
            func(args[0].__values)

        with torch._C.DisableTorchFunction():
            ret = func(*args, **kwargs)
            # TODO: check this
            if func in torch.overrides.get_default_nowrap_functions():
                return ret
            return torch._tensor._convert(ret, cls)

        return NotImplemented