def freeze_model_layers()

in distilvit/train.py [0:0]


def freeze_model_layers(model, freeze_encoder_layers=3, freeze_decoder_layers=3):
    for i, layer in enumerate(model.encoder.encoder.layer):
        if i < freeze_encoder_layers:
            for param in layer.parameters():
                param.requires_grad = False

    for i, layer in enumerate(model.decoder.transformer.h):
        if i < freeze_decoder_layers:
            for param in layer.parameters():
                param.requires_grad = False