in habitat/datasets/rearrange/rearrange_generator.py [0:0]
def generate_single_episode(self) -> Optional[RearrangeEpisode]:
"""
Generate a single episode, sampling the scene.
"""
self._reset_samplers()
self.episode_data: Dict[str, Dict[str, Any]] = {
"sampled_objects": {}, # object sampler name -> sampled object instances
"sampled_targets": {}, # target sampler name -> (object, target state)
}
ep_scene_handle = self.generate_scene()
# sample AO states for objects in the scene
# ao_instance_handle -> [ (link_ix, state), ... ]
ao_states: Dict[str, Dict[int, float]] = {}
for sampler_name, ao_state_sampler in self._ao_state_samplers.items():
sampler_states = ao_state_sampler.sample(self.sim)
assert (
sampler_states is not None
), f"AO sampler '{sampler_name}' failed"
for sampled_instance, link_states in sampler_states.items():
if sampled_instance.handle not in ao_states:
ao_states[sampled_instance.handle] = {}
for link_ix, joint_state in link_states.items():
ao_states[sampled_instance.handle][link_ix] = joint_state
# visualize after setting AO states to correctly see scene state
if self._render_debug_obs:
self.visualize_scene_receptacles()
self.vdb.make_debug_video(prefix="receptacles_")
# sample object placements
for sampler_name, obj_sampler in self._obj_samplers.items():
new_objects = obj_sampler.sample(
self.sim,
snap_down=True,
vdb=(self.vdb if self._render_debug_obs else None),
)
if sampler_name not in self.episode_data["sampled_objects"]:
self.episode_data["sampled_objects"][
sampler_name
] = new_objects
else:
# handle duplicate sampler names
self.episode_data["sampled_objects"][
sampler_name
] += new_objects
self.ep_sampled_objects += new_objects
logger.info(
f"Sampler {sampler_name} generated {len(new_objects)} new object placements."
)
# debug visualization showing each newly added object
if self._render_debug_obs:
for new_object in new_objects:
self.vdb.look_at(new_object.translation)
self.vdb.get_observation()
# simulate the world for a few seconds to validate the placements
if not self.settle_sim():
logger.warning(
"Aborting episode generation due to unstable state."
)
return None
# generate the target samplers
self._get_object_target_samplers()
target_refs = {}
# sample targets
for target_idx, (sampler_name, target_sampler) in enumerate(
self._target_samplers.items()
):
new_target_objects = target_sampler.sample(
self.sim, snap_down=True, vdb=self.vdb
)
# cache transforms and add visualizations
for instance_handle, target_object in new_target_objects.items():
assert (
instance_handle not in self.episode_data["sampled_targets"]
), f"Duplicate target for instance '{instance_handle}'."
rom = self.sim.get_rigid_object_manager()
target_bb_size = (
target_object.root_scene_node.cumulative_bb.size()
)
target_transform = target_object.transformation
self.episode_data["sampled_targets"][
instance_handle
] = np.array(target_transform)
target_refs[instance_handle] = f"{sampler_name}|{target_idx}"
rom.remove_object_by_handle(target_object.handle)
if self._render_debug_obs:
sutils.add_transformed_wire_box(
self.sim,
size=target_bb_size / 2.0,
transform=target_transform,
)
self.vdb.look_at(target_transform.translation)
self.vdb.debug_line_render.set_line_width(2.0)
self.vdb.debug_line_render.draw_transformed_line(
target_transform.translation,
rom.get_object_by_handle(instance_handle).translation,
mn.Color4(1.0, 0.0, 0.0, 1.0),
mn.Color4(1.0, 0.0, 0.0, 1.0),
)
self.vdb.get_observation()
# collect final object states and serialize the episode
# TODO: creating shortened names should be automated and embedded in the objects to be done in a uniform way
sampled_rigid_object_states = [
(
x.creation_attributes.handle.split(
x.creation_attributes.file_directory
)[-1].split("/")[-1],
np.array(x.transformation),
)
for x in self.ep_sampled_objects
]
# sampled_rigid_object_states = [
# (x.creation_attributes.handle, np.array(x.transformation))
# for x in self.ep_sampled_objects
# ]
self.num_ep_generated += 1
return RearrangeEpisode(
scene_dataset_config=self.cfg.dataset_path,
additional_obj_config_paths=self.cfg.additional_object_paths,
episode_id=str(self.num_ep_generated - 1),
start_position=[0, 0, 0],
start_rotation=[
0,
0,
0,
1,
],
scene_id=ep_scene_handle,
ao_states=ao_states,
rigid_objs=sampled_rigid_object_states,
targets=self.episode_data["sampled_targets"],
markers=self.cfg.markers,
info={"object_labels": target_refs},
)