in distilvit/train.py [0:0]
def freeze_model_layers(model, freeze_encoder_layers=3, freeze_decoder_layers=3):
for i, layer in enumerate(model.encoder.encoder.layer):
if i < freeze_encoder_layers:
for param in layer.parameters():
param.requires_grad = False
for i, layer in enumerate(model.decoder.transformer.h):
if i < freeze_decoder_layers:
for param in layer.parameters():
param.requires_grad = False