in shap_e/models/transmitter/multiview_encoder.py [0:0]
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
_ = options
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)
batch_size, num_views, _, _, _ = all_views.shape
views_proj = self.patch_emb(
all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
cameras_proj = self.camera_emb(all_cameras).reshape([batch_size, num_views, 1, self.width])
h = torch.cat([views_proj, cameras_proj], dim=2).reshape([batch_size, -1, self.width])
h = h + self.pos_emb
h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)
h = self.ln_pre(h)
h = self.backbone(h)
h = self.ln_post(h)
h = h[:, self.n_ctx :]
h = self.output_proj(h).flatten(1)
return h