in text-generation-inference/server/text_generation_server/jetstream_pt_support/logits_process.py [0:0]
def __call__(self, logits: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
if self.temperature != 1.0:
logits = logits / self.temperature
do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1]
do_top_p = self.top_p < 1.0 and self.top_p > 0.0
if do_top_k:
sorted_indices = jnp.argsort(logits, axis=-1)[..., ::-1][:, : self.top_k]
sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1)
else:
sorted_indices = jnp.argsort(logits, axis=-1)
sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1)
if do_top_p:
if do_top_k:
# logits have been sorted in descending order, so we need to flip them
sorted_logits = jnp.flip(sorted_logits, axis=-1)
sorted_indices = jnp.flip(sorted_indices, axis=-1)
# We always keep the best logits and those whose cumulative probability is strictly higher than top_p
cum_probs = jax.nn.softmax(sorted_logits, axis=-1).cumsum(axis=-1)
keep_mask = cum_probs > (1 - self.top_p)
keep_mask = keep_mask.at[:, -1].set(True)
# Set rejected logits to -inf so that they are ignored in downstream comparisons
sorted_logits = jnp.where(keep_mask, sorted_logits, float("-Inf"))
return sorted_logits, sorted_indices