train_segmentation.py (117 lines of code) (raw):
import os
from typing import Optional
from pyannote.audio import Model
from transformers import Trainer, TrainingArguments, HfArgumentParser
from datasets import load_dataset, DatasetDict
from diarizers import Preprocess, SegmentationModel, DataCollator, Metrics
from dataclasses import dataclass, field
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
dataset_name: str = field(
default=None,
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
dataset_config_name: str = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_split_name: str = field(
default="train", metadata={"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"}
)
eval_split_name: str = field(
default="validation", metadata={"help": "The name of the training data set split to use (via the datasets library). Defaults to 'val'"}
)
split_on_subset: str = field(
default=None,
metadata={"help": "Automatically splits the dataset into train-val-set on a specified subset. Defaults to 'None'"},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
parser = HfArgumentParser((DataTrainingArguments, ModelArguments, TrainingArguments))
data_args, model_args, training_args = parser.parse_args_into_dataclasses()
# Load the Dataset:
if data_args.dataset_config_name:
dataset = load_dataset(
str(data_args.dataset_name),
str(data_args.dataset_config_name),
num_proc=int(data_args.preprocessing_num_workers)
)
else:
dataset = load_dataset(
str(data_args.dataset_name),
num_proc=int(data_args.preprocessing_num_workers)
)
train_split_name = data_args.train_split_name
val_split_name = data_args.eval_split_name
# Split in Train-Val-Test:
if data_args.split_on_subset:
train_testvalid = dataset[str(data_args.split_on_subset)].train_test_split(test_size=0.2, seed=0)
test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed=0)
dataset = DatasetDict({
'train': train_testvalid['train'],
'validation': test_valid['test'],
'test': test_valid['train']}
)
train_split_name = 'train'
val_split_name = 'validation'
# Load the Pretrained Pyannote Segmentation Model:
pretrained = Model.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_auth_token=True
)
model = SegmentationModel.from_pyannote_model(pretrained)
# Load the preprocessor:
preprocessor = Preprocess(model.config)
# Preprocess:
if training_args.do_train:
train_set = dataset[str(train_split_name)].map(
lambda file: preprocessor(file, random=False, overlap=0.5),
num_proc=data_args.preprocessing_num_workers,
remove_columns=next(iter(dataset.values())).column_names,
batched=True,
batch_size=1
).shuffle().with_format("torch")
if training_args.do_eval:
val_set = dataset[str(val_split_name)].map(
lambda file: preprocessor(file, random=False, overlap=0.0),
num_proc=data_args.preprocessing_num_workers,
remove_columns=next(iter(dataset.values())).column_names,
batched=True,
keep_in_memory=True,
batch_size=1
).with_format('torch')
# Load metrics:
metrics = Metrics(model.specifications)
# Define the Trainer:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_set,
data_collator=DataCollator(max_speakers_per_chunk=model.config.max_speakers_per_chunk),
eval_dataset=val_set,
compute_metrics=metrics,
)
# Train!
if training_args.do_eval:
first_eval = trainer.evaluate()
print("Initial metric values: ", first_eval)
if training_args.do_train:
trainer.train()
# Write Training Stats
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "speaker diarization"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
kwargs["dataset"] = data_args.dataset_name
kwargs['tags'] = ['speaker-diarization', 'speaker-segmentation']
# Push to Hub
if training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(**kwargs)