docker_images/peft/app/pipelines/text_generation.py (55 lines of code) (raw):

import logging import os import torch from app import idle, timing from app.pipelines import Pipeline from huggingface_hub import model_info from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer logger = logging.getLogger(__name__) class TextGenerationPipeline(Pipeline): def __init__(self, model_id: str): use_auth_token = os.getenv("HF_API_TOKEN") model_data = model_info(model_id, token=use_auth_token) config_dict = model_data.config.get("peft") if config_dict: base_model_id = config_dict["base_model_name_or_path"] if base_model_id: self.tokenizer = AutoTokenizer.from_pretrained(base_model_id) model = AutoModelForCausalLM.from_pretrained( base_model_id, device_map="auto" ) # wrap base model with peft self.model = PeftModel.from_pretrained(model, model_id) else: raise ValueError("There's no base model ID in configuration file.") else: raise ValueError("Config file for this model does not exist or is invalid.") def __call__(self, inputs: str, **kwargs) -> str: """ Args: inputs (:obj:`str`): a string for text to be completed Returns: A string of completed text. """ if idle.UNLOAD_IDLE: with idle.request_witnesses(): self._model_to_gpu() resp = self._process_req(inputs, **kwargs) else: resp = self._process_req(inputs, **kwargs) return [{"generated_text": resp[0]}] @timing.timing def _model_to_gpu(self): if torch.cuda.is_available(): self.model.to("cuda") def _process_req(self, inputs: str, **kwargs) -> str: """ Args: inputs (:obj:`str`): a string for text to be completed Returns: A string of completed text. """ tokenized_inputs = self.tokenizer(inputs, return_tensors="pt") self._model_to_gpu() if torch.cuda.is_available(): device = "cuda" tokenized_inputs = { "input_ids": tokenized_inputs["input_ids"].to(device), "attention_mask": tokenized_inputs["attention_mask"].to(device), } with torch.no_grad(): outputs = self.model.generate( input_ids=tokenized_inputs["input_ids"], attention_mask=tokenized_inputs["attention_mask"], max_new_tokens=10, eos_token_id=3, ) return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)