in src/diarizers/models/model.py [0:0]
def from_pyannote_model(cls, pretrained):
"""Copy the weights and architecture of a pre-trained Pyannote model.
Args:
pretrained (pyannote.core.Model): pretrained pyannote segmentation model.
"""
# Initialize model:
specifications = copy.deepcopy(pretrained.specifications)
# Copy pretrained model hyperparameters:
chunk_duration = specifications.duration
max_speakers_per_frame = specifications.powerset_max_classes
weigh_by_cardinality = False
min_duration = specifications.min_duration
warm_up = specifications.warm_up
max_speakers_per_chunk = len(specifications.classes)
config = SegmentationModelConfig(
chunk_duration=chunk_duration,
max_speakers_per_frame=max_speakers_per_frame,
weigh_by_cardinality=weigh_by_cardinality,
min_duration=min_duration,
warm_up=warm_up,
max_speakers_per_chunk=max_speakers_per_chunk,
)
model = cls(config)
# Copy pretrained model weights:
model.model.hparams = copy.deepcopy(pretrained.hparams)
model.model.sincnet = copy.deepcopy(pretrained.sincnet)
model.model.sincnet.load_state_dict(pretrained.sincnet.state_dict())
model.model.lstm = copy.deepcopy(pretrained.lstm)
model.model.lstm.load_state_dict(pretrained.lstm.state_dict())
model.model.linear = copy.deepcopy(pretrained.linear)
model.model.linear.load_state_dict(pretrained.linear.state_dict())
model.model.classifier = copy.deepcopy(pretrained.classifier)
model.model.classifier.load_state_dict(pretrained.classifier.state_dict())
model.model.activation = copy.deepcopy(pretrained.activation)
model.model.activation.load_state_dict(pretrained.activation.state_dict())
return model