in shap_e/models/transmitter/pc_encoder.py [0:0]
def encode_views(self, batch: AttrDict) -> torch.Tensor:
"""
:return: [batch_size, num_views, n_patches, width]
"""
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)
batch_size, num_views, _, _, _ = all_views.shape
views_proj = self.patch_emb(
all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
# [batch_size, num_views, 1, 2 * width]
camera_proj = self.camera_emb(all_cameras).reshape(
[batch_size, num_views, 1, self.width * 2]
)
pose_dropout = self.pose_dropout if self.training else 0.0
mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout
camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj))
scale, shift = camera_proj.chunk(2, dim=3)
views_proj = views_proj * (scale + 1.0) + shift
return views_proj