in domainbed_measures/measures/classical.py [0:0]
def _calculate_measure(self, sigma_max=2.0, sigma_min=0.0):
"""
Compute the sharpness magnitude 1/alpha'^2 described in [1].
Notes
-----
- This is slightly different than [1] because the target deviation is
on cross-entropy instead of accuracy
Args:
sigma_max: float, optional
sigma_min: float, optional
Minimum standard deviation of perturbation.
"""
trainer = clone_trainer(self._trainer_current)
trainer.criterion = self._measure_criterion
trainer.initialize()
acc = self.accuracy_trainer(trainer, self._train_loader)
logging.info(f"Accuracy of original model: {acc}")
for bin_search in range(self._max_binary_search):
sigma_min, sigma_max = self.get_sharp_mag_interval(
trainer,
acc,
sigma_min,
sigma_max,
)
if sigma_min > sigma_max or math.isclose(
sigma_min, sigma_max, rel_tol=1e-2):
# if interval for binary search is very small stop
break
if bin_search == self._max_binary_search - 1:
logging.info(
f"Stopped early beacuase reached max_binary_search={self._max_binary_search}.\
[sigma_min,sigma_max]=[{sigma_min},{sigma_max}]")
return 1 / (sigma_max**2), {}