in scripts/models.py [0:0]
def mask_step(self, losses, parameters, tau=0.9, wd=0.1, lr=1e-3):
with torch.no_grad():
gradients = []
for loss in losses:
gradients.append(list(torch.autograd.grad(loss, parameters)))
gradients[-1][0] = gradients[-1][0] / gradients[-1][0].norm()
for ge_all, parameter in zip(zip(*gradients), parameters):
# environment-wise gradients (num_environments x num_parameters)
ge_cat = torch.cat(ge_all)
# treat scalar parameters also as matrices
if ge_cat.dim() == 1:
ge_cat = ge_cat.view(len(losses), -1)
# creates a mask with zeros on weak features
mask = (torch.abs(torch.sign(ge_cat).sum(0))
> len(losses) * tau).int()
# mean gradient (1 x num_parameters)
g_mean = ge_cat.mean(0, keepdim=True)
# apply the mask
g_masked = mask * g_mean
# update
parameter.data = parameter.data - lr * g_masked \
- lr * wd * parameter.data