in training/utils/checkpoint_utils.py [0:0]
def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]):
"""
Verifies that all the parameters matching the provided patterns
are frozen - this acts as a safeguard when ignoring parameter
when saving checkpoints - if the parameters are in fact trainable
"""
if not patterns:
return
frozen_state_dict = filter_params_matching_unix_pattern(
patterns=patterns, state_dict=model.state_dict()
)
non_frozen_keys = {
n
for n, p in model.named_parameters()
if n in frozen_state_dict and p.requires_grad
}
if non_frozen_keys:
raise ValueError(
f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}"
)