in domainbed_measures/measures/held_out_measures.py [0:0]
def _calculate_measure(self,
num_head_batches=10,
max_lr=0.10,
lr_sweep_factor=0.5,
batch_size=256):
"""Calculate v-minimality.
Args:
num_head_batches: Number of batches of heads to optimize to improve the
estimate of v-minimality
"""
# Parameters related to learning
#lr = self._algorithm.hparams['lr']
#batch_size = self._algorithm.hparams['batch_size']
callbacks = [
LRScheduler(
torch.optim.lr_scheduler.ExponentialLR,
gamma=get_exponential_decay_gamma(100, self._train_epochs),
)
]
if self._cond_min == True:
labels_to_process = range(self._num_classes)
else:
labels_to_process = [-1]
v_entropy_x_z = []
for label_idx in labels_to_process:
# Prepare data per label
task_feats, task_idx = self.get_v_min_data(label_idx)
feat_dim = task_feats.shape[-1]
# Get the heads for optimization
num_heads = self.get_num_heads(int(torch.max(task_idx)))
cond_v_entropy_across_batches = []
accumulated_v_entropy_across_batches = []
for batch_idx in range(num_head_batches):
this_heads = (self.get_reinit_heads(
num_heads,
trained_classifier=(
self._trainer_current.get_classifier()),
feat_dim=feat_dim,
lr=max_lr * lr_sweep_factor**batch_idx,
batch_size=batch_size,
callbacks=callbacks))
logging.info(
f"******** Batch of heads {batch_idx}/{num_head_batches}**********"
)
cond_v_entropy, accumulated_v_entropy = (
self.get_cond_v_entropy(this_heads, task_feats, task_idx,
label_idx))
cond_v_entropy_across_batches.append(cond_v_entropy)
accumulated_v_entropy_across_batches.append(
accumulated_v_entropy)
best_head_idx = np.argmin(accumulated_v_entropy_across_batches)
v_entropy_x_z.append(cond_v_entropy_across_batches[best_head_idx])
v_entropy_x_z = np.mean(v_entropy_x_z)
return np.log(self._num_classes) - v_entropy_x_z