point_e/models/sdf.py (88 lines of code) (raw):
from abc import abstractmethod
from typing import Dict, Optional
import torch
import torch.nn as nn
from .perceiver import SimplePerceiver
from .transformer import Transformer
class PointCloudSDFModel(nn.Module):
@property
@abstractmethod
def device(self) -> torch.device:
"""
Get the device that should be used for input tensors.
"""
@property
@abstractmethod
def default_batch_size(self) -> int:
"""
Get a reasonable default number of query points for the model.
In some cases, this might be the only supported size.
"""
@abstractmethod
def encode_point_clouds(self, point_clouds: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Encode a batch of point clouds to cache part of the SDF calculation
done by forward().
:param point_clouds: a batch of [batch x 3 x N] points.
:return: a state representing the encoded point cloud batch.
"""
def forward(
self,
x: torch.Tensor,
point_clouds: Optional[torch.Tensor] = None,
encoded: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
"""
Predict the SDF at the coordinates x, given a batch of point clouds.
Either point_clouds or encoded should be passed. Only exactly one of
these arguments should be None.
:param x: a [batch x 3 x N'] tensor of query points.
:param point_clouds: a [batch x 3 x N] batch of point clouds.
:param encoded: the result of calling encode_point_clouds().
:return: a [batch x N'] tensor of SDF predictions.
"""
assert point_clouds is not None or encoded is not None
assert point_clouds is None or encoded is None
if point_clouds is not None:
encoded = self.encode_point_clouds(point_clouds)
return self.predict_sdf(x, encoded)
@abstractmethod
def predict_sdf(
self, x: torch.Tensor, encoded: Optional[Dict[str, torch.Tensor]]
) -> torch.Tensor:
"""
Predict the SDF at the query points given the encoded point clouds.
Each query point should be treated independently, only conditioning on
the point clouds themselves.
"""
class CrossAttentionPointCloudSDFModel(PointCloudSDFModel):
"""
Encode point clouds using a transformer, and query points using cross
attention to the encoded latents.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int = 4096,
width: int = 512,
encoder_layers: int = 12,
encoder_heads: int = 8,
decoder_layers: int = 4,
decoder_heads: int = 8,
init_scale: float = 0.25,
):
super().__init__()
self._device = device
self.n_ctx = n_ctx
self.encoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype)
self.encoder = Transformer(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
layers=encoder_layers,
heads=encoder_heads,
init_scale=init_scale,
)
self.decoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype)
self.decoder = SimplePerceiver(
device=device,
dtype=dtype,
n_data=n_ctx,
width=width,
layers=decoder_layers,
heads=decoder_heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype)
@property
def device(self) -> torch.device:
return self._device
@property
def default_batch_size(self) -> int:
return self.n_query
def encode_point_clouds(self, point_clouds: torch.Tensor) -> Dict[str, torch.Tensor]:
h = self.encoder_input_proj(point_clouds.permute(0, 2, 1))
h = self.encoder(h)
return dict(latents=h)
def predict_sdf(
self, x: torch.Tensor, encoded: Optional[Dict[str, torch.Tensor]]
) -> torch.Tensor:
data = encoded["latents"]
x = self.decoder_input_proj(x.permute(0, 2, 1))
x = self.decoder(x, data)
x = self.ln_post(x)
x = self.output_proj(x)
return x[..., 0]