def get_optimizer_params()

in models/base_ssl3d_model.py [0:0]


    def get_optimizer_params(self):
        regularized_params, unregularized_params = [], []
        conv_types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
        bn_types = (
            nn.BatchNorm1d,
            nn.BatchNorm2d,
            nn.BatchNorm3d,
            nn.SyncBatchNorm,
            apex.parallel.SyncBatchNorm,
        )
        for module in self.modules():
            if isinstance(module, nn.Linear) or isinstance(module, conv_types):
                regularized_params.append(module.weight)
                if module.bias is not None:
                    if self.optimizer_config["regularize_bias"]:
                        regularized_params.append(module.bias)
                    else:
                        unregularized_params.append(module.bias)
            elif isinstance(module, bn_types):
                if module.weight is not None:
                    if self.optimizer_config["regularize_bn"]:
                        regularized_params.append(module.weight)
                    else:
                        unregularized_params.append(module.weight)
                if module.bias is not None:
                    if (
                        self.optimizer_config["regularize_bn"]
                        and self.optimizer_config["regularize_bias"]
                    ):
                        regularized_params.append(module.bias)
                    else:
                        unregularized_params.append(module.bias)
            elif len(list(module.children())) >= 0:
                # for any other layers not bn_types, conv_types or nn.Linear, if
                # the layers are the leaf nodes and have parameters, we regularize
                # them. Similarly, if non-leaf nodes but have parameters, regularize
                # them (set recurse=False)
                for params in module.parameters(recurse=False):
                    regularized_params.append(params)

        non_trainable_params = []
        for name, param in self.named_parameters():
            if name in cfg.MODEL.NON_TRAINABLE_PARAMS:
                param.requires_grad = False
                non_trainable_params.append(param)

        trainable_params = [
            params for params in self.parameters() if params.requires_grad
        ]
        regularized_params = [
            params for params in regularized_params if params.requires_grad
        ]
        unregularized_params = [
            params for params in unregularized_params if params.requires_grad
        ]
        logging.info("Traininable params: {}".format(len(trainable_params)))
        logging.info("Non-Traininable params: {}".format(len(non_trainable_params)))
        logging.info(
            "Regularized Parameters: {}. Unregularized Parameters {}".format(
                len(regularized_params), len(unregularized_params)
            )
        )
        return {
            "regularized_params": regularized_params,
            "unregularized_params": unregularized_params,
        }