def initialize()

in community-content/vertex_model_garden/model_oss/transformers/handler.py [0:0]


  def initialize(self, context: Any):
    """Custom initialize."""

    properties = context.system_properties
    self.map_location = (
        "cuda"
        if torch.cuda.is_available() and properties.get("gpu_id") is not None
        else "cpu"
    )
    self.device = torch.device(
        self.map_location + ":" + str(properties.get("gpu_id"))
        if torch.cuda.is_available() and properties.get("gpu_id") is not None
        else self.map_location
    )

    self.manifest = context.manifest
    # The model id is can be either:
    # 1) a huggingface model card id, like "Salesforce/blip", or
    # 2) a GCS path to the model files, like "gs://foo/bar".
    # If it's a model card id, the model will be loaded from huggingface.
    self.model_id = (
        DEFAULT_MODEL_ID
        if os.environ.get("MODEL_ID") is None
        else os.environ["MODEL_ID"]
    )
    # Else it will be downloaded from GCS to local first.
    # Since the transformers from_pretrained API can't read from GCS.
    if self.model_id.startswith(constants.GCS_URI_PREFIX):
      gcs_path = self.model_id[len(constants.GCS_URI_PREFIX) :]
      local_model_dir = os.path.join(constants.LOCAL_MODEL_DIR, gcs_path)
      logging.info("Download %s to %s", self.model_id, local_model_dir)
      fileutils.download_gcs_dir_to_local(self.model_id, local_model_dir)
      self.model_id = local_model_dir

    self.task = (
        ZERO_CLASSIFICATION
        if os.environ.get("TASK") is None
        else os.environ["TASK"]
    )
    logging.info(
        "Handler initializing task:%s, model:%s", self.task, self.model_id
    )

    if SALESFORCE_BLIP in self.model_id:
      # pipeline() hasn't been ready for Salesforce/blip models.
      self.salesforce_blip = True
      self._create_blip_model()
    else:
      self.salesforce_blip = False
      if self.task == FEATURE_EMBEDDING:
        self.model = CLIPModel.from_pretrained(self.model_id).to(
            self.map_location
        )
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        self.processor = AutoProcessor.from_pretrained(self.model_id)
      elif self.task == SUMMARIZATION and FLAN_T5 in self.model_id:
        self.pipeline = pipeline(
            task=self.task,
            model=self.model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
      else:
        self.pipeline = pipeline(
            task=self.task, model=self.model_id, device=self.device
        )

    self.initialized = True
    logging.info("Handler initialization done.")