community-content/vertex_cpr_samples/torch/predictor_resnet.py (27 lines of code) (raw):

import os import torch from google.cloud.aiplatform.utils import prediction_utils from google.cloud.aiplatform.prediction.predictor import Predictor from torchvision.models import detection, resnet50, ResNet50_Weights from typing import Dict, List class ResNetPredictor(Predictor): def __init__(self): return def load(self, artifacts_uri: str) -> None: prediction_utils.download_model_artifacts(artifacts_uri) if os.path.exists("model.pth.tar"): self.model = detection.fasterrcnn_resnet50_fpn(pretrained=True) stat_dic = torch.load("model.pth.tar") self.model.load_state_dict(stat_dic['state_dict']) else: weights = ResNet50_Weights.DEFAULT self.model = resnet50(weights=weights) self.model.eval() def preprocess(self, prediction_input: dict) -> torch.Tensor: instances = prediction_input["instances"] return torch.Tensor(instances) @torch.inference_mode() def predict(self, instances: torch.Tensor) -> List[str]: return self._model(instances) def postprocess(self, prediction_results: List[str]) -> Dict: return {"predictions": prediction_results}