scripts/render/worker.py (368 lines of code) (raw):

#!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Consumption side of render message queue. Provides the interface for performing computations based on messages received from the queue. Running a worker node only has semantic meaning in the context of subscribing to a master running a queue, but once it establishes connection, it will continue to poll for messages until explicitly terminated or the connection is closed. Example: To run a single worker node subscribed to 192.168.1.100: $ python worker.py \ --master=192.168.1.100 Attributes: FLAGS (absl.flags._flagvalues.FlagValues): Globally defined flags for worker.py. """ import functools import json import os import shutil import sys import threading from copy import copy import pika from absl import app, flags dir_scripts = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) dir_root = os.path.dirname(dir_scripts) sys.path.append(dir_root) sys.path.append(os.path.join(dir_scripts, "util")) import config from network import ( copy_image_level, download, download_image_type, download_image_types, download_rig, get_cameras, get_frame_fns, get_frame_name, get_frame_range, local_image_type_path, local_rig_path, remote_image_type_path, upload, upload_image_type, ) from resize import resize_frames from scripts.render.network import Address from scripts.util.system_util import run_command from setup import bin_to_flags FLAGS = flags.FLAGS def _run_bin(msg): """Runs the binary associated with the message. The execution assumes the worker is running in a configured Docker container. Args: msg (dict[str, str]): Message received from RabbitMQ publisher. """ msg_cp = copy(msg) # The binary flag convention includes the "last" frame msg_cp["last"] = get_frame_name(int(msg["last"])) app_name = msg_cp["app"].split(":")[0] relevant_flags = [flag["name"] for flag in bin_to_flags[app_name]] cmd_flags = " ".join( [ f"--{flag}={msg_cp[flag]}" for flag in relevant_flags if flag in msg_cp and msg_cp[flag] != "" ] ) # Order is determined to prevent substrings from being accidentally replaced input_root = msg_cp["input_root"].rstrip("/") output_root = msg_cp["output_root"].rstrip("/") root_order = ( [output_root, input_root] if input_root in output_root else [input_root, output_root] ) root_to_docker = { input_root: config.DOCKER_INPUT_ROOT, output_root: config.DOCKER_OUTPUT_ROOT, } for root in root_order: if not os.path.exists(root): cmd_flags = cmd_flags.replace(root, root_to_docker[root]) bin_path = os.path.join(config.DOCKER_BUILD_ROOT, "bin", app_name) cmd = f"GLOG_alsologtostderr=1 GLOG_stderrthreshold=0 {bin_path} {cmd_flags}" run_command(cmd) def _clean_worker(ran_download, ran_upload): """Deletes any files that were downloaded or uploaded. Args: ran_download (bool): Whether or not a download was performed. ran_upload (bool): Whether or not an upload was performed. """ if ran_download and os.path.exists(config.DOCKER_INPUT_ROOT): shutil.rmtree(config.DOCKER_INPUT_ROOT) if ran_upload and os.path.exists(config.DOCKER_INPUT_ROOT): shutil.rmtree(config.DOCKER_OUTPUT_ROOT) def generate_foreground_masks_callback(msg): """Runs foreground mask generation according to parameters read from the message. Args: msg (dict[str, str]): Message received from RabbitMQ publisher. """ print("Running foreground mask generation...") image_types_to_level = [("color", msg["level"])] ran_download = download_rig(msg) ran_download |= download_image_types(msg, image_types_to_level) ran_download |= download_image_type( msg, "background_color", [msg["background_frame"]], msg["level"] ) msg_cp = copy(msg) msg_cp["color"] = local_image_type_path(msg, "color", msg["level"]) msg_cp["background_color"] = local_image_type_path( msg, "background_color", msg["level"] ) msg_cp["foreground_masks"] = local_image_type_path( msg, "foreground_masks", msg["dst_level"] ) _run_bin(msg_cp) ran_upload = upload_image_type(msg, "foreground_masks", level=msg["dst_level"]) _clean_worker(ran_download, ran_upload) def resize_images_callback(msg): """Runs image resizing according to parameters read from the message. Args: msg (dict[str, str]): Message received from RabbitMQ publisher. """ print("Running image resizing...") image_types_to_level = [(msg["image_type"], None)] ran_download = download_rig(msg) ran_download |= download_image_types(msg, image_types_to_level) with open(local_rig_path(msg), "r") as f: rig = json.load(f) local_src_dir = local_image_type_path(msg, msg["image_type"]) local_dst_dir = local_image_type_path( msg, config.type_to_levels_type[msg["image_type"]] ) resize_frames( local_src_dir, local_dst_dir, rig, msg["first"], msg["last"], msg["threshold"] ) # Clean up workspace to prevent using too much disk space on workers for level in msg["dst_level"]: ran_upload = upload_image_type(msg, msg["image_type"], level=level) _clean_worker(ran_download, ran_upload) def depth_estimation_callback(msg): """Runs depth estimation according to parameters read from the message. Args: msg (dict[str, str]): Message received from RabbitMQ publisher. """ print("Running depth estimation...") ran_download = False msg_cp = copy(msg) if msg["image_type"] == "disparity": image_types_to_level = [("color", msg["level_start"])] if msg["use_foreground_masks"]: ran_download |= download_image_type( msg, "background_disp", [msg["background_frame"]], msg["level_start"] ) image_types_to_level.append(("foreground_masks", msg["level_start"])) if msg["level_start"] < msg["num_levels"] - 1: image_types_to_level.append(("disparity", msg["level_start"] + 1)) if msg["use_foreground_masks"]: image_types_to_level.append( ("foreground_masks", msg["level_start"] + 1) ) else: image_types_to_level = [("background_color", msg["level_start"])] if msg["level_start"] < msg["num_levels"] - 1: image_types_to_level.append(("background_disp", msg["level_start"] + 1)) msg_cp["color"] = local_image_type_path(msg, "background_color_levels") msg_cp["output_root"] = os.path.join(msg["input_root"], "background") ran_download |= download_rig(msg) ran_download |= download_image_types(msg, image_types_to_level) _run_bin(msg_cp) ran_upload = upload_image_type(msg, msg["image_type"], level=msg["level_end"]) _clean_worker(ran_download, ran_upload) def temporal_filter_callback(msg): """Runs temporal filtering according to parameters read from the message. Args: msg (dict[str, str]): Message received from RabbitMQ publisher. """ print("Running temporal filtering...") # If such frames do not exist, S3 simply does not download them msg_cp = copy(msg) frames = get_frame_range(msg["filter_first"], msg["filter_last"]) image_types_to_level = [("color", msg["level"]), ("disparity", msg["level"])] if msg["use_foreground_masks"]: image_types_to_level.append(("foreground_masks", msg["level"])) ran_download = download_rig(msg) ran_download |= download_image_types(msg, image_types_to_level, frames) msg_cp["disparity"] = "" # disparity_level is automatically populated by app _run_bin(msg_cp) processed_frames = get_frame_range(msg["first"], msg["last"]) ran_upload = upload_image_type( msg, "disparity_time_filtered", processed_frames, level=msg["level"] ) _clean_worker(ran_download, ran_upload) def transfer_callback(msg): """Runs transfer according to parameters read from the message. Args: msg (dict[str, str]): Message received from RabbitMQ publisher. """ print("Running rearranging...") rig_cameras = get_cameras(msg, "cameras") frames = get_frame_range(msg["first"], msg["last"]) copy_image_level( msg, msg["src_image_type"], msg["dst_image_type"], rig_cameras, frames, msg["src_level"], msg["dst_level"], ) def _run_upsample(msg, run_upload=True): """Runs disparity upsampling according to parameters read from the message. Args: msg (dict[str, str]): Message received from RabbitMQ publisher. run_upload (bool, optional): Whether or not an upload was performed. Returns: tuple(bool, bool): Respectively whether or not a download and upload were performed. """ image_types_to_level = [(msg["image_type"], msg["level"])] msg_cp = copy(msg) if msg["image_type"] == "disparity": color_image_type = "color" image_types_to_level += [ ("foreground_masks", msg["level"]), ("foreground_masks", msg["dst_level"]), ] msg_cp["foreground_masks_in"] = local_image_type_path( msg, "foreground_masks", msg["level"] ) msg_cp["foreground_masks_out"] = local_image_type_path( msg, "foreground_masks", msg["dst_level"] ) msg_cp["background_disp"] = local_image_type_path( msg, "background_disp", msg["dst_level"] ) download_image_type( msg, "background_disp", [msg["background_frame"]], msg["dst_level"] ) msg_cp["background_disp"] = local_image_type_path( msg, "background_disp", msg["dst_level"] ) elif msg["image_type"] == "background_disp": color_image_type = "background_color" msg_cp["foreground_masks_in"] = "" # Background upsampling doesn't use masks msg_cp["foreground_masks_out"] = "" image_types_to_level.append((color_image_type, msg["dst_level"])) ran_download = download_image_types(msg, image_types_to_level) ran_download |= download_rig(msg) msg_cp["disparity"] = local_image_type_path(msg, msg["image_type"], msg["level"]) msg_cp["output"] = local_image_type_path( msg, config.type_to_upsample_type[msg["image_type"]] ) msg_cp["color"] = local_image_type_path(msg, color_image_type, msg["dst_level"]) _run_bin(msg_cp) if run_upload: ran_upload = upload_image_type( msg, config.type_to_upsample_type[msg["image_type"]] ) return ran_download, ran_upload def upsample_disparity_callback(msg): """Runs disparity upsampling according to parameters read from the message. Args: msg (dict[str, str]): Message received from RabbitMQ publisher. """ print("Running disparity upsampling...") ran_download, ran_upload = _run_upsample(msg) _clean_worker(ran_download, ran_upload) def upsample_layer_disparity_callback(msg): """Runs disparity upsampling and layering according to parameters read from the message. Args: msg (dict[str, str]): Message received from RabbitMQ publisher. """ print("Running disparity upsampling and layering...") msg_cp = copy(msg) msg_cp["app"] = "UpsampleDisparity" ran_download, _ = _run_upsample(msg, run_upload=False) ran_download |= download_image_type( msg, config.type_to_upsample_type["background_disp"], [msg["background_frame"]] ) msg_cp["app"] = "LayerDisparities" msg_cp["background_disp"] = local_image_type_path( msg, config.type_to_upsample_type["background_disp"] ) msg_cp["foreground_disp"] = local_image_type_path( msg, config.type_to_upsample_type["disparity"] ) msg_cp["output"] = config.DOCKER_OUTPUT_ROOT _run_bin(msg_cp) ran_upload = upload_image_type(msg, "disparity") _clean_worker(ran_download, ran_upload) def convert_to_binary_callback(msg): """Runs binary conversion according to parameters read from the message. Args: msg (dict[str, str]): Message received from RabbitMQ publisher. """ print("Converting to binary...") msg_cp = copy(msg) ran_download = download_rig(msg) rig_json = os.path.basename(msg["rig"]) ext_index = rig_json.index(".") fused_json = f"{rig_json[:ext_index]}_fused{rig_json[ext_index:]}" if msg["run_conversion"]: image_types_to_level = [ (msg["color_type"], None), (msg["disparity_type"], msg["level"]), ] msg_cp["disparity"] = local_image_type_path( msg, msg["disparity_type"], msg["level"] ) msg_cp["color"] = local_image_type_path(msg, msg["color_type"]) msg_cp["fused"] = "" # fusion is done independently from conversion ran_download |= download_image_types(msg, image_types_to_level) # If we only have color levels uploaded to S3, we fall back to level_0 if len(os.listdir(msg_cp["color"])) == 0: ran_download = download_image_types(msg, [(msg["color_type"], 0)]) msg_cp["color"] = local_image_type_path(msg, msg["color_type"], 0) else: image_types_to_level = [("bin", None)] local_fused_dir = local_image_type_path(msg, "fused") # Paths are explicitly emptied to avoid path verifications msg_cp["color"] = "" msg_cp["disparity"] = "" msg_cp["foreground_masks"] = "" msg_cp["fused"] = local_fused_dir ran_download |= download_image_types(msg, image_types_to_level) ran_download |= download( src=os.path.join(remote_image_type_path(msg, "bin"), fused_json), dst=os.path.join(local_image_type_path(msg, "bin"), fused_json), ) msg_cp["bin"] = local_image_type_path(msg, "bin") os.makedirs(msg["bin"], exist_ok=True) _run_bin(msg_cp) if msg["run_conversion"]: ran_upload = upload_image_type(msg, "bin") ran_upload |= upload( src=os.path.join(local_image_type_path(msg, "bin"), fused_json), dst=os.path.join(remote_image_type_path(msg, "bin"), fused_json), ) else: # We use a raw upload since upload_image_type only handles frames but we want to # also upload the fused json here ran_upload = upload( src=local_image_type_path(msg, "fused"), dst=remote_image_type_path(msg, "fused"), filters=["*"], ) _clean_worker(ran_download, ran_upload) def simple_mesh_renderer_callback(msg): print("Generating exports...") msg_cp = copy(msg) frames = get_frame_range(msg_cp["first"], msg_cp["last"]) ran_download = download_rig(msg) ran_download = download_image_type(msg, msg_cp["color_type"], frames) ran_download |= download_image_type(msg, msg_cp["disparity_type"], frames) msg_cp["color"] = local_image_type_path(msg, msg_cp["color_type"]) msg_cp["disparity"] = local_image_type_path(msg, msg_cp["disparity_type"]) msg_cp["output"] = local_image_type_path(msg, msg_cp["dst_image_type"]) msg_cp["position"] = '"0.0 0.0 0.0"' msg_cp["forward"] = '"-1.0 0.0 0.0"' msg_cp["up"] = '"0.0 0.0 1.0"' _run_bin(msg_cp) ran_upload = upload_image_type(msg, msg_cp["dst_image_type"], frames) _clean_worker(ran_download, ran_upload) def success(channel, delivery_tag): if channel.is_open: channel.basic_ack(delivery_tag) channel.queue_declare(config.RESPONSE_QUEUE_NAME) channel.basic_publish( exchange="", routing_key=config.RESPONSE_QUEUE_NAME, body="Completed!" ) else: pass def failure(channel, delivery_tag, msg): if channel.is_open: channel.basic_reject(delivery_tag) channel.queue_declare(config.QUEUE_NAME) channel.basic_publish( exchange="", routing_key=config.QUEUE_NAME, body=json.dumps(msg), properties=pika.BasicProperties(delivery_mode=2), # make message persistent ) else: pass def handle_message(connection, channel, delivery_tag, body): msg = json.loads(body.decode("utf-8")) try: print(f"Received {msg}") app_name_to_callback = { "GenerateForegroundMasks": generate_foreground_masks_callback, "DerpCLI": depth_estimation_callback, "TemporalBilateralFilter": temporal_filter_callback, "Transfer": transfer_callback, "UpsampleDisparity": upsample_disparity_callback, "UpsampleLayer": upsample_layer_disparity_callback, "ConvertToBinary": convert_to_binary_callback, "SimpleMeshRenderer": simple_mesh_renderer_callback, "Resize": resize_images_callback, } for app_name in app_name_to_callback: if msg["app"].startswith(app_name): app_name_to_callback[app_name](msg) break # Sends response of job completion success_callback = functools.partial(success, channel, delivery_tag) connection.add_callback_threadsafe(success_callback) except Exception: failure_callback = functools.partial(failure, channel, delivery_tag, msg) connection.add_callback_threadsafe(failure_callback) def callback(ch, method, properties, body, connection): """Dispatches to different callbacks based on the contents of the message. Args: Only the body argument is necessary. The others are imposed by Pika. ch (pika.Channel): N/a method (pika.spec.Basic): N/a properties (pika.spec.BasicProperties): N/a body (bytes): utf-8 encoded message published to the message queue. """ handle = threading.Thread( target=handle_message, args=(connection, ch, method.delivery_tag, body) ) handle.start() def main_loop(argv): """Sets up the callback loop for the worker. Args: argv (list[str]): List of arguments (used interally by abseil). """ while True: try: connection = pika.BlockingConnection( pika.ConnectionParameters(FLAGS.master) ) on_message_callback = functools.partial(callback, connection=(connection)) channel = connection.channel() channel.queue_declare(queue=config.QUEUE_NAME) channel.basic_qos(prefetch_count=1) channel.basic_consume( queue=config.QUEUE_NAME, auto_ack=False, on_message_callback=on_message_callback, ) channel.start_consuming() # Don't recover if connection was closed by broker except pika.exceptions.ConnectionClosedByBroker: break # Don't recover on channel errors except pika.exceptions.AMQPChannelError: break # Recover on all other connection errors except pika.exceptions.AMQPConnectionError: continue if __name__ == "__main__": # Abseil entry point app.run() expects all flags to be already defined flags.DEFINE_string("master", None, "master IP") # Required FLAGS. flags.mark_flag_as_required("master") app.run(main_loop)