def checkpoint_filter_fn()

in timm/models/coat.py [0:0]


def checkpoint_filter_fn(state_dict, model):
    out_dict = {}
    state_dict = state_dict.get('model', state_dict)
    for k, v in state_dict.items():
        # original model had unused norm layers, removing them requires filtering pretrained checkpoints
        if k.startswith('norm1') or \
                (k.startswith('norm2') and getattr(model, 'norm2', None) is None) or \
                (k.startswith('norm3') and getattr(model, 'norm3', None) is None) or \
                (k.startswith('norm4') and getattr(model, 'norm4', None) is None) or \
                (k.startswith('aggregate') and getattr(model, 'aggregate', None) is None) or \
                (k.startswith('head') and getattr(model, 'head', None) is None):
            continue
        out_dict[k] = v
    return out_dict