modules/SwissArmyTransformer/sat/model/official/distill_model.py (29 lines of code) (raw):

import torch.nn as nn class DistillModel(nn.Module): def __init__(self, teacher, student): super().__init__() self.teacher = teacher self.student = student def forward(self, teacher_kwargs, student_kwargs): teacher_logits, *mem_t = self.teacher(**teacher_kwargs) student_logits, *mem_s = self.student(**student_kwargs) return teacher_logits, student_logits def disable_untrainable_params(self): for n, p in self.teacher.named_parameters(): p.requires_grad_(False) @classmethod def add_model_specific_args(cls, parser): group = parser.add_argument_group('BERT-distill', 'BERT distill Configurations') group.add_argument('--teacher', type=str) group.add_argument('--tc-type', type=str) group.add_argument('--st-type', type=str) return parser @classmethod def from_pretrained(cls, args, teacher_cls, student_name, student_cls): student, args = student_cls.from_pretrained(student_name, args, prefix='student.') if isinstance(teacher_cls, type): teacher, t_args = teacher_cls.from_pretrained(args.teacher, args) else: teacher = teacher_cls model = DistillModel(teacher, student) return model, args