tinynn/graph/quantization/quantizable/lstm.py (23 lines of code) (raw):

from distutils.version import LooseVersion import torch if LooseVersion(torch.__version__) >= '1.13.0': from torch.ao.nn.quantizable.modules.rnn import _LSTMLayer @classmethod def from_float(cls, other, qconfig=None): assert isinstance(other, cls._FLOAT_MODULE) assert hasattr(other, 'qconfig') or qconfig observed = cls( other.input_size, other.hidden_size, other.num_layers, other.bias, other.batch_first, other.dropout, other.bidirectional, ) observed.qconfig = getattr(other, 'qconfig', qconfig) for idx in range(other.num_layers): observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig, batch_first=False) observed.train() observed = torch.ao.quantization.prepare_qat(observed, inplace=True) return observed