def _calculate_measure_base()

in domainbed_measures/measures/held_out_measures.py [0:0]


    def _calculate_measure_base(self,
                                num_head_batches=8,
                                max_lr=0.025,
                                lr_sweep_factor=0.5,
                                lr_decay_gamma=0.1,
                                stochastic_fraction_data=1.0,
                                train_env_to_use=None,
                                train_val_split=0.7):
        if train_env_to_use == None:
            train_env_to_use = self._train_loader

        featurizer = self._trainer_current.get_featurizer()
        _freeze_params(featurizer)

        logging.info("Precomputing features for train")
        all_train_feats, all_train_labels, _, _ = self.compute_features(
            train_env_to_use, featurizer, self._device)
        train_domain_labels = torch.ones_like(all_train_labels)

        logging.info("Precomputing features for held out")
        all_held_out_feats, all_held_out_labels, _, _ = self.compute_features(
            self._union_held_out_loader, featurizer, self._device)
        heldout_domain_labels = torch.zeros_like(all_held_out_labels)

        all_train_feats, all_train_labels = permute_dataset(
            all_train_feats, all_train_labels)
        data_idx = int(stochastic_fraction_data * all_train_feats.shape[0])
        all_train_feats = all_train_feats[:data_idx, :]
        all_train_labels = all_train_labels[:data_idx]

        all_held_out_feats, all_held_out_labels = permute_dataset(
            all_held_out_feats, all_held_out_labels)
        data_idx = int(stochastic_fraction_data * all_held_out_feats.shape[0])
        all_held_out_feats = all_held_out_feats[:data_idx, :]
        all_held_out_labels = all_held_out_labels[:data_idx]

        heldout_measure, train_measure = self._calculate_divergence_measure(
            all_train_feats, train_domain_labels, all_held_out_feats,
            heldout_domain_labels, lr_decay_gamma, num_head_batches, max_lr,
            lr_sweep_factor, train_env_to_use, train_val_split)

        # Compute lambda closenes
        lambda_closeness = self._calculate_lambda_closeness(
            all_train_feats, all_train_labels, all_held_out_feats,
            all_held_out_labels, lr_decay_gamma, num_head_batches, max_lr,
            lr_sweep_factor, train_env_to_use, train_val_split)

        return heldout_measure, train_measure, lambda_closeness