def _create_gp_models()

in occant_baselines/models/occant.py [0:0]


    def _create_gp_models(self):
        resnet = tmodels.resnet18(pretrained=True)
        self.main = nn.Sequential(  # (3, 128, 128)
            # Feature extraction
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,  # (512, 4, 4)
            # FC layers equivalent
            nn.Conv2d(512, 512, 1),  # (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 1),  # (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(),
            # Upsampling
            nn.Conv2d(512, 256, 3, padding=1),  # (256, 4, 4)
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(
                scale_factor=2, mode="bilinear", align_corners=True
            ),  # (256, 8, 8)
            nn.Conv2d(256, 128, 3, padding=1),  # (128, 8, 8)
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample(
                scale_factor=2, mode="bilinear", align_corners=True
            ),  # (128, 16, 16),
            nn.Conv2d(128, 64, 3, padding=1),  # (64, 16, 16)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Upsample(
                scale_factor=2, mode="bilinear", align_corners=True
            ),  # (64, 32, 32),
            nn.Conv2d(64, 32, 3, padding=1),  # (32, 32, 32)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Upsample(
                scale_factor=2, mode="bilinear", align_corners=True
            ),  # (32, 64, 64),
            nn.Conv2d(32, 2, 3, padding=1),  # (2, 64, 64)
            nn.Upsample(
                scale_factor=2, mode="bilinear", align_corners=True
            ),  # (2, 128, 128),
        )