def inference()

in community-content/vertex_model_garden/model_oss/vot/handler.py [0:0]


  def inference(self, data: Any, *args, **kwargs) -> List[Any]:
    """Runs object detection and tracking inference on a video frame by frame.

    If using yolo detection, the function uses the ultralytics yolo models for
    IOD, otherwise it uses the provided IOD endpoint and associated the selected
    tracking method to the detections.

    Args:
        data: List of video files.
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        List of video frame annotations and/or output decorated video uris.
    """
    gcs_video_files = data
    video_preds = []
    for gcs_video_file in gcs_video_files:
      results_info = {}
      temp_text_file = tempfile.NamedTemporaryFile(delete=False, mode="w+t")
      local_video_file_name, remote_video_file_name = (
          fileutils.download_video_from_gcs_to_local(gcs_video_file)
      )
      remote_text_file_name = remote_video_file_name.replace(
          "overlay.mp4", "annotations.txt"
      )

      cap = cv2.VideoCapture(local_video_file_name)
      fps = cap.get(cv2.CAP_PROP_FPS)
      temp_local_video_file_name = fileutils.get_output_video_file(
          local_video_file_name
      )
      if self.save_video_results:
        self.video_writer = iio.get_writer(
            temp_local_video_file_name,
            format="FFMPEG",
            mode="I",
            fps=float(fps),
            codec="h264",
        )
      self.tracker = BYTETracker(
          track_thresh=self.track_thresh,
          track_buffer=self.track_buffer,
          match_thresh=self.match_thresh,
          frame_rate=fps,
      )
      frame_idx = 1
      while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
          break

        dets_np = commons.get_object_detection_endpoint_predictions(
            self.detection_endpoint, frame
        )
        dets_tf = tf.convert_to_tensor(dets_np)
        online_targets = self.tracker.update(dets_tf, None)
        if online_targets.size > 0:
          frame = visualization_utils.overlay_tracking_results(
              frame_idx,
              frame,
              online_targets,
              label_map=self.label_map,
              temp_text_file_path=temp_text_file.name,
          )

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        if self.save_video_results:
          self.video_writer.append_data(frame)
        logging.info(
            "Finished processing frame %s for video %s.",
            frame_idx,
            gcs_video_file,
        )
        frame_idx += 1

      self.video_writer.close()
      cap.release()
      if self.save_video_results:
        fileutils.upload_video_from_local_to_gcs(
            self.output_bucket,
            local_video_file_name,
            remote_video_file_name,
            temp_local_video_file_name,
        )
        results_info["output_video"] = "{}/{}".format(
            self.output_bucket, remote_video_file_name
        )

      fileutils.release_text_assets(
          self.output_bucket,
          temp_text_file.name,
          remote_text_file_name,
      )
      results_info["annotations"] = "{}/{}".format(
          self.output_bucket, remote_text_file_name
      )
      video_preds.append(results_info)

    return video_preds