def __init__()

in archs/models.py [0:0]


    def __init__(self,
                 dset,
                 args,
                 num_layers=2,
                 num_modules_per_layer=3,
                 stoch_sample=False,
                 use_full_model=False,
                 num_classes=[2],
                 gater_type='general'):
        """TODO: to be defined1.

        :dset: TODO
        :args: TODO

        """
        CompositionalModel.__init__(self, dset, args)

        self.train_forward = self.train_forward_softmax
        self.compose_type = args.compose_type

        gating_in_dim = 128
        if args.glove_init:
            gating_in_dim = 300
        elif args.clf_init:
            gating_in_dim = 512

        if self.compose_type == 'nn':
            tdim = gating_in_dim * 2
            inter_tdim = self.args.embed_rank
            # Change this to allow only obj, only attr gatings
            self.attr_embedder = nn.Embedding(
                len(dset.attrs) + 1,
                gating_in_dim,
                padding_idx=len(dset.attrs),
            )
            self.obj_embedder = nn.Embedding(
                len(dset.objs) + 1,
                gating_in_dim,
                padding_idx=len(dset.objs),
            )

            # initialize the weights of the embedders with the svm weights
            if args.glove_init:
                pretrained_weight = utils.load_word_embeddings(
                    'data/glove/glove.6B.300d.txt', dset.attrs)
                self.attr_embedder.weight[:-1, :].data.copy_(pretrained_weight)
                pretrained_weight = utils.load_word_embeddings(
                    'data/glove/glove.6B.300d.txt', dset.objs)
                self.obj_embedder.weight.data[:-1, :].copy_(pretrained_weight)
            elif args.clf_init:
                for idx, attr in enumerate(dset.attrs):
                    at_id = self.dset.attr2idx[attr]
                    weight = torch.load(
                        '%s/svm/attr_%d' % (args.data_dir,
                                            at_id)).coef_.squeeze()
                    self.attr_embedder.weight[idx].data.copy_(
                        torch.from_numpy(weight))
                for idx, obj in enumerate(dset.objs):
                    obj_id = self.dset.obj2idx[obj]
                    weight = torch.load(
                        '%s/svm/obj_%d' % (args.data_dir,
                                           obj_id)).coef_.squeeze()
                    self.obj_embedder.weight[idx].data.copy_(
                        torch.from_numpy(weight))
            else:
                n_attr = len(dset.predicates)
                gating_in_dim = 300
                tdim = gating_in_dim * 2 + n_attr
                self.attr_embedder = nn.Embedding(
                    n_attr,
                    n_attr,
                )
                self.attr_embedder.weight.data.copy_(
                    torch.from_numpy(np.eye(n_attr)))
                self.obj_embedder = nn.Embedding(
                    len(dset.objs) + 1,
                    gating_in_dim,
                    padding_idx=len(dset.objs),
                )
                pretrained_weight = utils.load_word_embeddings(
                    'data/glove/glove.6B.300d.txt', dset.objs)
                self.obj_embedder.weight.data[:-1, :].copy_(pretrained_weight)
        else:
            raise (NotImplementedError)

        self.comp_network, self.gating_network, self.nummods, _ = modular_general(
            num_layers=num_layers,
            num_modules_per_layer=num_modules_per_layer,
            feat_dim=dset.feat_dim,
            inter_dim=args.emb_dim,
            stoch_sample=stoch_sample,
            use_full_model=use_full_model,
            tdim=tdim,
            inter_tdim=inter_tdim,
            gater_type=gater_type,
        )

        if args.static_inp:
            for param in self.attr_embedder.parameters():
                param.requires_grad = False
            for param in self.obj_embedder.parameters():
                param.requires_grad = False