in occant_baselines/models/occant.py [0:0]
def __init__(self, cfg):
super().__init__()
self.config = cfg
model_type = cfg.type
self._model_type = model_type
cfg.defrost()
if model_type == "ans_rgb":
self.main = ANSRGB(cfg)
elif model_type == "ans_depth":
self.main = ANSDepth(cfg)
elif model_type == "occant_rgb":
self.main = OccAntRGB(cfg)
elif model_type == "occant_depth":
self.main = OccAntDepth(cfg)
elif model_type == "occant_rgbd":
self.main = OccAntRGBD(cfg)
elif model_type == "occant_ground_truth":
self.main = OccAntGroundTruth(cfg)
else:
raise ValueError(f"Invalid model_type {model_type}")
cfg.freeze()