in text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py [0:0]
def _sample(self, scores: jnp.ndarray, key) -> jnp.ndarray:
do_top_k = self.logits_warper.top_k > 0 and self.logits_warper.top_k < scores.shape[-1]
do_top_p = self.logits_warper.top_p < 1.0 and self.logits_warper.top_p > 0.0
if do_top_k:
return sampling_utils.sample_topk_logits(
scores,
self.logits_warper.top_k,
self.logits_warper.temperature,
key,
)
elif do_top_p:
return sampling_utils.sample_nucleus_topp_logits(
scores,
self.logits_warper.top_p,
self.logits_warper.temperature,
key,
)
return jax.random.categorical(key, scores / self.logits_warper.temperature)