in classy_vision/generic/util.py [0:0]
def split_batchnorm_params(model: nn.Module):
"""Finds the set of BatchNorm parameters in the model.
Recursively traverses all parameters in the given model and returns a tuple
of lists: the first element is the set of batchnorm parameters, the second
list contains all other parameters of the model."""
batchnorm_params = []
other_params = []
for module in model.modules():
# If module has children (i.e. internal node of constructed DAG) then
# only add direct parameters() to the list of params, else go over
# children node to find if they are BatchNorm or have "bias".
if list(module.children()) != []:
for params in module.parameters(recurse=False):
if params.requires_grad:
other_params.append(params)
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
for params in module.parameters():
if params.requires_grad:
batchnorm_params.append(params)
else:
for params in module.parameters():
if params.requires_grad:
other_params.append(params)
return batchnorm_params, other_params