models/compressive.py [200:211]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if self.args.pre_norm:
            h2 = h + attn_out  # B x M x H
            ff_out = self.ff(self.norm2(h2))
            out = h2 + ff_out  # B x M x H
        else:
            h2 = self.norm1(h + attn_out)  # B x M x H
            ff_out = self.ff(h2)
            out = self.norm2(h2 + ff_out)  # B x M x H

        return out, aux_loss

    def get_cache_size(self):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



models/transformer_seq.py [170:181]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if self.args.pre_norm:
            h2 = h + attn_out  # B x M x H
            ff_out = self.ff(self.norm2(h2))
            out = h2 + ff_out  # B x M x H
        else:
            h2 = self.norm1(h + attn_out)  # B x M x H
            ff_out = self.ff(h2)
            out = self.norm2(h2 + ff_out)  # B x M x H

        return out, aux_loss

    def get_cache_size(self):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



