in models/src/ptflops/flops_counter.py [0:0]
def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
# matrix matrix mult ih state and internal state
flops += w_ih.shape[0] * w_ih.shape[1]
# matrix matrix mult hh state and internal state
flops += w_hh.shape[0] * w_hh.shape[1]
if isinstance(rnn_module, (nn.RNN, nn.RNNCell)):
# add both operations
flops += rnn_module.hidden_size
elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)):
# hadamard of r
flops += rnn_module.hidden_size
# adding operations from both states
flops += rnn_module.hidden_size * 3
# last two hadamard product and add
flops += rnn_module.hidden_size * 3
elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)):
# adding operations from both states
flops += rnn_module.hidden_size * 4
# two hadamard product and add for C state
flops += (
rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
)
# final hadamard
flops += (
rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
)
return flops