def _calculate_divergence_measure()

in domainbed_measures/measures/held_out_measures.py [0:0]


    def _calculate_divergence_measure(self,
                                      all_train_feats,
                                      train_domain_labels,
                                      all_held_out_feats,
                                      heldout_domain_labels,
                                      lr_decay_gamma,
                                      num_head_batches,
                                      max_lr,
                                      lr_sweep_factor,
                                      train_env_to_use,
                                      train_val_split,
                                      trainval_test_split=0.8):

        if (train_domain_labels - 1).sum() != 0:
            raise ValueError(
                "Train domain labels must be encoded with label 1")

        if (heldout_domain_labels).sum() != 0:
            raise ValueError(
                "Held out domain labels must be encoded with label 0")

        feat_dim = all_train_feats.shape[-1]
        all_train_feats, train_domain_labels = permute_dataset(
            all_train_feats, train_domain_labels)
        all_held_out_feats, heldout_domain_labels = permute_dataset(
            all_held_out_feats, heldout_domain_labels)

        num_data = min(all_train_feats.shape[0], all_held_out_feats.shape[0])
        all_train_feats = all_train_feats[:num_data, :]
        train_domain_labels = train_domain_labels[:num_data]
        all_held_out_feats = all_held_out_feats[:num_data, :]
        heldout_domain_labels = heldout_domain_labels[:num_data]

        all_feats = torch.vstack([all_train_feats, all_held_out_feats])
        all_labels = torch.hstack([train_domain_labels, heldout_domain_labels])
        all_feats, all_labels = permute_dataset(all_feats, all_labels)

        callbacks = [
            skorch.callbacks.LRScheduler(
                torch.optim.lr_scheduler.StepLR,
                gamma=lr_decay_gamma,
                step_size=self._train_epochs / 2,
            ),
            skorch.callbacks.EpochScoring(
                self.hdh_accuracy_fn,
                lower_is_better=False,
                name='val_divergence',
            ),
            skorch.callbacks.EpochScoring(
                self.hdh_accuracy_fn,
                lower_is_better=False,
                name='train_divergence',
                on_train=True,
            ),
            skorch.callbacks.EarlyStopping(
                monitor='val_divergence',
                patience=15,
                threshold=0.0001,
                threshold_mode='rel',
                lower_is_better=False,
            )
        ]

        heads = self.get_hdh_heads(
            num_head_batches,
            feat_dim=feat_dim,
            criterion=NegHDelHCriterion,
            num_labels=self._num_classes,
            max_lr=max_lr,
            lr_sweep_factor=lr_sweep_factor,
            train_split=skorch.dataset.CVSplit(train_val_split),
            batch_size=self._algorithm.hparams['batch_size'],
            callbacks=callbacks)

        val_divergence = []
        train_divergence = []
        for hidx, h in enumerate(heads):
            logging.info("Fitting head %d/%d" % (hidx, len(heads)))
            train_val_idx = int(trainval_test_split * all_feats.shape[0])
            h.fit(all_feats[:train_val_idx, :], all_labels[:train_val_idx])
            val_divergence.append([x['val_divergence'] for x in h.history][-1])
            train_divergence.append(
                max([x['train_divergence'] for x in h.history]))

        best_model_idx = np.argmax(val_divergence)
        h_del_h_divergence = self.hdh_accuracy_fn(
            heads[best_model_idx],
            torch.utils.data.TensorDataset(all_feats[train_val_idx:, :],
                                           all_labels[train_val_idx:]))

        return h_del_h_divergence, max(train_divergence)