in point_e/models/transformer.py [0:0]
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
input_channels: int = 3,
output_channels: int = 3,
n_ctx: int = 1024,
width: int = 512,
layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
time_token_cond: bool = False,