def initialize()

in community-content/vertex_model_garden/model_oss/transformers/sam_handler.py [0:0]


  def initialize(self, context: Any):
    """Custom initialize."""

    # vv-docker:google3-begin(internal)
    # TODO(b/287051908): Move handler functions to common utils for
    # everyone to use.
    # vv-docker:google3-end
    properties = context.system_properties
    self.map_location = (
        "cuda"
        if torch.cuda.is_available() and properties.get("gpu_id") is not None
        else "cpu"
    )
    self.device = torch.device(
        self.map_location + ":" + str(properties.get("gpu_id"))
        if torch.cuda.is_available() and properties.get("gpu_id") is not None
        else self.map_location
    )
    self.manifest = context.manifest

    # The model id is can be either:
    # 1) a huggingface model card id, like "Salesforce/blip", or
    # 2) a GCS path to the model files, like "gs://foo/bar".
    # If it's a model card id, the model will be loaded from huggingface.
    self.model_id = (
        DEFAULT_MODEL_ID
        if os.environ.get("MODEL_ID") is None
        else os.environ["MODEL_ID"]
    )
    # Else it will be downloaded from GCS to local first.
    # Since the transformers from_pretrained API can't read from GCS.
    if self.model_id.startswith(GCS_PREFIX):
      gcs_path = self.model_id[len(GCS_PREFIX) :]
      local_model_dir = os.path.join(DOWNLOAD_DIR, gcs_path)
      logging.info(f"Download {self.model_id} to {local_model_dir}")
      download_gcs_dir(self.model_id, local_model_dir)
      self.model_id = local_model_dir

    self.task = (
        MASK_GENERATION
        if os.environ.get("TASK") is None
        else os.environ["TASK"]
    )
    logging.info(
        f"Handler initializing task:{self.task}, model:{self.model_id}"
    )

    if self.task == MASK_GENERATION:
      self.pipeline = pipeline(
          task="mask-generation", model=self.model_id, device=self.device
      )
    else:
      raise ValueError(f"Invalid TASK: {self.task}")

    self.initialized = True
    logging.info("Handler initialization done.")