in models/base_ssl3d_model.py [0:0]
def get_optimizer_params(self):
regularized_params, unregularized_params = [], []
conv_types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
bn_types = (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.SyncBatchNorm,
apex.parallel.SyncBatchNorm,
)
for module in self.modules():
if isinstance(module, nn.Linear) or isinstance(module, conv_types):
regularized_params.append(module.weight)
if module.bias is not None:
if self.optimizer_config["regularize_bias"]:
regularized_params.append(module.bias)
else:
unregularized_params.append(module.bias)
elif isinstance(module, bn_types):
if module.weight is not None:
if self.optimizer_config["regularize_bn"]:
regularized_params.append(module.weight)
else:
unregularized_params.append(module.weight)
if module.bias is not None:
if (
self.optimizer_config["regularize_bn"]
and self.optimizer_config["regularize_bias"]
):
regularized_params.append(module.bias)
else:
unregularized_params.append(module.bias)
elif len(list(module.children())) >= 0:
# for any other layers not bn_types, conv_types or nn.Linear, if
# the layers are the leaf nodes and have parameters, we regularize
# them. Similarly, if non-leaf nodes but have parameters, regularize
# them (set recurse=False)
for params in module.parameters(recurse=False):
regularized_params.append(params)
non_trainable_params = []
for name, param in self.named_parameters():
if name in cfg.MODEL.NON_TRAINABLE_PARAMS:
param.requires_grad = False
non_trainable_params.append(param)
trainable_params = [
params for params in self.parameters() if params.requires_grad
]
regularized_params = [
params for params in regularized_params if params.requires_grad
]
unregularized_params = [
params for params in unregularized_params if params.requires_grad
]
logging.info("Traininable params: {}".format(len(trainable_params)))
logging.info("Non-Traininable params: {}".format(len(non_trainable_params)))
logging.info(
"Regularized Parameters: {}. Unregularized Parameters {}".format(
len(regularized_params), len(unregularized_params)
)
)
return {
"regularized_params": regularized_params,
"unregularized_params": unregularized_params,
}