def __call__()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/logits_process.py [0:0]


    def __call__(self, logits: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        if self.temperature != 1.0:
            logits = logits / self.temperature

        do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1]
        do_top_p = self.top_p < 1.0 and self.top_p > 0.0

        if do_top_k:
            sorted_indices = jnp.argsort(logits, axis=-1)[..., ::-1][:, : self.top_k]
            sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1)
        else:
            sorted_indices = jnp.argsort(logits, axis=-1)
            sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1)

        if do_top_p:
            if do_top_k:
                # logits have been sorted in descending order, so we need to flip them
                sorted_logits = jnp.flip(sorted_logits, axis=-1)
                sorted_indices = jnp.flip(sorted_indices, axis=-1)
            # We always keep the best logits and those whose cumulative probability is strictly higher than top_p
            cum_probs = jax.nn.softmax(sorted_logits, axis=-1).cumsum(axis=-1)
            keep_mask = cum_probs > (1 - self.top_p)
            keep_mask = keep_mask.at[:, -1].set(True)
            # Set rejected logits to -inf so that they are ignored in downstream comparisons
            sorted_logits = jnp.where(keep_mask, sorted_logits, float("-Inf"))

        return sorted_logits, sorted_indices