scripts/render/worker.py (368 lines of code) (raw):
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Consumption side of render message queue.
Provides the interface for performing computations based on messages received
from the queue. Running a worker node only has semantic meaning in the context of
subscribing to a master running a queue, but once it establishes connection, it will
continue to poll for messages until explicitly terminated or the connection is closed.
Example:
To run a single worker node subscribed to 192.168.1.100:
$ python worker.py \
--master=192.168.1.100
Attributes:
FLAGS (absl.flags._flagvalues.FlagValues): Globally defined flags for worker.py.
"""
import functools
import json
import os
import shutil
import sys
import threading
from copy import copy
import pika
from absl import app, flags
dir_scripts = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
dir_root = os.path.dirname(dir_scripts)
sys.path.append(dir_root)
sys.path.append(os.path.join(dir_scripts, "util"))
import config
from network import (
copy_image_level,
download,
download_image_type,
download_image_types,
download_rig,
get_cameras,
get_frame_fns,
get_frame_name,
get_frame_range,
local_image_type_path,
local_rig_path,
remote_image_type_path,
upload,
upload_image_type,
)
from resize import resize_frames
from scripts.render.network import Address
from scripts.util.system_util import run_command
from setup import bin_to_flags
FLAGS = flags.FLAGS
def _run_bin(msg):
"""Runs the binary associated with the message. The execution assumes the worker is
running in a configured Docker container.
Args:
msg (dict[str, str]): Message received from RabbitMQ publisher.
"""
msg_cp = copy(msg)
# The binary flag convention includes the "last" frame
msg_cp["last"] = get_frame_name(int(msg["last"]))
app_name = msg_cp["app"].split(":")[0]
relevant_flags = [flag["name"] for flag in bin_to_flags[app_name]]
cmd_flags = " ".join(
[
f"--{flag}={msg_cp[flag]}"
for flag in relevant_flags
if flag in msg_cp and msg_cp[flag] != ""
]
)
# Order is determined to prevent substrings from being accidentally replaced
input_root = msg_cp["input_root"].rstrip("/")
output_root = msg_cp["output_root"].rstrip("/")
root_order = (
[output_root, input_root]
if input_root in output_root
else [input_root, output_root]
)
root_to_docker = {
input_root: config.DOCKER_INPUT_ROOT,
output_root: config.DOCKER_OUTPUT_ROOT,
}
for root in root_order:
if not os.path.exists(root):
cmd_flags = cmd_flags.replace(root, root_to_docker[root])
bin_path = os.path.join(config.DOCKER_BUILD_ROOT, "bin", app_name)
cmd = f"GLOG_alsologtostderr=1 GLOG_stderrthreshold=0 {bin_path} {cmd_flags}"
run_command(cmd)
def _clean_worker(ran_download, ran_upload):
"""Deletes any files that were downloaded or uploaded.
Args:
ran_download (bool): Whether or not a download was performed.
ran_upload (bool): Whether or not an upload was performed.
"""
if ran_download and os.path.exists(config.DOCKER_INPUT_ROOT):
shutil.rmtree(config.DOCKER_INPUT_ROOT)
if ran_upload and os.path.exists(config.DOCKER_INPUT_ROOT):
shutil.rmtree(config.DOCKER_OUTPUT_ROOT)
def generate_foreground_masks_callback(msg):
"""Runs foreground mask generation according to parameters read from the message.
Args:
msg (dict[str, str]): Message received from RabbitMQ publisher.
"""
print("Running foreground mask generation...")
image_types_to_level = [("color", msg["level"])]
ran_download = download_rig(msg)
ran_download |= download_image_types(msg, image_types_to_level)
ran_download |= download_image_type(
msg, "background_color", [msg["background_frame"]], msg["level"]
)
msg_cp = copy(msg)
msg_cp["color"] = local_image_type_path(msg, "color", msg["level"])
msg_cp["background_color"] = local_image_type_path(
msg, "background_color", msg["level"]
)
msg_cp["foreground_masks"] = local_image_type_path(
msg, "foreground_masks", msg["dst_level"]
)
_run_bin(msg_cp)
ran_upload = upload_image_type(msg, "foreground_masks", level=msg["dst_level"])
_clean_worker(ran_download, ran_upload)
def resize_images_callback(msg):
"""Runs image resizing according to parameters read from the message.
Args:
msg (dict[str, str]): Message received from RabbitMQ publisher.
"""
print("Running image resizing...")
image_types_to_level = [(msg["image_type"], None)]
ran_download = download_rig(msg)
ran_download |= download_image_types(msg, image_types_to_level)
with open(local_rig_path(msg), "r") as f:
rig = json.load(f)
local_src_dir = local_image_type_path(msg, msg["image_type"])
local_dst_dir = local_image_type_path(
msg, config.type_to_levels_type[msg["image_type"]]
)
resize_frames(
local_src_dir, local_dst_dir, rig, msg["first"], msg["last"], msg["threshold"]
)
# Clean up workspace to prevent using too much disk space on workers
for level in msg["dst_level"]:
ran_upload = upload_image_type(msg, msg["image_type"], level=level)
_clean_worker(ran_download, ran_upload)
def depth_estimation_callback(msg):
"""Runs depth estimation according to parameters read from the message.
Args:
msg (dict[str, str]): Message received from RabbitMQ publisher.
"""
print("Running depth estimation...")
ran_download = False
msg_cp = copy(msg)
if msg["image_type"] == "disparity":
image_types_to_level = [("color", msg["level_start"])]
if msg["use_foreground_masks"]:
ran_download |= download_image_type(
msg, "background_disp", [msg["background_frame"]], msg["level_start"]
)
image_types_to_level.append(("foreground_masks", msg["level_start"]))
if msg["level_start"] < msg["num_levels"] - 1:
image_types_to_level.append(("disparity", msg["level_start"] + 1))
if msg["use_foreground_masks"]:
image_types_to_level.append(
("foreground_masks", msg["level_start"] + 1)
)
else:
image_types_to_level = [("background_color", msg["level_start"])]
if msg["level_start"] < msg["num_levels"] - 1:
image_types_to_level.append(("background_disp", msg["level_start"] + 1))
msg_cp["color"] = local_image_type_path(msg, "background_color_levels")
msg_cp["output_root"] = os.path.join(msg["input_root"], "background")
ran_download |= download_rig(msg)
ran_download |= download_image_types(msg, image_types_to_level)
_run_bin(msg_cp)
ran_upload = upload_image_type(msg, msg["image_type"], level=msg["level_end"])
_clean_worker(ran_download, ran_upload)
def temporal_filter_callback(msg):
"""Runs temporal filtering according to parameters read from the message.
Args:
msg (dict[str, str]): Message received from RabbitMQ publisher.
"""
print("Running temporal filtering...")
# If such frames do not exist, S3 simply does not download them
msg_cp = copy(msg)
frames = get_frame_range(msg["filter_first"], msg["filter_last"])
image_types_to_level = [("color", msg["level"]), ("disparity", msg["level"])]
if msg["use_foreground_masks"]:
image_types_to_level.append(("foreground_masks", msg["level"]))
ran_download = download_rig(msg)
ran_download |= download_image_types(msg, image_types_to_level, frames)
msg_cp["disparity"] = "" # disparity_level is automatically populated by app
_run_bin(msg_cp)
processed_frames = get_frame_range(msg["first"], msg["last"])
ran_upload = upload_image_type(
msg, "disparity_time_filtered", processed_frames, level=msg["level"]
)
_clean_worker(ran_download, ran_upload)
def transfer_callback(msg):
"""Runs transfer according to parameters read from the message.
Args:
msg (dict[str, str]): Message received from RabbitMQ publisher.
"""
print("Running rearranging...")
rig_cameras = get_cameras(msg, "cameras")
frames = get_frame_range(msg["first"], msg["last"])
copy_image_level(
msg,
msg["src_image_type"],
msg["dst_image_type"],
rig_cameras,
frames,
msg["src_level"],
msg["dst_level"],
)
def _run_upsample(msg, run_upload=True):
"""Runs disparity upsampling according to parameters read from the message.
Args:
msg (dict[str, str]): Message received from RabbitMQ publisher.
run_upload (bool, optional): Whether or not an upload was performed.
Returns:
tuple(bool, bool): Respectively whether or not a download and upload were performed.
"""
image_types_to_level = [(msg["image_type"], msg["level"])]
msg_cp = copy(msg)
if msg["image_type"] == "disparity":
color_image_type = "color"
image_types_to_level += [
("foreground_masks", msg["level"]),
("foreground_masks", msg["dst_level"]),
]
msg_cp["foreground_masks_in"] = local_image_type_path(
msg, "foreground_masks", msg["level"]
)
msg_cp["foreground_masks_out"] = local_image_type_path(
msg, "foreground_masks", msg["dst_level"]
)
msg_cp["background_disp"] = local_image_type_path(
msg, "background_disp", msg["dst_level"]
)
download_image_type(
msg, "background_disp", [msg["background_frame"]], msg["dst_level"]
)
msg_cp["background_disp"] = local_image_type_path(
msg, "background_disp", msg["dst_level"]
)
elif msg["image_type"] == "background_disp":
color_image_type = "background_color"
msg_cp["foreground_masks_in"] = "" # Background upsampling doesn't use masks
msg_cp["foreground_masks_out"] = ""
image_types_to_level.append((color_image_type, msg["dst_level"]))
ran_download = download_image_types(msg, image_types_to_level)
ran_download |= download_rig(msg)
msg_cp["disparity"] = local_image_type_path(msg, msg["image_type"], msg["level"])
msg_cp["output"] = local_image_type_path(
msg, config.type_to_upsample_type[msg["image_type"]]
)
msg_cp["color"] = local_image_type_path(msg, color_image_type, msg["dst_level"])
_run_bin(msg_cp)
if run_upload:
ran_upload = upload_image_type(
msg, config.type_to_upsample_type[msg["image_type"]]
)
return ran_download, ran_upload
def upsample_disparity_callback(msg):
"""Runs disparity upsampling according to parameters read from the message.
Args:
msg (dict[str, str]): Message received from RabbitMQ publisher.
"""
print("Running disparity upsampling...")
ran_download, ran_upload = _run_upsample(msg)
_clean_worker(ran_download, ran_upload)
def upsample_layer_disparity_callback(msg):
"""Runs disparity upsampling and layering according to parameters read from the message.
Args:
msg (dict[str, str]): Message received from RabbitMQ publisher.
"""
print("Running disparity upsampling and layering...")
msg_cp = copy(msg)
msg_cp["app"] = "UpsampleDisparity"
ran_download, _ = _run_upsample(msg, run_upload=False)
ran_download |= download_image_type(
msg, config.type_to_upsample_type["background_disp"], [msg["background_frame"]]
)
msg_cp["app"] = "LayerDisparities"
msg_cp["background_disp"] = local_image_type_path(
msg, config.type_to_upsample_type["background_disp"]
)
msg_cp["foreground_disp"] = local_image_type_path(
msg, config.type_to_upsample_type["disparity"]
)
msg_cp["output"] = config.DOCKER_OUTPUT_ROOT
_run_bin(msg_cp)
ran_upload = upload_image_type(msg, "disparity")
_clean_worker(ran_download, ran_upload)
def convert_to_binary_callback(msg):
"""Runs binary conversion according to parameters read from the message.
Args:
msg (dict[str, str]): Message received from RabbitMQ publisher.
"""
print("Converting to binary...")
msg_cp = copy(msg)
ran_download = download_rig(msg)
rig_json = os.path.basename(msg["rig"])
ext_index = rig_json.index(".")
fused_json = f"{rig_json[:ext_index]}_fused{rig_json[ext_index:]}"
if msg["run_conversion"]:
image_types_to_level = [
(msg["color_type"], None),
(msg["disparity_type"], msg["level"]),
]
msg_cp["disparity"] = local_image_type_path(
msg, msg["disparity_type"], msg["level"]
)
msg_cp["color"] = local_image_type_path(msg, msg["color_type"])
msg_cp["fused"] = "" # fusion is done independently from conversion
ran_download |= download_image_types(msg, image_types_to_level)
# If we only have color levels uploaded to S3, we fall back to level_0
if len(os.listdir(msg_cp["color"])) == 0:
ran_download = download_image_types(msg, [(msg["color_type"], 0)])
msg_cp["color"] = local_image_type_path(msg, msg["color_type"], 0)
else:
image_types_to_level = [("bin", None)]
local_fused_dir = local_image_type_path(msg, "fused")
# Paths are explicitly emptied to avoid path verifications
msg_cp["color"] = ""
msg_cp["disparity"] = ""
msg_cp["foreground_masks"] = ""
msg_cp["fused"] = local_fused_dir
ran_download |= download_image_types(msg, image_types_to_level)
ran_download |= download(
src=os.path.join(remote_image_type_path(msg, "bin"), fused_json),
dst=os.path.join(local_image_type_path(msg, "bin"), fused_json),
)
msg_cp["bin"] = local_image_type_path(msg, "bin")
os.makedirs(msg["bin"], exist_ok=True)
_run_bin(msg_cp)
if msg["run_conversion"]:
ran_upload = upload_image_type(msg, "bin")
ran_upload |= upload(
src=os.path.join(local_image_type_path(msg, "bin"), fused_json),
dst=os.path.join(remote_image_type_path(msg, "bin"), fused_json),
)
else:
# We use a raw upload since upload_image_type only handles frames but we want to
# also upload the fused json here
ran_upload = upload(
src=local_image_type_path(msg, "fused"),
dst=remote_image_type_path(msg, "fused"),
filters=["*"],
)
_clean_worker(ran_download, ran_upload)
def simple_mesh_renderer_callback(msg):
print("Generating exports...")
msg_cp = copy(msg)
frames = get_frame_range(msg_cp["first"], msg_cp["last"])
ran_download = download_rig(msg)
ran_download = download_image_type(msg, msg_cp["color_type"], frames)
ran_download |= download_image_type(msg, msg_cp["disparity_type"], frames)
msg_cp["color"] = local_image_type_path(msg, msg_cp["color_type"])
msg_cp["disparity"] = local_image_type_path(msg, msg_cp["disparity_type"])
msg_cp["output"] = local_image_type_path(msg, msg_cp["dst_image_type"])
msg_cp["position"] = '"0.0 0.0 0.0"'
msg_cp["forward"] = '"-1.0 0.0 0.0"'
msg_cp["up"] = '"0.0 0.0 1.0"'
_run_bin(msg_cp)
ran_upload = upload_image_type(msg, msg_cp["dst_image_type"], frames)
_clean_worker(ran_download, ran_upload)
def success(channel, delivery_tag):
if channel.is_open:
channel.basic_ack(delivery_tag)
channel.queue_declare(config.RESPONSE_QUEUE_NAME)
channel.basic_publish(
exchange="", routing_key=config.RESPONSE_QUEUE_NAME, body="Completed!"
)
else:
pass
def failure(channel, delivery_tag, msg):
if channel.is_open:
channel.basic_reject(delivery_tag)
channel.queue_declare(config.QUEUE_NAME)
channel.basic_publish(
exchange="",
routing_key=config.QUEUE_NAME,
body=json.dumps(msg),
properties=pika.BasicProperties(delivery_mode=2), # make message persistent
)
else:
pass
def handle_message(connection, channel, delivery_tag, body):
msg = json.loads(body.decode("utf-8"))
try:
print(f"Received {msg}")
app_name_to_callback = {
"GenerateForegroundMasks": generate_foreground_masks_callback,
"DerpCLI": depth_estimation_callback,
"TemporalBilateralFilter": temporal_filter_callback,
"Transfer": transfer_callback,
"UpsampleDisparity": upsample_disparity_callback,
"UpsampleLayer": upsample_layer_disparity_callback,
"ConvertToBinary": convert_to_binary_callback,
"SimpleMeshRenderer": simple_mesh_renderer_callback,
"Resize": resize_images_callback,
}
for app_name in app_name_to_callback:
if msg["app"].startswith(app_name):
app_name_to_callback[app_name](msg)
break
# Sends response of job completion
success_callback = functools.partial(success, channel, delivery_tag)
connection.add_callback_threadsafe(success_callback)
except Exception:
failure_callback = functools.partial(failure, channel, delivery_tag, msg)
connection.add_callback_threadsafe(failure_callback)
def callback(ch, method, properties, body, connection):
"""Dispatches to different callbacks based on the contents of the message.
Args:
Only the body argument is necessary. The others are imposed by Pika.
ch (pika.Channel): N/a
method (pika.spec.Basic): N/a
properties (pika.spec.BasicProperties): N/a
body (bytes): utf-8 encoded message published to the message queue.
"""
handle = threading.Thread(
target=handle_message, args=(connection, ch, method.delivery_tag, body)
)
handle.start()
def main_loop(argv):
"""Sets up the callback loop for the worker.
Args:
argv (list[str]): List of arguments (used interally by abseil).
"""
while True:
try:
connection = pika.BlockingConnection(
pika.ConnectionParameters(FLAGS.master)
)
on_message_callback = functools.partial(callback, connection=(connection))
channel = connection.channel()
channel.queue_declare(queue=config.QUEUE_NAME)
channel.basic_qos(prefetch_count=1)
channel.basic_consume(
queue=config.QUEUE_NAME,
auto_ack=False,
on_message_callback=on_message_callback,
)
channel.start_consuming()
# Don't recover if connection was closed by broker
except pika.exceptions.ConnectionClosedByBroker:
break
# Don't recover on channel errors
except pika.exceptions.AMQPChannelError:
break
# Recover on all other connection errors
except pika.exceptions.AMQPConnectionError:
continue
if __name__ == "__main__":
# Abseil entry point app.run() expects all flags to be already defined
flags.DEFINE_string("master", None, "master IP")
# Required FLAGS.
flags.mark_flag_as_required("master")
app.run(main_loop)