community-content/vertex_model_garden/model_oss/detectron2/handler.py (115 lines of code) (raw):

"""Custom handler for Detectron2 serving.""" import io import json import os from typing import Any, List, Tuple import cv2 from detectron2.config import get_cfg from detectron2.engine import DefaultPredictor from google.cloud import storage import numpy as np import pycocotools.mask as mask_util import torch def get_bucket_and_blob_name(gcs_filepath: str) -> Tuple[str, str]: """Gets bucket and blob name from gcs path.""" # The gcs path is of the form gs://<bucket-name>/<blob-name> gs_suffix = gcs_filepath.split("gs://", 1)[1] return tuple(gs_suffix.split("/", 1)) def download_gcs_file(src_file_path: str, dst_file_path: str): """Downloads gcs-file to local folder.""" src_bucket_name, src_blob_name = get_bucket_and_blob_name(src_file_path) client = storage.Client() src_bucket = client.get_bucket(src_bucket_name) src_blob = src_bucket.blob(src_blob_name) src_blob.download_to_filename(dst_file_path) class ModelHandler: """Custom model handler for Detectron2.""" def __init__(self): self.error = None self._batch_size = 0 self.initialized = False self.predictor = None self.test_threshold = 0.5 def initialize(self, context: Any): """Initialize.""" print("context.system_properties: ", context.system_properties) print("context.manifest: ", context.manifest) self.manifest = context.manifest properties = context.system_properties # Get threshold from environment variable. # This will be set by customer. self.test_threshold = float(os.environ.get("TEST_THRESHOLD")) print("test_threshold: ", self.test_threshold) # Get model and config file location from environment variables. # These will be set by customer when doing model upload. gcs_model_file = os.environ["MODEL_PTH_FILE"] gcs_config_file = os.environ["CONFIG_YAML_FILE"] print("Copying gcs_model_file: ", gcs_model_file) print("Copying gcs_config_file: ", gcs_config_file) # Copy these files from GCS location to local file. # Note(lavrai): GCSFuse path does not seem to work here for now. model_file = "./model.pth" config_file = "./cfg.yaml" download_gcs_file(src_file_path=gcs_model_file, dst_file_path=model_file) if not os.path.exists(model_file): raise RuntimeError("Missing model_file: %s" % model_file) download_gcs_file(src_file_path=gcs_config_file, dst_file_path=config_file) if not os.path.exists(config_file): raise RuntimeError("Missing config_file: %s" % config_file) # Set up config file. cfg = get_cfg() cfg.merge_from_file(config_file) cfg.MODEL.WEIGHTS = model_file cfg.MODEL.DEVICE = ( cfg.MODEL.DEVICE + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" ) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = self.test_threshold # Build predictor from config. self.predictor = DefaultPredictor(cfg) self._batch_size = context.system_properties["batch_size"] self.initialized = True def preprocess(self, batch: List[Any]) -> List[Any]: """Preprocess raw input and return as list of images.""" print("Running pre-processing.") images = [] for request in batch: request_data = request.get("data") input_bytes = io.BytesIO(request_data) img = cv2.imdecode(np.fromstring(input_bytes.read(), np.uint8), 1) images.append(img) return images def inference(self, model_input: List[Any]) -> List[Any]: """Runs inference.""" print("Running model-inference.") return [self.predictor(image) for image in model_input] def postprocess(self, inference_result: List[Any]) -> List[Any]: """Post process inference result.""" response_list = [] print("Num inference_items are:", len(inference_result)) for inference_item in inference_result: predictions = inference_item["instances"].to("cpu") print("Predictions are:", predictions) boxes = None if predictions.has("pred_boxes"): boxes = predictions.pred_boxes.tensor.numpy().tolist() scores = None if predictions.has("scores"): scores = predictions.scores.numpy().tolist() classes = None if predictions.has("pred_classes"): classes = predictions.pred_classes.numpy().tolist() masks_rle = None if predictions.has("pred_masks"): # Do run length encoding, else the mask output becomes huge. masks_rle = [ mask_util.encode(np.asfortranarray(mask)) for mask in predictions.pred_masks ] for rle in masks_rle: rle["counts"] = rle["counts"].decode("utf-8") response = { "classes": classes, "scores": scores, "boxes": boxes, "masks_rle": masks_rle, } response_list.append(json.dumps(response)) print("response_list: ", response_list) return response_list def handle(self, data: Any, context: Any) -> List[Any]: # pylint: disable=unused-argument """Runs preprocess, inference, and post-processing.""" model_input = self.preprocess(data) model_out = self.inference(model_input) output = self.postprocess(model_out) print("Done handling input.") return output _service = ModelHandler() def handle(data: Any, context: Any) -> List[Any]: if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context)