in domainbed_measures/measures/held_out_measures.py [0:0]
def get_cond_v_entropy(self, heads, all_feats, all_idx, label_idx):
# Do the optimization for each of the heads
head_losses = []
accumulated_head_epoch_losses = []
for head_idx, head in enumerate(heads):
# Recompute the features for each head, so that we are
# closer to the stochastic setting for things like
# dropout and so on.
logging.info(f"Processing head {head_idx}/{len(heads)}")
head.initialize()
head.notify('on_train_begin')
accumulated_loss = 0
for ep in range(self._train_epochs):
if self._recompute_features_every_epoch == True:
# Only way this can work is if shuffling is off in the loader
logging.info(f"Recomputing features..")
all_feats, all_idx = self.get_v_min_data(label_idx)
all_this_head_labels = torch.Tensor(
self._base_rep_ith_digit(all_idx.cpu().numpy(),
head_idx)).to(
self._device).long()
if all_feats.shape[0] != all_this_head_labels.shape[0]:
raise ValueError("Shapes must match")
dtrain, dval = head.get_split_datasets(all_feats,
all_this_head_labels)
on_epoch_kwargs = {
"dataset_train": dtrain,
"dataset_valid": dval
}
head.notify("on_epoch_begin", **on_epoch_kwargs)
head.run_single_epoch(dtrain,
training=True,
prefix="train",
step_fn=head.train_step)
head.notify("on_epoch_end", **on_epoch_kwargs)
accumulated_loss += head.mean_train_loss(
all_feats, all_this_head_labels)
# Run evaluation on the training set to find out the
# loss on the training overall
head_losses.append(
head.mean_train_loss(all_feats, all_this_head_labels))
accumulated_head_epoch_losses.append(accumulated_loss)
return np.mean(head_losses), np.mean(accumulated_head_epoch_losses)