def split_batchnorm_params()

in classy_vision/generic/util.py [0:0]


def split_batchnorm_params(model: nn.Module):
    """Finds the set of BatchNorm parameters in the model.

    Recursively traverses all parameters in the given model and returns a tuple
    of lists: the first element is the set of batchnorm parameters, the second
    list contains all other parameters of the model."""
    batchnorm_params = []
    other_params = []
    for module in model.modules():
        # If module has children (i.e. internal node of constructed DAG) then
        # only add direct parameters() to the list of params, else go over
        # children node to find if they are BatchNorm or have "bias".
        if list(module.children()) != []:
            for params in module.parameters(recurse=False):
                if params.requires_grad:
                    other_params.append(params)
        elif isinstance(module, nn.modules.batchnorm._BatchNorm):
            for params in module.parameters():
                if params.requires_grad:
                    batchnorm_params.append(params)
        else:
            for params in module.parameters():
                if params.requires_grad:
                    other_params.append(params)
    return batchnorm_params, other_params