in automl21/accel/neural_aa.py [0:0]
def update(self, fx, x, hidden):
single = x.dim() == 1
if single:
x = x.unsqueeze(0)
fx = fx.unsqueeze(0)
assert x.dim() == fx.dim() == 2
assert x.size(0) == fx.size(0)
g = x - fx
x_fx_g = torch.cat((x, fx, g), dim=1)
z = self.enc(x_fx_g).unsqueeze(0)
o, (h, c) = self.cell(z, (hidden.lstm_h, hidden.lstm_c))
o = o.squeeze(0)
hidden.lstm_h.data = h
hidden.lstm_c.data = c
alpha = self.dec(o)
alpha = (alpha + 5.).sigmoid()
hidden.gs.append(g)
k = len(hidden.xs) - 1
if k > 0:
hidden.ys.append(g - hidden.gs[-2])
m_k = min(k, self.memory_size)
ST = torch.stack(hidden.xs[-m_k:], dim=1) - \
torch.stack(hidden.xs[-m_k-1:-1], dim=1)
S = ST.transpose(1, 2)
Y = torch.stack(hidden.ys[-m_k:], dim=1).transpose(1, 2)
STY = ST.bmm(Y)
STYinv_ST = ST.solve(STY).solution
Binv = (S - Y).bmm(STYinv_ST)
Binv.diagonal(dim1=1, dim2=2).add_(1.)
x = x - alpha * g - (1. - alpha)*Binv.bmm(g.unsqueeze(2)).squeeze(2)
else:
x = fx
hidden.xs.append(x)
if single:
x = x.squeeze(0)
g = g.squeeze(0)
return x, g, hidden