in torchmoji/lstm.py [0:0]
def __setstate__(self, d):
super(LSTMHardSigmoid, self).__setstate__(d)
self.__dict__.setdefault('_data_ptrs', [])
if 'all_weights' in d:
self._all_weights = d['all_weights']
if isinstance(self._all_weights[0][0], str):
return
num_layers = self.num_layers
num_directions = 2 if self.bidirectional else 1
self._all_weights = []
for layer in range(num_layers):
for direction in range(num_directions):
suffix = '_reverse' if direction == 1 else ''
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
weights = [x.format(layer, suffix) for x in weights]
if self.bias:
self._all_weights += [weights]
else:
self._all_weights += [weights[:2]]