in threestudio/data/images.py [0:0]
def setup(self, cfg, split):
self.split = split
self.rank = get_rank()
self.cfg: SingleImageDataModuleConfig = cfg
if self.cfg.use_random_camera:
random_camera_cfg = parse_structured(
RandomCameraDataModuleConfig, self.cfg.get("random_camera", {})
)
# FIXME:
if self.cfg.use_mixed_camera_config:
if self.rank % 2 == 0:
random_camera_cfg.camera_distance_range=[self.cfg.default_camera_distance, self.cfg.default_camera_distance]
random_camera_cfg.fovy_range=[self.cfg.default_fovy_deg, self.cfg.default_fovy_deg]
self.fixed_camera_intrinsic = True
else:
self.fixed_camera_intrinsic = False
if split == "train":
self.random_pose_generator = RandomCameraIterableDataset(
random_camera_cfg
)
else:
self.random_pose_generator = RandomCameraDataset(
random_camera_cfg, split
)
elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg])
azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg])
camera_distance = torch.FloatTensor([self.cfg.default_camera_distance])
elevation = elevation_deg * math.pi / 180
azimuth = azimuth_deg * math.pi / 180
camera_position: Float[Tensor, "1 3"] = torch.stack(
[
camera_distance * torch.cos(elevation) * torch.cos(azimuth),
camera_distance * torch.cos(elevation) * torch.sin(azimuth),
camera_distance * torch.sin(elevation),
],
dim=-1,
)
center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position)
up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None]
light_position: Float[Tensor, "1 3"] = camera_position
lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1)
right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
up = F.normalize(torch.cross(right, lookat), dim=-1)
self.c2w: Float[Tensor, "1 3 4"] = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]],
dim=-1,
)
self.c2w4x4: Float[Tensor, "B 4 4"] = torch.cat(
[self.c2w, torch.zeros_like(self.c2w[:, :1])], dim=1
)
self.c2w4x4[:, 3, 3] = 1.0
self.camera_position = camera_position
self.light_position = light_position
self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg
self.camera_distance = camera_distance
self.fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg]))
self.heights: List[int] = (
[self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height
)
self.widths: List[int] = (
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width
)
assert len(self.heights) == len(self.widths)
self.resolution_milestones: List[int]
if len(self.heights) == 1 and len(self.widths) == 1:
if len(self.cfg.resolution_milestones) > 0:
threestudio.warn(
"Ignoring resolution_milestones since height and width are not changing"
)
self.resolution_milestones = [-1]
else:
assert len(self.heights) == len(self.cfg.resolution_milestones) + 1
self.resolution_milestones = [-1] + self.cfg.resolution_milestones
self.directions_unit_focals = [
get_ray_directions(H=height, W=width, focal=1.0)
for (height, width) in zip(self.heights, self.widths)
]
self.focal_lengths = [
0.5 * height / torch.tan(0.5 * self.fovy) for height in self.heights
]
self.height: int = self.heights[0]
self.width: int = self.widths[0]
self.directions_unit_focal = self.directions_unit_focals[0]
self.focal_length = self.focal_lengths[0]
self.set_rays()
self.load_images()
self.prev_height = self.height