in domainbed_measures/measures/held_out_measures.py [0:0]
def _calculate_divergence_measure(self,
all_train_feats,
train_domain_labels,
all_held_out_feats,
heldout_domain_labels,
lr_decay_gamma,
num_head_batches,
max_lr,
lr_sweep_factor,
train_env_to_use,
train_val_split,
trainval_test_split=0.8):
all_feats, all_labels = self.prepare_c2st_datasets(
all_train_feats, train_domain_labels, all_held_out_feats,
heldout_domain_labels)
feat_dim = all_feats.shape[-1]
logging.info("Obtaining heads")
callbacks = [
skorch.callbacks.LRScheduler(
torch.optim.lr_scheduler.StepLR,
gamma=lr_decay_gamma,
step_size=self._train_epochs / 2,
),
skorch.callbacks.EpochScoring(
self.accuracy_fn,
lower_is_better=False,
name='val_accuracy',
),
skorch.callbacks.EpochScoring(
self.accuracy_fn,
lower_is_better=False,
name='train_accuracy',
on_train=True,
),
skorch.callbacks.EarlyStopping(
monitor='val_accuracy',
patience=15,
threshold=0.0001,
threshold_mode='rel',
lower_is_better=False,
)
]
heads = self.get_heads(
num_head_batches,
feat_dim=feat_dim,
criterion=nn.CrossEntropyLoss,
num_labels=2,
max_lr=max_lr,
lr_sweep_factor=lr_sweep_factor,
train_split=skorch.dataset.CVSplit(train_val_split),
batch_size=self._algorithm.hparams['batch_size'],
callbacks=callbacks)
val_accuracies = []
train_accuracies = []
for hidx, h in enumerate(heads):
logging.info("Fitting head %d/%d" % (hidx, len(heads)))
train_val_idx = int(trainval_test_split * all_feats.shape[0])
h.fit(all_feats[:train_val_idx, :], all_labels[:train_val_idx])
val_accuracies.append([x['val_accuracy'] for x in h.history][-1])
train_accuracies.append(
max([x['train_accuracy'] for x in h.history]))
best_model_idx = np.argmax(val_accuracies)
best_gen_accuracy = (heads[best_model_idx].accuracy(
all_feats[train_val_idx:, :], all_labels[train_val_idx:]))
return (self.convert_domain_classifier_accuracy_to_divergence(
best_gen_accuracy),
self.convert_domain_classifier_accuracy_to_divergence(
max(train_accuracies)))