def mask_step()

in scripts/models.py [0:0]


    def mask_step(self, losses, parameters, tau=0.9, wd=0.1, lr=1e-3):
        with torch.no_grad():
            gradients = []
            for loss in losses:
                gradients.append(list(torch.autograd.grad(loss, parameters)))
                gradients[-1][0] = gradients[-1][0] / gradients[-1][0].norm()

            for ge_all, parameter in zip(zip(*gradients), parameters):
                # environment-wise gradients (num_environments x num_parameters)
                ge_cat = torch.cat(ge_all)

                # treat scalar parameters also as matrices
                if ge_cat.dim() == 1:
                    ge_cat = ge_cat.view(len(losses), -1)

                # creates a mask with zeros on weak features
                mask = (torch.abs(torch.sign(ge_cat).sum(0))
                        > len(losses) * tau).int()

                # mean gradient (1 x num_parameters)
                g_mean = ge_cat.mean(0, keepdim=True)

                # apply the mask
                g_masked = mask * g_mean

                # update
                parameter.data = parameter.data - lr * g_masked \
                    - lr * wd * parameter.data