def cumsum()

in optimum/neuron/models/inference/backend/modules/generation/sampling.py [0:0]


def cumsum(tensor_in, dim, on_cpu: bool = False):
    if on_cpu:
        logger.debug("On CPU, using torch cumsum")
        return torch.cumsum(tensor_in, dim=dim)
    init_shape_len = len(tensor_in.shape)
    cumsum_dim = dim % init_shape_len
    last_dim = init_shape_len - 1
    is_transposed = False
    if cumsum_dim != last_dim:
        tensor_in = torch.transpose(tensor_in, cumsum_dim, last_dim)
        is_transposed = True
    init_shape = tensor_in.shape
    cumsum_len = init_shape[last_dim]
    # Prioritize nki kernel for float dtype, then matmul cumsum if not input is not float
    if torch.is_floating_point(tensor_in):
        logger.debug("Using NKI cumsum")
        tensor_in = tensor_in.view(-1, cumsum_len)
        nki_cumsum_func = nki_jit()(nki_cumsum)
        output = torch.zeros_like(tensor_in, device=tensor_in.device, dtype=tensor_in.dtype)
        nki_cumsum_func(tensor_in, output, axis=1)
        output = output.view(init_shape)
        if is_transposed:
            output = torch.transpose(output, cumsum_dim, last_dim)
        return output
    else:
        logger.debug("Using matmul cumsum")
        triu = torch.triu(
            torch.ones(
                cumsum_len,
                cumsum_len,
                dtype=tensor_in.dtype,
                device=tensor_in.device,
            )
        )
        output = tensor_in @ triu
        if is_transposed:
            output = torch.transpose(output, cumsum_dim, last_dim)
        return output