in xformers/sparse/csr_tensor.py [0:0]
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in [
torch.Tensor.bmm,
torch.bmm,
torch.Tensor.__matmul__,
torch.matmul,
torch.Tensor.matmul,
]:
assert len(args) == 2
return cls._bmm(args[0], args[1])
if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]:
return cls._softmax(args[0], kwargs["dim"])
if func in [torch.Tensor.transpose, torch.transpose]:
assert len(kwargs) == 0
return cls._transpose(args[0], args[1], args[2])
if func == masked_matmul:
assert len(args) == 3
return cls._masked_matmul(args[0], args[1], args[2])
if func in [
torch.Tensor.add,
torch.add,
torch.Tensor.__add__,
]:
assert len(args) == 2
if not (isinstance(args[0], cls) and isinstance(args[1], cls)):
raise NotImplementedError(
f"{func} with {type(args[0])} and {type(args[1])} not implemented"
)
return cls._binary_op(func, args[0], args[1])
if func in [
torch.Tensor.mul,
torch.mul,
torch.Tensor.__mul__,
]:
assert len(args) == 2
return cls._binary_op(func, args[0], args[1])
if func in [torch.Tensor.logical_and, torch.logical_and, torch.Tensor.__and__]:
assert len(args) == 2
return cls._binary_op_slow(func, args[0], args[1])
if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]:
x = args[0]
values = x.__values.clone()
values = func(values, *args[1:], **kwargs)
return cls._wrap(
x.shape,
values,
x.__row_indices,
x.__row_offsets,
x.__column_indices,
x.__transp_info,
)
if func == torch.Tensor.to:
# print(args, kwargs)
assert len(args) >= 2
return cls._to(args[0], args[1])
# return cls._to(args[0], kwargs["device"])
if func in [torch.Tensor.copy_]:
assert len(args) == 2
return cls._copy(args[0], args[1])
if func in [torch.Tensor.equal, torch.equal]:
assert len(args) == 2
return cls._equal(args[0], args[1])
if func == torch.Tensor.to_dense:
assert len(args) == 1
return cls._to_dense(args[0])
if func == torch.Tensor.detach:
x = args[0]
return cls._wrap(
x.shape,
x.__values.detach(),
x.__row_indices,
x.__row_offsets,
x.__column_indices,
x.__transp_info,
)
if func == torch.Tensor.__deepcopy__:
x = args[0]
memo = args[1]
return cls._wrap(
x.shape,
x.__values.__deepcopy__(memo),
x.__row_indices.__deepcopy__(memo),
x.__row_offsets.__deepcopy__(memo),
x.__column_indices.__deepcopy__(memo),
tuple(v.__deepcopy__(memo) for v in x.__transp_info),
)
if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]:
assert len(args) == 1
assert len(kwargs) == 0
x = args[0]
return cls._wrap(
x.shape,
x.__values.grad,
x.__row_indices,
x.__row_offsets,
x.__column_indices,
x.__transp_info,
)
if func == torch.Tensor.requires_grad_:
func(args[0].__values)
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
# TODO: check this
if func in torch.overrides.get_default_nowrap_functions():
return ret
return torch._tensor._convert(ret, cls)
return NotImplemented