def generate_until()

in lm_eval/models/nemo_lm.py [0:0]


    def generate_until(self, requests):
        if not requests:
            return []
        res = []

        def get_until(req_args):
            until = req_args.get("until", [])
            until = deepcopy(until)  # prevent from modifying req_args for cache_key
            if self.tokenizer.ids_to_tokens([self.eot_token_id])[0] not in until:
                until.append(self.tokenizer.ids_to_tokens([self.eot_token_id])[0])
            return until

        def _collate(x):
            toks = self.tok_encode(x[0])
            return len(toks), x[0]

        re_ords = Collator(
            [reg.args for reg in requests], sort_fn=_collate, group_by="gen_kwargs"
        )
        chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
        for chunk in chunks:
            contexts, all_gen_kwargs = zip(*chunk)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            req_args = all_gen_kwargs[0]
            # unpack our keyword arguments.
            until = get_until(req_args)
            max_gen_toks = req_args.get("max_gen_toks", self.max_gen_toks)

            remaining_length = self.max_length - max_gen_toks
            contexts = []
            for context, _ in chunk:
                encoded_context = self.tok_encode(context)
                encoded_context = encoded_context[-remaining_length:]
                contexts.append(self.tok_decode(encoded_context))

            output = self.generate(
                self.model,
                inputs=contexts,
                tokens_to_generate=max_gen_toks,
                end_strings=until,
                greedy=True,
            )

            answers = output["sentences"]

            continuations = []
            for context, answer in zip(contexts, answers):
                continuations.append(answer[len(context) :])

            for term in until:
                continuations = [answer.split(term)[0] for answer in continuations]

            for request, answer in zip(chunk, continuations):
                self.cache_hook.add_partial("greedy_until", request, answer)
                res.append(answer)

        return re_ords.get_original(res)