in src/model.py [0:0]
def _run_cell(self, input, md_layer, hidden, w_ih, w_hh, b_ih, b_hh):
"""
LSTM cell structure adapted from:
github.com/pytorch/benchmark/blob/09eaadc1d05ad442b1f0beb82babf875bbafb24b/rnns/fastrnns/cells.py#L25-L40
"""
hx, cx = hidden
gates = torch.matmul(input, w_ih.t()) + torch.matmul(hx, w_hh.t()) +\
md_layer + b_ih + b_hh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
if self.use_layernorm:
ingate = self.layernorm(ingate)
forgetgate = self.layernorm(forgetgate)
cellgate = self.layernorm(cellgate)
outgate = self.layernorm(outgate)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy