in community-content/vertex_model_garden/model_oss/timm/handler.py [0:0]
def initialize(self, context: Any):
"""Custom initialize."""
properties = context.system_properties
self.map_location = (
"cuda"
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
self.manifest = context.manifest
# Load timm model by model name.
self.model_name = os.environ["MODEL_NAME"]
# Whether to use timm pretrained weights, MODEL_PT_PATH overrides this.
timm_pretrained = True if os.environ.get("TIMM_PRETRAINED") else False
# Load custom checkpoint, it overrides TIMM_PRETRAINED model.
self.model_pt_path = os.environ.get("MODEL_PT_PATH")
if self.model_pt_path and self.model_pt_path.startswith(GCS_PREFIX):
self.model_pt_path = download_gcs_file(self.model_pt_path, DOWNLOAD_DIR)
if self.model_pt_path and self.model_pt_path.endswith(".pt"):
logging.info(
"Load model with .pt in jit mode, not working for all timm models"
" yet."
)
self.model = self._load_torchscript_model(self.model_pt_path)
else:
logging.info("Load model with .pth in eager mode.")
self.model = timm.create_model(
self.model_name, pretrained=timm_pretrained
)
if self.model_pt_path and (
self.model_pt_path.endswith(".pth")
or self.model_pt_path.endswith(".pth.tar")
):
checkpoint = torch.load(self.model_pt_path, map_location=self.device)
state_dict = checkpoint["state_dict"]
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
mapping_file_path = os.environ.get("INDEX_TO_NAME_FILE")
if mapping_file_path:
if mapping_file_path.startswith(GCS_PREFIX):
mapping_file_path = download_gcs_file(mapping_file_path, DOWNLOAD_DIR)
self.mapping = load_label_mapping(mapping_file_path)
self.initialized = True