in tinynn/graph/modifier.py [0:0]
def update_weight_metric(importance, metric_func, module, name):
if type(module) in [nn.Linear, nn.Conv2d, nn.Conv1d, nn.ConvTranspose2d, nn.ConvTranspose1d]:
importance[name] = metric_func(module.weight, module)
elif type(module) in [nn.GRU, nn.LSTM, nn.RNN]:
num_directions = 2 if module.bidirectional else 1
has_proj = hasattr(module, 'proj_size') and module.proj_size > 0
gs = rnn_gate_size(module)
weights = []
if has_proj:
for i in range(module.num_layers):
weight_hrs = []
for j in range(num_directions):
suffix = '_reverse' if j > 0 else ''
weight_hr = getattr(module, f'weight_hr_l{i}{suffix}')
weight_hrs.append(weight_hr)
weights.append(torch.cat(weight_hrs, dim=0))
importance[name] = metric_func(weights, module)
weights.clear()
name = f'{name}:h'
for i in range(module.num_layers):
weight_ihs = []
weight_hhs = []
for j in range(num_directions):
suffix = '_reverse' if j > 0 else ''
weight_ih = getattr(module, f'weight_ih_l{i}{suffix}')
weight_hh = getattr(module, f'weight_hh_l{i}{suffix}')
weight_ihs.append(weight_ih)
weight_hhs.append(weight_hh)
if gs == 1:
weights.append(torch.cat(weight_ihs, dim=0))
weights.append(torch.cat(weight_hhs, dim=0))
else:
w_ih_splits = zip(*[torch.unbind(x.view(gs, module.hidden_size, -1)) for x in weight_ihs])
w_hh_splits = zip(*[torch.unbind(x.view(gs, module.hidden_size, -1)) for x in weight_hhs])
ih_gate_weights = [torch.cat(x) for x in w_ih_splits]
hh_gate_weights = [torch.cat(x) for x in w_hh_splits]
weights.extend(ih_gate_weights)
weights.extend(hh_gate_weights)
importance[name] = metric_func(weights, module)
else:
raise AttributeError(f'{type(module).__name__}({name}) is not supported for importance calculation')