def get_heads()

in domainbed_measures/measures/held_out_measures.py [0:0]


    def get_heads(self,
                  num_batch_heads,
                  feat_dim,
                  num_labels,
                  criterion,
                  batch_size=256,
                  max_lr=0.0025,
                  lr_sweep_factor=0.5,
                  weight_decay_max=1e-1,
                  weight_decay_min=1e-4,
                  callbacks=None,
                  train_split=None):
        """Get heads for optimization.

        Set the maximum learning rate to start from for each head, and multiple
        heads then have a learning rate that is lr_sweep_factor^(i-1) * max_lr
        for the i'th head that we consider.
        """
        if callbacks == None:
            callbacks = []

        heads = []
        for hidx in range(num_batch_heads):
            lr = max_lr * lr_sweep_factor**(hidx)
            weight_decay = float(torch.multinomial(torch.logspace(1, -3, 5),
                                                   1))

            logging.info(
                f"Creating head {hidx} with lr {lr}, weight decay {weight_decay}"
            )
            if self._v_plus == True:
                this_head = nn.Sequential(
                    nn.Linear(feat_dim, int(feat_dim / 2)), nn.ReLU(),
                    nn.Linear(int(feat_dim / 2), int(feat_dim / 4)), nn.ReLU(),
                    nn.Linear(int(feat_dim / 4), int(feat_dim / 4)), nn.ReLU(),
                    nn.Linear(int(feat_dim / 4), num_labels))
            else:
                this_head = nn.Linear(feat_dim, num_labels)

            this_head = TrainEvalNeuralNet(
                module=this_head,
                device=self._device,
                criterion=criterion,
                optimizer=torch.optim.SGD,
                optimizer__lr=lr,
                optimizer__weight_decay=weight_decay,
                batch_size=batch_size,
                max_epochs=self._train_epochs,
                train_split=
                train_split,  # None means no validation, only training
                iterator_train__shuffle=True,
                callbacks=callbacks)

            heads.append(this_head)

        return heads