in optimum/neuron/models/inference/backend/modules/generation/sampling.py [0:0]
def cumsum(tensor_in, dim, on_cpu: bool = False):
if on_cpu:
logger.debug("On CPU, using torch cumsum")
return torch.cumsum(tensor_in, dim=dim)
init_shape_len = len(tensor_in.shape)
cumsum_dim = dim % init_shape_len
last_dim = init_shape_len - 1
is_transposed = False
if cumsum_dim != last_dim:
tensor_in = torch.transpose(tensor_in, cumsum_dim, last_dim)
is_transposed = True
init_shape = tensor_in.shape
cumsum_len = init_shape[last_dim]
# Prioritize nki kernel for float dtype, then matmul cumsum if not input is not float
if torch.is_floating_point(tensor_in):
logger.debug("Using NKI cumsum")
tensor_in = tensor_in.view(-1, cumsum_len)
nki_cumsum_func = nki_jit()(nki_cumsum)
output = torch.zeros_like(tensor_in, device=tensor_in.device, dtype=tensor_in.dtype)
nki_cumsum_func(tensor_in, output, axis=1)
output = output.view(init_shape)
if is_transposed:
output = torch.transpose(output, cumsum_dim, last_dim)
return output
else:
logger.debug("Using matmul cumsum")
triu = torch.triu(
torch.ones(
cumsum_len,
cumsum_len,
dtype=tensor_in.dtype,
device=tensor_in.device,
)
)
output = tensor_in @ triu
if is_transposed:
output = torch.transpose(output, cumsum_dim, last_dim)
return output