point_e/models/configs.py (129 lines of code) (raw):

from typing import Any, Dict import torch import torch.nn as nn from .sdf import CrossAttentionPointCloudSDFModel from .transformer import ( CLIPImageGridPointDiffusionTransformer, CLIPImageGridUpsamplePointDiffusionTransformer, CLIPImagePointDiffusionTransformer, PointDiffusionTransformer, UpsamplePointDiffusionTransformer, ) MODEL_CONFIGS = { "base40M-imagevec": { "cond_drop_prob": 0.1, "heads": 8, "init_scale": 0.25, "input_channels": 6, "layers": 12, "n_ctx": 1024, "name": "CLIPImagePointDiffusionTransformer", "output_channels": 12, "time_token_cond": True, "token_cond": True, "width": 512, }, "base40M-textvec": { "cond_drop_prob": 0.1, "heads": 8, "init_scale": 0.25, "input_channels": 6, "layers": 12, "n_ctx": 1024, "name": "CLIPImagePointDiffusionTransformer", "output_channels": 12, "time_token_cond": True, "token_cond": True, "width": 512, }, "base40M-uncond": { "heads": 8, "init_scale": 0.25, "input_channels": 6, "layers": 12, "n_ctx": 1024, "name": "PointDiffusionTransformer", "output_channels": 12, "time_token_cond": True, "width": 512, }, "base40M": { "cond_drop_prob": 0.1, "heads": 8, "init_scale": 0.25, "input_channels": 6, "layers": 12, "n_ctx": 1024, "name": "CLIPImageGridPointDiffusionTransformer", "output_channels": 12, "time_token_cond": True, "width": 512, }, "base300M": { "cond_drop_prob": 0.1, "heads": 16, "init_scale": 0.25, "input_channels": 6, "layers": 24, "n_ctx": 1024, "name": "CLIPImageGridPointDiffusionTransformer", "output_channels": 12, "time_token_cond": True, "width": 1024, }, "base1B": { "cond_drop_prob": 0.1, "heads": 32, "init_scale": 0.25, "input_channels": 6, "layers": 24, "n_ctx": 1024, "name": "CLIPImageGridPointDiffusionTransformer", "output_channels": 12, "time_token_cond": True, "width": 2048, }, "upsample": { "channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0], "channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255], "cond_ctx": 1024, "cond_drop_prob": 0.1, "heads": 8, "init_scale": 0.25, "input_channels": 6, "layers": 12, "n_ctx": 3072, "name": "CLIPImageGridUpsamplePointDiffusionTransformer", "output_channels": 12, "time_token_cond": True, "width": 512, }, "sdf": { "decoder_heads": 4, "decoder_layers": 4, "encoder_heads": 4, "encoder_layers": 8, "init_scale": 0.25, "n_ctx": 4096, "name": "CrossAttentionPointCloudSDFModel", "width": 256, }, } def model_from_config(config: Dict[str, Any], device: torch.device) -> nn.Module: config = config.copy() name = config.pop("name") if name == "PointDiffusionTransformer": return PointDiffusionTransformer(device=device, dtype=torch.float32, **config) elif name == "CLIPImagePointDiffusionTransformer": return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config) elif name == "CLIPImageGridPointDiffusionTransformer": return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config) elif name == "UpsamplePointDiffusionTransformer": return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config) elif name == "CLIPImageGridUpsamplePointDiffusionTransformer": return CLIPImageGridUpsamplePointDiffusionTransformer( device=device, dtype=torch.float32, **config ) elif name == "CrossAttentionPointCloudSDFModel": return CrossAttentionPointCloudSDFModel(device=device, dtype=torch.float32, **config) raise ValueError(f"unknown model name: {name}")