def update()

in automl21/accel/aa.py [0:0]


    def update(self, fx, x, hidden):
        single = x.dim() == 1
        if single:
            x = x.unsqueeze(0)
            fx = fx.unsqueeze(0)

        assert x.dim() == fx.dim() == 2
        assert x.size(0) == fx.size(0)

        g = x - fx
        hidden.gs.append(g)

        k = len(hidden.xs) - 1
        if k > 0:
            hidden.ys.append(g - hidden.gs[-2])

            m_k = min(k, self.memory_size)
            ST = torch.stack(hidden.xs[-m_k:], dim=1) - \
                torch.stack(hidden.xs[-m_k-1:-1], dim=1)
            S = ST.transpose(1, 2)
            Y = torch.stack(hidden.ys[-m_k:], dim=1).transpose(1, 2)
            STY = ST.bmm(Y)
            STYinv_ST = ST.solve(STY).solution
            Binv = (S - Y).bmm(STYinv_ST)
            Binv.diagonal(dim1=1, dim2=2).add_(1.)
            x = x - Binv.bmm(g.unsqueeze(2)).squeeze(2)
        else:
            x = fx

        hidden.xs.append(x)

        if single:
            x = x.squeeze(0)
            g = g.squeeze(0)
        return x, g, hidden