def _create_gp_models()

in occant_baselines/models/occant.py [0:0]


    def _create_gp_models(self):
        nmodes = 2
        gp_cfg = self.config.GP_ANTICIPATION

        # Compute constants
        resnet_type = (
            gp_cfg.resnet_type if hasattr(gp_cfg, "resnet_type") else "resnet50"
        )
        infeats = 768 if resnet_type == "resnet50" else 192
        nsf = gp_cfg.unet_nsf
        unet_feat_size = nsf * 8

        # RGB encoder branch
        self.gp_rgb_encoder = ResNetRGBEncoder(resnet_type)
        self.gp_rgb_projector = LearnedRGBProjection(mtype="upsample", infeats=infeats)
        self.gp_rgb_unet = MiniUNetEncoder(infeats, unet_feat_size)

        # Depth projection estimator
        config = self.config.clone()
        self.gp_depth_proj_estimator = ANSRGB(config)

        # Depth encoder branch
        self.gp_depth_proj_encoder = UNetEncoder(2, nsf=nsf)

        # Merge modules
        self.gp_merge_x5 = MergeMultimodal(unet_feat_size, nmodes=nmodes)
        self.gp_merge_x4 = MergeMultimodal(unet_feat_size, nmodes=nmodes)
        self.gp_merge_x3 = MergeMultimodal(unet_feat_size // 2, nmodes=nmodes)

        # Decoder module
        self.gp_decoder = UNetDecoder(gp_cfg.nclasses, nsf=nsf)

        self._detach_depth_proj = gp_cfg.detach_depth_proj

        # Load pretrained model if available
        if gp_cfg.pretrained_depth_proj_model != "":
            self._load_pretrained_model(gp_cfg.pretrained_depth_proj_model)

        if gp_cfg.freeze_features:
            for p in self.gp_rgb_encoder.parameters():
                p.requires_grad = False
        if gp_cfg.freeze_depth_proj_model:
            for p in self.gp_depth_proj_estimator.parameters():
                p.requires_grad = False