lmms_eval/api/model.py (97 lines of code) (raw):

import abc import os from typing import Union, List, Tuple, Optional, Type, TypeVar from sqlitedict import SqliteDict import json import hashlib from lmms_eval.api.instance import Instance from tqdm import tqdm from lmms_eval import utils import logging eval_logger = logging.getLogger("lmms-eval") T = TypeVar("T", bound="lmms") class lmms(abc.ABC): def __init__(self) -> None: """Defines the interface that should be implemented by all lmms subclasses. lmmss are assumed to take image-text as input and yield strings as output (inputs/outputs should be tokenization-agnostic.) """ # set rank and world size to a single process, by default. self._rank = 0 self._world_size = 1 self.cache_hook = CacheHook(None) self.task_dict = {} @abc.abstractmethod def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: """Compute log-likelihood of generating a continuation from a context. Downstream tasks should attempt to use loglikelihood instead of other LMM calls whenever possible. :param requests: list[Instance] A list of Instance objects, with property `args` which returns a tuple (context, continuation). `context: str` Context string. Implementations of LMM must be able to handle an empty context string. `continuation: str` The continuation over which log likelihood will be calculated. If there is a word boundary, the space should be in the continuation. For example, context="hello" continuation=" world" is correct. 'visual_list: list[dict]' Visual input to the model. Can be None. :return: list[tuple[float, bool]] A list of pairs (logprob, isgreedy) `logprob: float` The log probability of `continuation`. `isgreedy`: Whether `continuation` would be generated by greedy sampling from `context`. """ pass # TODO: Add an optional max length @abc.abstractmethod def generate_until(self, requests) -> List[str]: """Generate greedily until a stopping sequence :param requests: list[Instance] A list of Instance objects with property `args` which returns a tuple (context, until). context: str Context string generation_kwargs: dict Generation Kwargs 'visual_list: list[dict]' Visual input to the model. Can be None. :return: list[str] A list of strings continuation continuation: str The generated continuation. """ pass @classmethod def create_from_arg_string(cls: Type[T], arg_string: str, additional_config: Optional[dict] = None) -> T: """ Creates an instance of the LMM class using the given argument string and additional config. Parameters: - arg_string: A string containing arguments in the format key1=value1,key2=value2. - additional_config: Optional dictionary containing additional configuration parameters. Returns: - Instance of the LMM class. """ additional_config = {} if additional_config is None else additional_config args = utils.simple_parse_args_string(arg_string) args2 = {k: v for k, v in additional_config.items() if v is not None} return cls(**args, **args2) @property def rank(self): # used in the case of parallelism. Hardcoded to # ensure no errors arise using API models which do # not support multi-device parallelism nor expect it. return self._rank @property def world_size(self): # used in the case of parallelism. Hardcoded to # ensure no errors arise using API models which do # not support multi-device parallelism nor expect it. return self._world_size def set_cache_hook(self, cache_hook) -> None: self.cache_hook = cache_hook ### SQLite-based caching of LMM responses def hash_args(attr, args): dat = json.dumps([attr] + list(args)) return hashlib.sha256(dat.encode("utf-8")).hexdigest() class CacheHook: def __init__(self, cachinglm) -> None: if cachinglm is None: self.dbdict = None return self.dbdict = cachinglm.dbdict def add_partial(self, attr, req, res) -> None: if self.dbdict is None: return hsh = hash_args(attr, req) self.dbdict[hsh] = res class CachingLMM: def __init__(self, lm, cache_db) -> None: """LMM wrapper that returns cached results if they exist, and uses the underlying LMM if not. :param lm: LMM Underlying LMM :param cache_db: str Path to cache db """ self.lm = lm self.cache_db = cache_db if os.path.dirname(cache_db): os.makedirs(os.path.dirname(cache_db), exist_ok=True) self.dbdict = SqliteDict(cache_db, autocommit=True) # add hook to lm lm.set_cache_hook(self.get_cache_hook()) def __getattr__(self, attr): lm_attr = getattr(self.lm, attr) if not callable(lm_attr): return lm_attr def fn(requests): res = [] remaining_reqs = [] warned = False # figure out which ones are cached and which ones are new eval_logger.info(f"Loading '{attr}' responses from cache '{self.cache_db}' where possible...") for req in tqdm(requests): hsh = hash_args(attr, req.args) if attr == "generate_until" and req.args[1].get("do_sample", False): # when we are doing non-greedy generation, don't use the cache # (else every "randomly sampled" generation would be identical for repeats > 1). if not warned: eval_logger.warning(f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests.") warned = True res.append(None) remaining_reqs.append(req) elif hsh in self.dbdict: ob = self.dbdict[hsh] assert ob is not None res.append(ob) else: res.append(None) remaining_reqs.append(req) # actually run the LMM on the requests that do not have cached results rem_res = getattr(self.lm, attr)(remaining_reqs) # stick the new ones back into the list and also cache any of the new ones resptr = 0 for req, r in zip(remaining_reqs, rem_res): while res[resptr] is not None: resptr += 1 res[resptr] = r # caching hsh = hash_args(attr, req.args) self.dbdict[hsh] = r self.dbdict.commit() return res return fn def get_cache_hook(self): return CacheHook(self)