in occant_baselines/models/occant.py [0:0]
def _create_gp_models(self):
nmodes = 2
gp_cfg = self.config.GP_ANTICIPATION
# Compute constants
resnet_type = (
gp_cfg.resnet_type if hasattr(gp_cfg, "resnet_type") else "resnet50"
)
infeats = 768 if resnet_type == "resnet50" else 192
nsf = gp_cfg.unet_nsf
unet_feat_size = nsf * 8
# RGB encoder branch
self.gp_rgb_encoder = ResNetRGBEncoder(resnet_type)
self.gp_rgb_projector = LearnedRGBProjection(mtype="upsample", infeats=infeats)
self.gp_rgb_unet = MiniUNetEncoder(infeats, unet_feat_size)
# Depth projection estimator
config = self.config.clone()
self.gp_depth_proj_estimator = ANSRGB(config)
# Depth encoder branch
self.gp_depth_proj_encoder = UNetEncoder(2, nsf=nsf)
# Merge modules
self.gp_merge_x5 = MergeMultimodal(unet_feat_size, nmodes=nmodes)
self.gp_merge_x4 = MergeMultimodal(unet_feat_size, nmodes=nmodes)
self.gp_merge_x3 = MergeMultimodal(unet_feat_size // 2, nmodes=nmodes)
# Decoder module
self.gp_decoder = UNetDecoder(gp_cfg.nclasses, nsf=nsf)
self._detach_depth_proj = gp_cfg.detach_depth_proj
# Load pretrained model if available
if gp_cfg.pretrained_depth_proj_model != "":
self._load_pretrained_model(gp_cfg.pretrained_depth_proj_model)
if gp_cfg.freeze_features:
for p in self.gp_rgb_encoder.parameters():
p.requires_grad = False
if gp_cfg.freeze_depth_proj_model:
for p in self.gp_depth_proj_estimator.parameters():
p.requires_grad = False