in community-content/vertex_model_garden/model_oss/vot/handler.py [0:0]
def inference(self, data: Any, *args, **kwargs) -> List[Any]:
"""Runs object detection and tracking inference on a video frame by frame.
If using yolo detection, the function uses the ultralytics yolo models for
IOD, otherwise it uses the provided IOD endpoint and associated the selected
tracking method to the detections.
Args:
data: List of video files.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
List of video frame annotations and/or output decorated video uris.
"""
gcs_video_files = data
video_preds = []
for gcs_video_file in gcs_video_files:
results_info = {}
temp_text_file = tempfile.NamedTemporaryFile(delete=False, mode="w+t")
local_video_file_name, remote_video_file_name = (
fileutils.download_video_from_gcs_to_local(gcs_video_file)
)
remote_text_file_name = remote_video_file_name.replace(
"overlay.mp4", "annotations.txt"
)
cap = cv2.VideoCapture(local_video_file_name)
fps = cap.get(cv2.CAP_PROP_FPS)
temp_local_video_file_name = fileutils.get_output_video_file(
local_video_file_name
)
if self.save_video_results:
self.video_writer = iio.get_writer(
temp_local_video_file_name,
format="FFMPEG",
mode="I",
fps=float(fps),
codec="h264",
)
self.tracker = BYTETracker(
track_thresh=self.track_thresh,
track_buffer=self.track_buffer,
match_thresh=self.match_thresh,
frame_rate=fps,
)
frame_idx = 1
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
dets_np = commons.get_object_detection_endpoint_predictions(
self.detection_endpoint, frame
)
dets_tf = tf.convert_to_tensor(dets_np)
online_targets = self.tracker.update(dets_tf, None)
if online_targets.size > 0:
frame = visualization_utils.overlay_tracking_results(
frame_idx,
frame,
online_targets,
label_map=self.label_map,
temp_text_file_path=temp_text_file.name,
)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if self.save_video_results:
self.video_writer.append_data(frame)
logging.info(
"Finished processing frame %s for video %s.",
frame_idx,
gcs_video_file,
)
frame_idx += 1
self.video_writer.close()
cap.release()
if self.save_video_results:
fileutils.upload_video_from_local_to_gcs(
self.output_bucket,
local_video_file_name,
remote_video_file_name,
temp_local_video_file_name,
)
results_info["output_video"] = "{}/{}".format(
self.output_bucket, remote_video_file_name
)
fileutils.release_text_assets(
self.output_bucket,
temp_text_file.name,
remote_text_file_name,
)
results_info["annotations"] = "{}/{}".format(
self.output_bucket, remote_text_file_name
)
video_preds.append(results_info)
return video_preds