in domainbed_measures/measures/held_out_measures.py [0:0]
def _calculate_measure_base(self,
num_head_batches=8,
max_lr=0.025,
lr_sweep_factor=0.5,
lr_decay_gamma=0.1,
stochastic_fraction_data=1.0,
train_env_to_use=None,
train_val_split=0.7):
if train_env_to_use == None:
train_env_to_use = self._train_loader
featurizer = self._trainer_current.get_featurizer()
_freeze_params(featurizer)
logging.info("Precomputing features for train")
all_train_feats, all_train_labels, _, _ = self.compute_features(
train_env_to_use, featurizer, self._device)
train_domain_labels = torch.ones_like(all_train_labels)
logging.info("Precomputing features for held out")
all_held_out_feats, all_held_out_labels, _, _ = self.compute_features(
self._union_held_out_loader, featurizer, self._device)
heldout_domain_labels = torch.zeros_like(all_held_out_labels)
all_train_feats, all_train_labels = permute_dataset(
all_train_feats, all_train_labels)
data_idx = int(stochastic_fraction_data * all_train_feats.shape[0])
all_train_feats = all_train_feats[:data_idx, :]
all_train_labels = all_train_labels[:data_idx]
all_held_out_feats, all_held_out_labels = permute_dataset(
all_held_out_feats, all_held_out_labels)
data_idx = int(stochastic_fraction_data * all_held_out_feats.shape[0])
all_held_out_feats = all_held_out_feats[:data_idx, :]
all_held_out_labels = all_held_out_labels[:data_idx]
heldout_measure, train_measure = self._calculate_divergence_measure(
all_train_feats, train_domain_labels, all_held_out_feats,
heldout_domain_labels, lr_decay_gamma, num_head_batches, max_lr,
lr_sweep_factor, train_env_to_use, train_val_split)
# Compute lambda closenes
lambda_closeness = self._calculate_lambda_closeness(
all_train_feats, all_train_labels, all_held_out_feats,
all_held_out_labels, lr_decay_gamma, num_head_batches, max_lr,
lr_sweep_factor, train_env_to_use, train_val_split)
return heldout_measure, train_measure, lambda_closeness