def update()

in automl21/accel/neural_aa.py [0:0]


    def update(self, fx, x, hidden):
        single = x.dim() == 1
        if single:
            x = x.unsqueeze(0)
            fx = fx.unsqueeze(0)

        assert x.dim() == fx.dim() == 2
        assert x.size(0) == fx.size(0)

        g = x - fx
        x_fx_g = torch.cat((x, fx, g), dim=1)
        z = self.enc(x_fx_g).unsqueeze(0)
        o, (h, c) = self.cell(z, (hidden.lstm_h, hidden.lstm_c))
        o = o.squeeze(0)
        hidden.lstm_h.data = h
        hidden.lstm_c.data = c
        alpha = self.dec(o)
        alpha = (alpha + 5.).sigmoid()

        hidden.gs.append(g)

        k = len(hidden.xs) - 1
        if k > 0:
            hidden.ys.append(g - hidden.gs[-2])

            m_k = min(k, self.memory_size)
            ST = torch.stack(hidden.xs[-m_k:], dim=1) - \
              torch.stack(hidden.xs[-m_k-1:-1], dim=1)
            S = ST.transpose(1, 2)
            Y = torch.stack(hidden.ys[-m_k:], dim=1).transpose(1, 2)
            STY = ST.bmm(Y)
            STYinv_ST = ST.solve(STY).solution
            Binv = (S - Y).bmm(STYinv_ST)
            Binv.diagonal(dim1=1, dim2=2).add_(1.)

            x = x - alpha * g - (1. - alpha)*Binv.bmm(g.unsqueeze(2)).squeeze(2)
        else:
            x = fx

        hidden.xs.append(x)

        if single:
            x = x.squeeze(0)
            g = g.squeeze(0)

        return x, g, hidden