def _run_cell()

in src/model.py [0:0]


    def _run_cell(self, input, md_layer, hidden, w_ih, w_hh, b_ih, b_hh):
        """
        LSTM cell structure adapted from:
        github.com/pytorch/benchmark/blob/09eaadc1d05ad442b1f0beb82babf875bbafb24b/rnns/fastrnns/cells.py#L25-L40
        """

        hx, cx = hidden
        gates = torch.matmul(input, w_ih.t()) + torch.matmul(hx, w_hh.t()) +\
                md_layer + b_ih + b_hh

        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        if self.use_layernorm:
            ingate = self.layernorm(ingate)
            forgetgate = self.layernorm(forgetgate)
            cellgate = self.layernorm(cellgate)
            outgate = self.layernorm(outgate)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, cy