in threestudio/data/uncond.py [0:0]
def collate(self, batch) -> Dict[str, Any]:
# sample elevation angles
elevation_deg: Float[Tensor, "B"]
elevation: Float[Tensor, "B"]
if random.random() < 0.5:
# sample elevation angles uniformly with a probability 0.5 (biased towards poles)
elevation_deg = (
torch.rand(self.batch_size)
* (self.elevation_range[1] - self.elevation_range[0])
+ self.elevation_range[0]
)
elevation = elevation_deg * math.pi / 180
else:
# otherwise sample uniformly on sphere
elevation_range_percent = [
self.elevation_range[0] / 180.0 * math.pi,
self.elevation_range[1] / 180.0 * math.pi,
]
# inverse transform sampling
elevation = torch.asin(
(
torch.rand(self.batch_size)
* (
math.sin(elevation_range_percent[1])
- math.sin(elevation_range_percent[0])
)
+ math.sin(elevation_range_percent[0])
)
)
elevation_deg = elevation / math.pi * 180.0
# sample azimuth angles from a uniform distribution bounded by azimuth_range
azimuth_deg: Float[Tensor, "B"]
if self.cfg.batch_uniform_azimuth:
# ensures sampled azimuth angles in a batch cover the whole range
azimuth_deg = (
torch.rand(self.batch_size) + torch.arange(self.batch_size)
) / self.batch_size * (
self.azimuth_range[1] - self.azimuth_range[0]
) + self.azimuth_range[
0
]
else:
# simple random sampling
azimuth_deg = (
torch.rand(self.batch_size)
* (self.azimuth_range[1] - self.azimuth_range[0])
+ self.azimuth_range[0]
)
azimuth = azimuth_deg * math.pi / 180
# sample distances from a uniform distribution bounded by distance_range
camera_distances: Float[Tensor, "B"] = (
torch.rand(self.batch_size)
* (self.camera_distance_range[1] - self.camera_distance_range[0])
+ self.camera_distance_range[0]
)
# convert spherical coordinates to cartesian coordinates
# right hand coordinate system, x back, y right, z up
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
camera_positions: Float[Tensor, "B 3"] = torch.stack(
[
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
camera_distances * torch.sin(elevation),
],
dim=-1,
)
# default scene center at origin
center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
# default camera up direction as +z
up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
None, :
].repeat(self.batch_size, 1)
# sample camera perturbations from a uniform distribution [-camera_perturb, camera_perturb]
camera_perturb: Float[Tensor, "B 3"] = (
torch.rand(self.batch_size, 3) * 2 * self.cfg.camera_perturb
- self.cfg.camera_perturb
)
camera_positions = camera_positions + camera_perturb
# sample center perturbations from a normal distribution with mean 0 and std center_perturb
center_perturb: Float[Tensor, "B 3"] = (
torch.randn(self.batch_size, 3) * self.cfg.center_perturb
)
center = center + center_perturb
# sample up perturbations from a normal distribution with mean 0 and std up_perturb
up_perturb: Float[Tensor, "B 3"] = (
torch.randn(self.batch_size, 3) * self.cfg.up_perturb
)
up = up + up_perturb
# sample fovs from a uniform distribution bounded by fov_range
fovy_deg: Float[Tensor, "B"] = (
torch.rand(self.batch_size) * (self.fovy_range[1] - self.fovy_range[0])
+ self.fovy_range[0]
)
fovy = fovy_deg * math.pi / 180
# sample light distance from a uniform distribution bounded by light_distance_range
light_distances: Float[Tensor, "B"] = (
torch.rand(self.batch_size)
* (self.cfg.light_distance_range[1] - self.cfg.light_distance_range[0])
+ self.cfg.light_distance_range[0]
)
if self.cfg.light_sample_strategy == "dreamfusion":
# sample light direction from a normal distribution with mean camera_position and std light_position_perturb
light_direction: Float[Tensor, "B 3"] = F.normalize(
camera_positions
+ torch.randn(self.batch_size, 3) * self.cfg.light_position_perturb,
dim=-1,
)
# get light position by scaling light direction by light distance
light_positions: Float[Tensor, "B 3"] = (
light_direction * light_distances[:, None]
)
elif self.cfg.light_sample_strategy == "magic3d":
# sample light direction within restricted angle range (pi/3)
local_z = F.normalize(camera_positions, dim=-1)
local_x = F.normalize(
torch.stack(
[local_z[:, 1], -local_z[:, 0], torch.zeros_like(local_z[:, 0])],
dim=-1,
),
dim=-1,
)
local_y = F.normalize(torch.cross(local_z, local_x, dim=-1), dim=-1)
rot = torch.stack([local_x, local_y, local_z], dim=-1)
light_azimuth = (
torch.rand(self.batch_size) * math.pi * 2 - math.pi
) # [-pi, pi]
light_elevation = (
torch.rand(self.batch_size) * math.pi / 3 + math.pi / 6
) # [pi/6, pi/2]
light_positions_local = torch.stack(
[
light_distances
* torch.cos(light_elevation)
* torch.cos(light_azimuth),
light_distances
* torch.cos(light_elevation)
* torch.sin(light_azimuth),
light_distances * torch.sin(light_elevation),
],
dim=-1,
)
light_positions = (rot @ light_positions_local[:, :, None])[:, :, 0]
else:
raise ValueError(
f"Unknown light sample strategy: {self.cfg.light_sample_strategy}"
)
lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
up = F.normalize(torch.cross(right, lookat), dim=-1)
c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
dim=-1,
)
c2w: Float[Tensor, "B 4 4"] = torch.cat(
[c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
)
c2w[:, 3, 3] = 1.0
# get directions by dividing directions_unit_focal by focal length
focal_length: Float[Tensor, "B"] = 0.5 * self.height / torch.tan(0.5 * fovy)
directions: Float[Tensor, "B H W 3"] = self.directions_unit_focal[
None, :, :, :
].repeat(self.batch_size, 1, 1, 1)
directions[:, :, :, :2] = (
directions[:, :, :, :2] / focal_length[:, None, None, None]
)
# Importance note: the returned rays_d MUST be normalized!
rays_o, rays_d = get_rays(
directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize
)
self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix(
fovy, self.width / self.height, 0.1, 1000.0
) # FIXME: hard-coded near and far
mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx)
self.fovy = fovy
return {
"rays_o": rays_o,
"rays_d": rays_d,
"mvp_mtx": mvp_mtx,
"camera_positions": camera_positions,
"c2w": c2w,
"light_positions": light_positions,
"elevation": elevation_deg,
"azimuth": azimuth_deg,
"camera_distances": camera_distances,
"height": self.height,
"width": self.width,
"fovy": self.fovy,
"proj_mtx": self.proj_mtx,
}