def initialize()

in community-content/vertex_model_garden/model_oss/peft/handler.py [0:0]


  def initialize(self, context: Any):
    """Initializes the handler."""
    logging.info("Start to initialize the PEFT handler.")
    properties = context.system_properties
    self.map_location = (
        "cuda"
        if torch.cuda.is_available() and properties.get("gpu_id") is not None
        else "cpu"
    )

    self.device = torch.device(
        self.map_location + ":" + str(properties.get("gpu_id"))
        if torch.cuda.is_available() and properties.get("gpu_id") is not None
        else self.map_location
    )
    self.manifest = context.manifest
    self.precision_mode = os.environ.get(
        "PRECISION_LOADING_MODE", constants.PRECISION_MODE_16
    )
    self.task = os.environ.get("TASK", CAUSAL_LANGUAGE_MODELING_LORA)
    trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", None)
    if trust_remote_code == "false":
      self.trust_remote_code = False
    else:
      self.trust_remote_code = True

    # If present, the path of the model in the container.
    aip_storage_dir = os.environ.get("AIP_STORAGE_DIR", None)

    # If present, the URI of the model in a google owned GCS bucket.
    aip_storage_uri = os.environ.get("AIP_STORAGE_URI", None)

    model_id = os.environ.get("MODEL_ID", None)
    base_model_id = os.environ.get("BASE_MODEL_ID", None)

    self.model_id = None
    if aip_storage_dir:
      self.model_id = aip_storage_dir
      logging.info(f"Loaded base model from AIP_STORAGE_DIR: {self.model_id}.")
    elif aip_storage_uri:
      self.model_id = aip_storage_uri
      logging.info(f"Loaded base model from AIP_STORAGE_URI: {self.model_id}.")
    elif model_id:
      self.model_id = model_id
      logging.info(f"Loaded base model from MODEL_ID: {self.model_id}.")
    elif base_model_id:
      # Note: BASE_MODEL_ID has been unified with MODEL_ID.
      # MODEL_ID should be used whenever possible.
      self.model_id = base_model_id
      logging.info(f"Loaded base model from BASE_MODEL_ID: {self.model_id}.")

    self.quantization = os.environ.get("QUANTIZATION", None)

    if not self.model_id:
      raise ValueError("Base model id is must be set.")
    if fileutils.is_gcs_path(self.model_id):
      fileutils.download_gcs_dir_to_local(
          self.model_id,
          constants.LOCAL_BASE_MODEL_DIR,
          skip_hf_model_bin=True,
      )
      self.model_id = constants.LOCAL_BASE_MODEL_DIR
    self.finetuned_lora_model_path = os.environ.get(
        "FINETUNED_LORA_MODEL_PATH", ""
    )
    if fileutils.is_gcs_path(self.finetuned_lora_model_path):
      fileutils.download_gcs_dir_to_local(
          self.finetuned_lora_model_path, constants.LOCAL_MODEL_DIR
      )
      self.finetuned_lora_model_path = constants.LOCAL_MODEL_DIR

    logging.info(
        f"Using task:{self.task}, base model:{self.model_id}, lora model:"
        f" {self.finetuned_lora_model_path}, precision {self.precision_mode}."
    )

    self.pipeline = None
    self.model = None
    self.tokenizer = None
    start_time = time.perf_counter()
    logging.info("Started PEFT handler initialization at: %s", start_time)
    if self.task == TEXT_TO_IMAGE_LORA:
      pipeline = StableDiffusionPipeline.from_pretrained(
          self.model_id, torch_dtype=torch.float16
      )
      logging.debug("Initialized the base model for text to image.")
      pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
          pipeline.scheduler.config
      )
      logging.debug("Initialized the scheduler for text to image.")
      # This is to reduce GPU memory requirements.
      pipeline.enable_xformers_memory_efficient_attention()
      pipeline = pipeline.to(self.map_location)
      # Reduces memory footprint.
      pipeline.enable_attention_slicing()
      if self.finetuned_lora_model_path:
        pipeline.load_lora_weights(self.finetuned_lora_model_path)
      logging.debug("Initialized the LoRA model for text to image.")
      self.pipeline = pipeline
      logging.info("Initialized the text to image pipelines.")
    elif self.task == SEQUENCE_CLASSIFICATION_LORA:
      tokenizer = AutoTokenizer.from_pretrained(self.model_id)
      logging.debug("Initialized the tokenizer for sequence classification.")
      model = AutoModelForSequenceClassification.from_pretrained(
          self.model_id, torch_dtype=torch.float16
      )
      logging.debug("Initialized the base model for sequence classification.")
      if self.finetuned_lora_model_path:
        model = PeftModel.from_pretrained(model, self.finetuned_lora_model_path)
        logging.debug("Initialized the LoRA model for sequence classification.")
      model.to(self.map_location)
      self.model = model
      self.tokenizer = tokenizer
    elif (
        self.task == CAUSAL_LANGUAGE_MODELING_LORA or self.task == INSTRUCT_LORA
    ):
      tokenizer = AutoTokenizer.from_pretrained(
          self.model_id,
          trust_remote_code=self.trust_remote_code,
      )

      logging.debug("Initialized the tokenizer.")
      if self.task == CAUSAL_LANGUAGE_MODELING_LORA:
        if self.quantization == constants.AWQ:
          model = AutoAWQForCausalLM.from_quantized(
              self.model_id,
              trust_remote_code=self.trust_remote_code,
          )
        elif self.quantization == constants.GPTQ or not self.quantization:
          if self.precision_mode == constants.PRECISION_MODE_32:
            model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                return_dict=True,
                torch_dtype=torch.float32,
                device_map="auto",
                trust_remote_code=self.trust_remote_code,
            )
          elif self.precision_mode == constants.PRECISION_MODE_16B:
            model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                return_dict=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=self.trust_remote_code,
            )
          elif self.precision_mode == constants.PRECISION_MODE_16:
            model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                return_dict=True,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=self.trust_remote_code,
            )
          elif self.precision_mode == constants.PRECISION_MODE_8:
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True, int8_threshold=0
            )
            model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                return_dict=True,
                torch_dtype=torch.float16,
                device_map="auto",
                quantization_config=quantization_config,
                trust_remote_code=self.trust_remote_code,
            )
          else:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
            )
            model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                return_dict=True,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                quantization_config=quantization_config,
                trust_remote_code=self.trust_remote_code,
            )
        else:
          raise ValueError(f"Invalid QUANTIZATION value: {self.quantization}")
      else:
        try:
          model = AutoModelForCausalLM.from_pretrained(
              self.model_id,
              torch_dtype=torch.bfloat16,
              trust_remote_code=self.trust_remote_code,
              device_map="auto",
          )
        except:  # pylint: disable=bare-except
          model = AutoModelForCausalLM.from_pretrained(
              self.model_id,
              torch_dtype=torch.bfloat16,
              trust_remote_code=self.trust_remote_code,
              device_map="auto",
          )
      logging.debug("Initialized the base model.")
      if self.finetuned_lora_model_path:
        model = PeftModel.from_pretrained(model, self.finetuned_lora_model_path)
        logging.debug("Initialized the LoRA model.")
      pipeline = transformers.pipeline(
          "text-generation",
          model=model,
          tokenizer=tokenizer,
      )
      self.tokenizer = tokenizer
      self.pipeline = pipeline
    else:
      raise ValueError(f"Invalid TASK: {self.task}")

    self.initialized = True
    end_time = time.perf_counter()
    logging.info("The PEFT handler was initialize at: %s", end_time)
    logging.info("Handler initiation took %s seconds", end_time - start_time)