def update_weight_metric()

in tinynn/graph/modifier.py [0:0]


def update_weight_metric(importance, metric_func, module, name):
    if type(module) in [nn.Linear, nn.Conv2d, nn.Conv1d, nn.ConvTranspose2d, nn.ConvTranspose1d]:
        importance[name] = metric_func(module.weight, module)
    elif type(module) in [nn.GRU, nn.LSTM, nn.RNN]:
        num_directions = 2 if module.bidirectional else 1
        has_proj = hasattr(module, 'proj_size') and module.proj_size > 0

        gs = rnn_gate_size(module)

        weights = []

        if has_proj:
            for i in range(module.num_layers):
                weight_hrs = []

                for j in range(num_directions):
                    suffix = '_reverse' if j > 0 else ''
                    weight_hr = getattr(module, f'weight_hr_l{i}{suffix}')
                    weight_hrs.append(weight_hr)

                weights.append(torch.cat(weight_hrs, dim=0))

            importance[name] = metric_func(weights, module)

            weights.clear()
            name = f'{name}:h'

        for i in range(module.num_layers):
            weight_ihs = []
            weight_hhs = []

            for j in range(num_directions):
                suffix = '_reverse' if j > 0 else ''
                weight_ih = getattr(module, f'weight_ih_l{i}{suffix}')
                weight_hh = getattr(module, f'weight_hh_l{i}{suffix}')

                weight_ihs.append(weight_ih)
                weight_hhs.append(weight_hh)

            if gs == 1:
                weights.append(torch.cat(weight_ihs, dim=0))
                weights.append(torch.cat(weight_hhs, dim=0))
            else:
                w_ih_splits = zip(*[torch.unbind(x.view(gs, module.hidden_size, -1)) for x in weight_ihs])
                w_hh_splits = zip(*[torch.unbind(x.view(gs, module.hidden_size, -1)) for x in weight_hhs])

                ih_gate_weights = [torch.cat(x) for x in w_ih_splits]
                hh_gate_weights = [torch.cat(x) for x in w_hh_splits]

                weights.extend(ih_gate_weights)
                weights.extend(hh_gate_weights)

            importance[name] = metric_func(weights, module)
    else:
        raise AttributeError(f'{type(module).__name__}({name}) is not supported for importance calculation')