in point_e/models/configs.py [0:0]
def model_from_config(config: Dict[str, Any], device: torch.device) -> nn.Module:
config = config.copy()
name = config.pop("name")
if name == "PointDiffusionTransformer":
return PointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImagePointDiffusionTransformer":
return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImageGridPointDiffusionTransformer":
return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "UpsamplePointDiffusionTransformer":
return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImageGridUpsamplePointDiffusionTransformer":
return CLIPImageGridUpsamplePointDiffusionTransformer(
device=device, dtype=torch.float32, **config
)
elif name == "CrossAttentionPointCloudSDFModel":
return CrossAttentionPointCloudSDFModel(device=device, dtype=torch.float32, **config)
raise ValueError(f"unknown model name: {name}")