shap_e/models/transmitter/channels_encoder.py [636:667]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return gen()

    def encode_views(self, batch: AttrDict) -> torch.Tensor:
        """
        :return: [batch_size, num_views, n_patches, width]
        """
        all_views = self.views_to_tensor(batch.views).to(self.device)
        if self.use_depth:
            all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
        all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)

        batch_size, num_views, _, _, _ = all_views.shape

        views_proj = self.patch_emb(
            all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
        )
        views_proj = (
            views_proj.reshape([batch_size, num_views, self.width, -1])
            .permute(0, 1, 3, 2)
            .contiguous()
        )  # [batch_size x num_views x n_patches x width]

        # [batch_size, num_views, 1, 2 * width]
        camera_proj = self.camera_emb(all_cameras).reshape(
            [batch_size, num_views, 1, self.width * 2]
        )
        pose_dropout = self.pose_dropout if self.training else 0.0
        mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout
        camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj))
        scale, shift = camera_proj.chunk(2, dim=3)
        views_proj = views_proj * (scale + 1.0) + shift
        return views_proj
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



shap_e/models/transmitter/pc_encoder.py [321:352]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return gen()

    def encode_views(self, batch: AttrDict) -> torch.Tensor:
        """
        :return: [batch_size, num_views, n_patches, width]
        """
        all_views = self.views_to_tensor(batch.views).to(self.device)
        if self.use_depth:
            all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
        all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)

        batch_size, num_views, _, _, _ = all_views.shape

        views_proj = self.patch_emb(
            all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
        )
        views_proj = (
            views_proj.reshape([batch_size, num_views, self.width, -1])
            .permute(0, 1, 3, 2)
            .contiguous()
        )  # [batch_size x num_views x n_patches x width]

        # [batch_size, num_views, 1, 2 * width]
        camera_proj = self.camera_emb(all_cameras).reshape(
            [batch_size, num_views, 1, self.width * 2]
        )
        pose_dropout = self.pose_dropout if self.training else 0.0
        mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout
        camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj))
        scale, shift = camera_proj.chunk(2, dim=3)
        views_proj = views_proj * (scale + 1.0) + shift
        return views_proj
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



