def initialize()

in community-content/vertex_model_garden/model_oss/timm/handler.py [0:0]


  def initialize(self, context: Any):
    """Custom initialize."""

    properties = context.system_properties
    self.map_location = (
        "cuda"
        if torch.cuda.is_available() and properties.get("gpu_id") is not None
        else "cpu"
    )
    self.device = torch.device(
        self.map_location + ":" + str(properties.get("gpu_id"))
        if torch.cuda.is_available() and properties.get("gpu_id") is not None
        else self.map_location
    )
    self.manifest = context.manifest

    # Load timm model by model name.
    self.model_name = os.environ["MODEL_NAME"]
    # Whether to use timm pretrained weights, MODEL_PT_PATH overrides this.
    timm_pretrained = True if os.environ.get("TIMM_PRETRAINED") else False
    # Load custom checkpoint, it overrides TIMM_PRETRAINED model.
    self.model_pt_path = os.environ.get("MODEL_PT_PATH")
    if self.model_pt_path and self.model_pt_path.startswith(GCS_PREFIX):
      self.model_pt_path = download_gcs_file(self.model_pt_path, DOWNLOAD_DIR)

    if self.model_pt_path and self.model_pt_path.endswith(".pt"):
      logging.info(
          "Load model with .pt in jit mode, not working for all timm models"
          " yet."
      )
      self.model = self._load_torchscript_model(self.model_pt_path)
    else:
      logging.info("Load model with .pth in eager mode.")
      self.model = timm.create_model(
          self.model_name, pretrained=timm_pretrained
      )
      if self.model_pt_path and (
          self.model_pt_path.endswith(".pth")
          or self.model_pt_path.endswith(".pth.tar")
      ):
        checkpoint = torch.load(self.model_pt_path, map_location=self.device)
        state_dict = checkpoint["state_dict"]
        self.model.load_state_dict(state_dict)
      self.model.to(self.device)
    self.model.eval()

    mapping_file_path = os.environ.get("INDEX_TO_NAME_FILE")
    if mapping_file_path:
      if mapping_file_path.startswith(GCS_PREFIX):
        mapping_file_path = download_gcs_file(mapping_file_path, DOWNLOAD_DIR)
      self.mapping = load_label_mapping(mapping_file_path)

    self.initialized = True