in lm_eval/models/openai_completions.py [0:0]
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
res = defaultdict(list)
re_ords = {}
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for key, re_ord in re_ords.items():
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]
gen_kwargs = all_gen_kwargs[0]
until = None
if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict):
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}"
)
kwargs["stop"] = until
kwargs["max_tokens"] = kwargs.pop("max_gen_toks", self.max_gen_toks)
else:
raise ValueError(
f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
)
response = oa_completion(
client=self.client,
chat=True,
messages=inps,
model=self.model,
**kwargs,
)
for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content
if until is not None:
for term in until:
if len(term) > 0:
s = s.split(term)[0]
res[key].append(s)
self.cache_hook.add_partial(
"generate_until", (context, {"until": until}), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)