lmms_eval/models/qwen_vl.py (227 lines of code) (raw):

import torch import logging from tqdm import tqdm from lmms_eval import utils from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model from lmms_eval.models.model_utils.qwen.qwen_generate_utils import make_context from accelerate import Accelerator, DistributedType from typing import List, Optional, Union, Tuple import uuid import os import warnings warnings.simplefilter("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore") eval_logger = logging.getLogger("lmms-eval") from transformers import AutoModelForCausalLM, AutoTokenizer @register_model("qwen_vl") class Qwen_VL(lmms): """ Qwen_VL Model https://github.com/QwenLM/Qwen-VL/blob/master/eval_mm/evaluate_vqa.py """ def __init__( self, pretrained: str = "Qwen/Qwen-VL", device: Optional[str] = "cuda", batch_size: Optional[Union[int, str]] = 1, trust_remote_code: Optional[bool] = True, use_cache=True, **kwargs, ) -> None: super().__init__() # Do not use kwargs for now assert kwargs == {}, f"Unexpected kwargs: {kwargs}" accelerator = Accelerator() if accelerator.num_processes > 1: self._device = torch.device(f"cuda:{accelerator.local_process_index}") else: self._device = device self._model = AutoModelForCausalLM.from_pretrained(pretrained, device_map=self._device, trust_remote_code=trust_remote_code).eval() self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code) self.tokenizer.padding_side = "left" self.tokenizer.pad_token_id = self.tokenizer.eod_id self.prompt = "<img>{}</img>{}" self._config = self._model.config self.model.tie_weights() self.batch_size_per_gpu = int(batch_size) self.use_cache = use_cache if accelerator.num_processes > 1: assert accelerator.distributed_type in [ DistributedType.FSDP, DistributedType.MULTI_GPU, ], "Unsupported distributed type provided. Only DDP and FSDP are supported." if accelerator.distributed_type == DistributedType.FSDP: self._model = accelerator.prepare(self.model) else: self._model = accelerator.prepare_model(self.model, evaluation_mode=True) self.accelerator = accelerator if self.accelerator.is_local_main_process: eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes else: self.model.to(self._device) self._rank = 0 self._word_size = 1 @property def config(self): # return the associated transformers.AutoConfig for the given pretrained model. return self._config @property def tokenizer(self): return self._tokenizer @property def model(self): # returns the model, unwrapping it if using Accelerate if hasattr(self, "accelerator"): return self.accelerator.unwrap_model(self._model) else: return self._model @property def eot_token_id(self): # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* return self.tokenizer.eod_id @property def max_length(self): return self._max_length # should be deleted since max_new_tokens is decided by gen_kwargs not a model property # @property # def max_new_tokens(self) -> int: # return 256 @property def batch_size(self): return self.batch_size_per_gpu @property def device(self): return self._device @property def rank(self): return self._rank @property def world_size(self): return self._world_size def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: res = [] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: # encode, pad, and truncate contexts for this batch if type(doc_to_target) == str: continuation = doc_to_target else: continuation = doc_to_target(self.task_dict[task][split][doc_id]) visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] visuals = self.flatten(visuals) query = [] visual_paths = [] for visual in visuals: name = uuid.uuid4().hex.upper()[0:6] visual.save(f"/tmp/{name}.png") visual_paths.append(f"/tmp/{name}.png") query.append({"image": f"/tmp/{name}.png"}) # Make a copy for query to save context (text that needs to be masked) context_query = [_ for _ in query] context_query.append({"text": contexts}) query.append({"text": contexts + continuation}) context_query = self.tokenizer.from_list_format(context_query) query = self.tokenizer.from_list_format(query) raw_contxt_text, context_tokens = make_context( self.tokenizer, context_query, history=None, system="You are a helpful assistant", max_window_size=self.model.generation_config.max_window_size, chat_format=self.model.generation_config.chat_format ) context_tokens = torch.tensor([context_tokens]) raw_continuation_text, continuation_tokens = make_context( self.tokenizer, query, history=None, system="You are a helpful assistant", max_window_size=self.model.generation_config.max_window_size, chat_format=self.model.generation_config.chat_format ) continuation_tokens = torch.tensor([continuation_tokens]).to(self.model.device) attn_mask = torch.ones_like(continuation_tokens).to(self.model.device) labels = continuation_tokens.clone().to(self.model.device) labels[:, : context_tokens.shape[1]] = -100 with torch.inference_mode(): outputs = self.model(input_ids=continuation_tokens, labels=labels, attention_mask=attn_mask) loss = outputs.loss logits = outputs["logits"] greedy_tokens = logits.argmax(dim=-1) cont_toks = continuation_tokens[:, context_tokens.shape[1] :] greedy_tokens = greedy_tokens[:, context_tokens.shape[1] : continuation_tokens.shape[1]] # [1, seq] max_equal = (greedy_tokens == cont_toks).all() res.append((float(loss.item()), bool(max_equal))) pbar.update(1) pbar.close() return res def flatten(self, input): new_list = [] for i in input: for j in i: new_list.append(j) return new_list def generate_until(self, requests: List[Instance]) -> List[str]: res = [] def _collate(x): # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = self.tokenizer.encode(x[0]) return -len(toks), x[0] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) for chunk in chunks: contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) task = task[0] split = split[0] visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] visuals = self.flatten(visuals) visual_paths = [] # save images to /tmp, name generated by hash function # qwen accept image path. Have to do it here.... for visual in visuals: name = uuid.uuid4().hex.upper()[0:6] visual.save(f"/tmp/{name}.png") visual_paths.append(f"/tmp/{name}.png") # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] # Set default values for until and max_new_tokens until = [self.tokenizer.decode(self.eot_token_id)] # Update values from gen_kwargs if present if "until" in gen_kwargs: until = gen_kwargs.pop("until") if isinstance(until, str): until = [until] elif not isinstance(until, list): raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") for i in range(len(contexts)): if "<image>" in contexts[i]: context[i] = contexts[i].replace("<image>", "") questions = [self.prompt.format(visual_path, context) for visual_path, context in zip(visual_paths, contexts)] # Similar to llava, is visual paths has len 0 # Then nothing will be executed query = [] for visual_path, context in zip(visual_paths, contexts): query.append({"image": visual_path}) query.append({"text": context}) if len(visual_paths) == 0: for context in contexts: query.append({"text": context}) questions = self.tokenizer.from_list_format(query) input_ids = self.tokenizer(questions, return_tensors="pt", padding="longest") # preconfigure gen_kwargs with defaults if "image_sizes" not in gen_kwargs: try: gen_kwargs["image_sizes"] = [visuals[0].size] except: gen_kwargs["image_sizes"] = None if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 if "temperature" not in gen_kwargs: gen_kwargs["temperature"] = 0 if "top_p" not in gen_kwargs: gen_kwargs["top_p"] = None if "num_beams" not in gen_kwargs: gen_kwargs["num_beams"] = 1 pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eod_id cont = self.model.generate( input_ids.input_ids.to(self.device), attention_mask=input_ids.attention_mask.to(self.device), eos_token_id=self.tokenizer.eod_id, pad_token_id=pad_token_id, do_sample=True if gen_kwargs["temperature"] > 0 else False, temperature=gen_kwargs["temperature"], top_p=gen_kwargs["top_p"], num_beams=gen_kwargs["num_beams"], max_new_tokens=gen_kwargs["max_new_tokens"], use_cache=self.use_cache, # kwargs=gen_kwargs ) cont_toks_list = cont.tolist() for cont_toks, context in zip(cont_toks_list, contexts): # discard context + left-padding toks if using causal decoder-only LMM cont_toks = cont_toks[input_ids.input_ids.shape[1] :] text_outputs = self.tokenizer.decode(cont_toks, skip_special_tokens=True).strip() for term in until: if len(term) > 0: # ignore '' separator, # for seq2seq case where self.tok_decode(self.eot_token_id) = '' text_outputs = text_outputs.split(term)[0] res.append(text_outputs) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) # remove visuals from tmp for visual_path in visual_paths: try: os.remove(visual_path) except: pass pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) pbar.close() return res