def _sample()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py [0:0]


    def _sample(self, scores: jnp.ndarray, key) -> jnp.ndarray:
        do_top_k = self.logits_warper.top_k > 0 and self.logits_warper.top_k < scores.shape[-1]
        do_top_p = self.logits_warper.top_p < 1.0 and self.logits_warper.top_p > 0.0

        if do_top_k:
            return sampling_utils.sample_topk_logits(
                scores,
                self.logits_warper.top_k,
                self.logits_warper.temperature,
                key,
            )
        elif do_top_p:
            return sampling_utils.sample_nucleus_topp_logits(
                scores,
                self.logits_warper.top_p,
                self.logits_warper.temperature,
                key,
            )

        return jax.random.categorical(key, scores / self.logits_warper.temperature)