in free_viewpoint_rendering.py [0:0]
def _setup_nonrigid_nerf_network(results_folder, checkpoint="latest"):
extra_sys_folder = os.path.join(results_folder, "backup/")
import sys
sys.path.append(extra_sys_folder)
from train import (
config_parser,
create_nerf,
render_path,
get_parallelized_render_function,
_get_multi_view_helper_mappings,
)
args = config_parser().parse_args(
["--config", os.path.join(results_folder, "logs", "args.txt")]
)
print(args, flush=True)
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
args, autodecoder_variables=None, ignore_optimizer=True
)
def load_weights_into_network(render_kwargs_train, checkpoint=None, path=None):
if path is not None and checkpoint is not None:
raise RuntimeError("trying to load weights from two sources")
if checkpoint is not None:
path = os.path.join(results_folder, "logs", checkpoint + ".tar")
checkpoint_dict = torch.load(path)
start = checkpoint_dict["global_step"]
# optimizer.load_state_dict(checkpoint_dict['optimizer_state_dict'])
render_kwargs_train["network_fn"].load_state_dict(
checkpoint_dict["network_fn_state_dict"]
)
if render_kwargs_train["network_fine"] is not None:
render_kwargs_train["network_fine"].load_state_dict(
checkpoint_dict["network_fine_state_dict"]
)
if render_kwargs_train["ray_bender"] is not None:
render_kwargs_train["ray_bender"].load_state_dict(
checkpoint_dict["ray_bender_state_dict"]
)
return checkpoint_dict
checkpoint_dict = load_weights_into_network(
render_kwargs_train, checkpoint=checkpoint
)
def get_training_ray_bending_latents(checkpoint="latest"):
training_latent_vectors = os.path.join(
results_folder, "logs", checkpoint + ".tar"
)
training_latent_vectors = torch.load(training_latent_vectors)[
"ray_bending_latent_codes"
]
return training_latent_vectors # shape: frames x latent_size
from run_nerf_helpers import determine_nerf_volume_extent
from load_llff import load_llff_data
def load_llff_dataset(
render_kwargs_train_=None,
render_kwargs_test_=None,
return_nerf_volume_extent=False,
):
datadir = args.datadir
factor = args.factor
spherify = args.spherify
bd_factor = args.bd_factor
# actual loading
images, poses, bds, render_poses, i_test = load_llff_data(
datadir,
factor=factor,
recenter=True,
bd_factor=bd_factor,
spherify=spherify,
)
extras = _get_multi_view_helper_mappings(images.shape[0], datatdir)
# poses
hwf = poses[0, :3, -1]
poses = poses[:, :3, :4] # N x 3 x 4
all_rotations = poses[:, :3, :3] # N x 3 x 3
all_translations = poses[:, :3, 3] # N x 3
render_poses = render_poses[:, :3, :4]
render_rotations = render_poses[:, :3, :3]
render_translations = render_poses[:, :3, 3]
# splits
i_test = [] # [i_test]
if args.test_block_size > 0 and args.train_block_size > 0:
print(
"splitting timesteps into training ("
+ str(args.train_block_size)
+ ") and test ("
+ str(args.test_block_size)
+ ") blocks"
)
num_timesteps = len(extras["raw_timesteps"])
test_timesteps = np.concatenate(
[
np.arange(
min(num_timesteps, blocks_start + args.train_block_size),
min(
num_timesteps,
blocks_start + args.train_block_size + args.test_block_size,
),
)
for blocks_start in np.arange(
0, num_timesteps, args.train_block_size + args.test_block_size
)
]
)
i_test = [
imageid
for imageid, timestep in enumerate(
extras["imageid_to_timestepid"]
)
if timestep in test_timesteps
]
i_test = np.array(i_test)
i_val = i_test
i_train = np.array(
[
i
for i in np.arange(int(images.shape[0]))
if (i not in i_test and i not in i_val)
]
)
# near, far
# if args.no_ndc:
near = np.ndarray.min(bds) * 0.9
far = np.ndarray.max(bds) * 1.0
# else:
# near = 0.
# far = 1.
bds_dict = {
"near": near,
"far": far,
}
if render_kwargs_train_ is not None:
render_kwargs_train_.update(bds_dict)
if render_kwargs_test_ is not None:
render_kwargs_test_.update(bds_dict)
if return_nerf_volume_extent:
intrinsics = checkpoint_dict["intrinsics"]
min_point, max_point = determine_nerf_volume_extent(
get_parallelized_render_function(),
poses,
[ intrinsics[extras["imageid_to_viewid"][imageid]] for imageid in range(poses.shape[0]) ],
render_kwargs_test,
args
)
extras["min_nerf_volume_point"] = min_point.detach()
extras["max_nerf_volume_point"] = max_point.detach()
extras["intrinsics"] = checkpoint_dict["intrinsics"]
return (
images,
poses,
all_rotations,
all_translations,
bds,
render_poses,
render_rotations,
render_translations,
i_train,
i_val,
i_test,
near,
far,
extras,
)
raw_render_path = render_path
def render_convenient(
rotations=None,
translations=None,
poses=None,
detailed_output=None,
ray_bending_latents=None,
render_factor=None,
with_ray_bending=None,
custom_checkpoint_dict=None,
intrinsics=None,
chunk=None,
custom_ray_params=None,
custom_render_kwargs_test=None,
rigidity_test_time_cutoff=None,
motion_factor=None,
foreground_removal=None
):
# poses should have shape Nx3x4, rotations Nx3x3, translations Nx3 (or Nx3x1 or Nx1x3 or 3)
# ray_bending_latents are a list of latent codes or an array of shape N x latent_size
# intrinsics should be a list of dicts with entries height, width, center_x, center_y, focal_x, focal_y
# poses
if poses is None:
if rotations is None or translations is None:
raise RuntimeError
rotations = torch.Tensor(rotations).reshape(-1, 3, 3)
translations = torch.Tensor(translations).reshape(-1, 3, 1)
poses = torch.cat([rotations, translations], -1) # N x 3 x 4
else:
if rotations is not None or translations is not None:
raise RuntimeError
if len(poses.shape) > 3:
raise RuntimeError
if (
poses.shape[-1] == 5
): # the standard poses that are loaded by load_llff have hwf in the last column, but that's ignored anyway later on, so throw away here for simplicity
poses = poses[..., :4]
poses = torch.Tensor(poses).cuda().reshape(-1, 3, 4)
# other parameters/arguments
checkpoint_dict_ = (
checkpoint_dict
if custom_checkpoint_dict is None
else custom_checkpoint_dict
)
render_kwargs_test_ = (
render_kwargs_test
if custom_render_kwargs_test is None
else custom_render_kwargs_test
)
if intrinsics is None:
intrinsics = [ checkpoint_dict["intrinsics"][0] for _ in range(len(poses)) ]
if chunk is None:
chunk = args.chunk
if render_factor is None:
render_factor = 0
if detailed_output is None:
detailed_output = False
if with_ray_bending is None:
with_ray_bending = True
if with_ray_bending:
# forced background stabilization
backup_rigidity_test_time_cutoff = render_kwargs_test_[
"ray_bender"
].rigidity_test_time_cutoff
render_kwargs_test_[
"ray_bender"
].rigidity_test_time_cutoff = rigidity_test_time_cutoff
# motion exaggeration/dampening
backup_test_time_scaling = render_kwargs_test_[
"ray_bender"
].test_time_scaling
render_kwargs_test_[
"ray_bender"
].test_time_scaling = motion_factor
# foreground removal
backup_foreground_removal = render_kwargs_test_["network_fn"].test_time_nonrigid_object_removal_threshold
render_kwargs_test_["network_fn"].test_time_nonrigid_object_removal_threshold = foreground_removal
if "network_fine" in render_kwargs_test_:
render_kwargs_test_["network_fine"].test_time_nonrigid_object_removal_threshold = foreground_removal
else:
backup_ray_bender = render_kwargs_test_["network_fn"].ray_bender
render_kwargs_test_["network_fn"].ray_bender = (None,)
render_kwargs_test_["ray_bender"] = None
if "network_fine" in render_kwargs_test_:
render_kwargs_test_["network_fine"].ray_bender = (None,)
coarse_model = render_kwargs_test_["network_fn"]
fine_model = render_kwargs_test_["network_fine"]
ray_bender = render_kwargs_test_["ray_bender"]
parallel_render = get_parallelized_render_function(
coarse_model=coarse_model, fine_model=fine_model, ray_bender=ray_bender
)
with torch.no_grad():
returned_outputs = render_path(
poses,
intrinsics,
args.chunk,
render_kwargs_test_,
render_factor=render_factor,
detailed_output=detailed_output,
ray_bending_latents=ray_bending_latents,
parallelized_render_function=parallel_render,
)
if with_ray_bending:
render_kwargs_test_[
"ray_bender"
].rigidity_test_time_cutoff = backup_rigidity_test_time_cutoff
render_kwargs_test_[
"ray_bender"
].test_time_scaling = backup_test_time_scaling
render_kwargs_test_["network_fn"].test_time_nonrigid_object_removal_threshold = backup_foreground_removal
if "network_fine" in render_kwargs_test_:
render_kwargs_test_["network_fine"].test_time_nonrigid_object_removal_threshold = backup_foreground_removal
else:
render_kwargs_test_["network_fn"].ray_bender = backup_ray_bender
render_kwargs_test_["ray_bender"] = backup_ray_bender[0]
if "network_fine" in render_kwargs_test_:
render_kwargs_test_["network_fine"].ray_bender = backup_ray_bender
if detailed_output:
rgbs, disps, details_and_rest = returned_outputs
return (
rgbs,
disps,
details_and_rest,
) # N x height x width x 3, N x height x width. RGB values in [0,1]
else:
rgbs, disps = returned_outputs
return (
rgbs,
disps,
) # N x height x width x 3, N x height x width. RGB values in [0,1]
from run_nerf_helpers import (
to8b,
visualize_disparity_with_jet_color_scheme,
visualize_disparity_with_blinn_phong,
visualize_ray_bending,
)
def convert_rgb_to_saveable(rgb):
# input: float values in [0,1]
# output: int values in [0,255]
return to8b(rgb)
def convert_disparity_to_saveable(disparity, normalize=True):
# takes in a single disparity map of shape height x width.
# can be saved via: imageio.imwrite(filename, convert_disparity_to_saveable(disparity))
converted_disparity = (
disparity / np.max(disparity) if normalize else disparity.copy()
)
converted_disparity = to8b(
converted_disparity
) # height x width. int values in [0,255].
return converted_disparity
def convert_disparity_to_jet(disparity, normalize=True):
converted_disparity = (
disparity / np.max(disparity) if normalize else disparity.copy()
)
converted_disparity = to8b(
visualize_disparity_with_jet_color_scheme(converted_disparity)
)
return converted_disparity # height x width x 3. int values in [0,255].
def convert_disparity_to_phong(disparity, normalize=True):
converted_disparity = (
disparity / np.max(disparity) if normalize else disparity.copy()
)
converted_disparity = to8b(
visualize_disparity_with_blinn_phong(converted_disparity)
)
return converted_disparity # height x width x 3. int values in [0,255].
def store_ray_bending_mesh_visualization(
initial_input_pts, input_pts, filename_prefix, subsampled_target=None
):
# initial_input_pts: rays x samples_per_ray x 3
# input_pts: rays x samples_per_ray x 3
return visualize_ray_bending(
initial_input_pts,
input_pts,
filename_prefix,
subsampled_target=subsampled_target,
)
sys.path.remove(extra_sys_folder)
return (
render_kwargs_train,
render_kwargs_test,
start,
grad_vars,
load_weights_into_network,
checkpoint_dict,
get_training_ray_bending_latents,
load_llff_dataset,
raw_render_path,
render_convenient,
convert_rgb_to_saveable,
convert_disparity_to_saveable,
convert_disparity_to_jet,
convert_disparity_to_phong,
store_ray_bending_mesh_visualization,
to8b,
)