def __init__()

in agents/nets.py [0:0]


    def __init__(self,
                 action_size,
                 action_layers=1,
                 action_hidden_size=256,
                 fusion_place='last',
                 embedding_dim=None,
                 repr_merging_method=None,
                 n_regressor_outputs=None,
                 embeddor_type=None,
                 backbone='resnet18'):
        super().__init__()
        if backbone == "resnet18":
            net = torchvision.models.resnet18(pretrained=False)
            self._resnet_embedding_dim = 512
            n_channels = (64, 64, 128, 256)
        elif backbone == 'resnet50':
            net = torchvision.models.resnet50(pretrained=False)
            self._resnet_embedding_dim = 2048
            n_channels = (64, 256, 512, 1024)
        else:
            raise ValueError("Invalid backbone: " + backbone)
        conv1 = nn.Conv2d(phyre.NUM_COLORS,
                          64,
                          kernel_size=7,
                          stride=2,
                          padding=3,
                          bias=False)
        self.register_buffer('embed_weights', torch.eye(phyre.NUM_COLORS))
        self.stem = nn.Sequential(conv1, net.bn1, net.relu, net.maxpool)
        self.stages = nn.ModuleList(
            [net.layer1, net.layer2, net.layer3, net.layer4])

        def build_film(output_size):
            return FilmActionNetwork(action_size,
                                     output_size,
                                     hidden_size=action_hidden_size,
                                     num_layers=action_layers)

        assert fusion_place in ('first', 'last', 'all', 'none', 'last_single')

        self.last_network = None
        if fusion_place == 'all':
            self.action_networks = nn.ModuleList(
                [build_film(size) for size in n_channels])
        elif fusion_place == 'last':
            # Save module as attribute.
            self._action_network = build_film(n_channels[3])
            self.action_networks = nn.ModuleList(
                [None, None, None, self._action_network])
        elif fusion_place == 'first':
            # Save module as attribute.
            self._action_network = build_film(n_channels[0])
            self.action_networks = nn.ModuleList(
                [self._action_network, None, None, None])
        elif fusion_place == 'last_single':
            # Save module as attribute.
            self.last_network = build_film(self._resnet_embedding_dim)
            self.action_networks = [None, None, None, None]
        elif fusion_place == 'none':
            self.action_networks = [None, None, None, None]
        else:
            raise Exception('Unknown fusion place: %s' % fusion_place)
        self.reason = nn.Linear(self._resnet_embedding_dim, 1)

        if embeddor_type == "linear":
            self.embeddor = nn.Linear(self._resnet_embedding_dim, embedding_dim)
        elif embeddor_type == "mlp":
            self.embeddor = nn.Sequential(
                nn.Linear(self._resnet_embedding_dim, embedding_dim), nn.ReLU(),
                nn.Linear(embedding_dim, embedding_dim))
        elif embeddor_type == "mlp_2hidden":
            self.embeddor = nn.Sequential(
                nn.Linear(self._resnet_embedding_dim, embedding_dim), nn.ReLU(),
                nn.Linear(embedding_dim, embedding_dim), nn.ReLU(),
                nn.Linear(embedding_dim, embedding_dim))
        elif embeddor_type == "none":
            self.embeddor = lambda x: x  # an indentity function
            embedding_dim = self._resnet_embedding_dim
        else:
            if embeddor_type is not None:
                raise ValueError(f"unknown embeddor type {embeddor_type}")
            self.embeddor = None

        self.repr_merging_method = repr_merging_method
        self.n_regressor_outputs = n_regressor_outputs
        if repr_merging_method == "mul":
            self.distance_regressor = nn.Linear(embedding_dim,
                                                n_regressor_outputs)
        elif repr_merging_method == "concat":
            self.distance_regressor = nn.Linear(embedding_dim * 2,
                                                n_regressor_outputs)
        elif repr_merging_method == "outer":
            self.distance_regressor = nn.Linear(embedding_dim**2,
                                                n_regressor_outputs)
        else:
            self.distance_regressor = None