in training/run_pseudo_labelling.py [0:0]
def main():
# 1. Parse input arguments
# We keep distinct sets of args, for cleaner separation of model/data/training related args
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# 2. Initialize the accelerator
# We will let the accelerator handle device placement for us in this example
# We simply have to specify the training precision and any trackers being used
# We'll use the same dtype arguments as our JAX/Flax training script and convert
# it to accelerate format
if model_args.dtype == "float16":
mixed_precision = "fp16"
torch_dtype = torch.float16
elif model_args.dtype == "bfloat16":
mixed_precision = "bf16"
torch_dtype = torch.bfloat16
else:
mixed_precision = "no"
torch_dtype = torch.float32
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=training_args.report_to,
project_dir=training_args.output_dir,
kwargs_handlers=[kwargs],
)
accelerator.init_trackers(project_name=data_args.wandb_project)
# 3. Set-up basic logging
# Create one log on every process with the configuration for debugging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Log a small summary on each proces
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info("Training/evaluation parameters %s", training_args)
# 3. Load dataset
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
token = model_args.token if model_args.token is not None else HfFolder().get_token()
data_splits = data_args.dataset_split_name.split("+")
for split in data_splits:
with accelerator.main_process_first():
raw_datasets[split] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=split,
cache_dir=data_args.dataset_cache_dir,
token=token,
streaming=data_args.streaming,
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
)
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
raise ValueError(
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset"
f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
" the correct audio column - one of"
f" {', '.join(next(iter(raw_datasets.values())).column_names)}."
)
if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
raise ValueError(
f"--text_column_name {data_args.text_column_name} not found in dataset"
f" '{data_args.dataset_name}'. Make sure to set `--text_column_name` to the"
" correct text column - one of"
f" {', '.join(next(iter(raw_datasets.values())).column_names)}."
)
# 7. Load pretrained model, tokenizer, and feature extractor
config = WhisperConfig.from_pretrained(
(model_args.config_name if model_args.config_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=token,
)
feature_extractor = WhisperFeatureExtractor.from_pretrained(
(model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=token,
)
tokenizer = WhisperTokenizerFast.from_pretrained(
(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
token=token,
)
processor = WhisperProcessor.from_pretrained(
(model_args.processor_name if model_args.processor_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=token,
)
model = WhisperForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
subfolder=model_args.subfolder,
token=token,
low_cpu_mem_usage=True,
torch_dtype=torch_dtype,
attn_implementation=model_args.attn_implementation,
)
model.eval()
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
return_timestamps = data_args.return_timestamps
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
is_multilingual = True
# We need to set the language and task ids for multilingual checkpoints
tokenizer.set_prefix_tokens(
language=data_args.language, task=data_args.task, predict_timestamps=return_timestamps
)
elif data_args.language is not None:
raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
"only be set for multilingual checkpoints."
)
else:
is_multilingual = False
# 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name,
datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
)
# 7. Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
max_label_length = (
data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
)
audio_column_name = data_args.audio_column_name
sampling_rate = feature_extractor.sampling_rate
preprocessing_batch_size = data_args.preprocessing_batch_size
num_workers = data_args.preprocessing_num_workers
dataloader_num_workers = training_args.dataloader_num_workers
text_column_name = data_args.text_column_name
model_input_name = feature_extractor.model_input_names[0]
id_column_name = data_args.id_column_name
speaker_id_column_name = data_args.speaker_id_column_name
normalizer = (
BasicTextNormalizer()
if data_args.language is not None
else EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
)
timestamp_position = 3 if is_multilingual else 1
decoder_prev_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>")
decoder_eot_token_id = tokenizer.eos_token_id
if data_args.max_samples_per_split is not None:
for split in data_splits:
raw_datasets[split] = (
raw_datasets[split].take(data_args.max_samples_per_split)
if data_args.streaming
else raw_datasets[split].select(range(data_args.max_samples_per_split))
)
if speaker_id_column_name is not None:
raw_datasets = raw_datasets.sort(speaker_id_column_name)
def concatenate_dataset(batch):
audio_arrays, texts, speaker_ids = [], [], []
# skip corrupted samples
for row in table_iter(batch.pa_table, batch_size=1):
row = batch.formatter.format_row(row)
try:
sample_audio = row[audio_column_name]['array']
sample_text = row[text_column_name]
sample_speaker_id = row[speaker_id_column_name] if speaker_id_column_name else None
except LibsndfileError:
logger.warning(f"{row[id_column_name]} is corrupted! Skipping sample.")
continue
audio_arrays.append(sample_audio)
texts.append(sample_text)
speaker_ids.append(sample_speaker_id)
# initialize concatenations
concat_audio = [audio_arrays[0]]
concat_text = [texts[0]]
concat_speaker_id = [speaker_ids[0]]
condition_on_prev = [0]
for audio_array, text, speaker_id in zip(audio_arrays[1:], texts[1:], speaker_ids[1:]):
is_same_speaker = speaker_id == concat_speaker_id[-1]
is_concatenable = len(audio_array) + len(concat_audio[-1]) <= max_input_length
if is_same_speaker and is_concatenable:
# inplace concatenation
concat_audio[-1] = np.append(concat_audio[-1], audio_array)
concat_text[-1] = concat_text[-1] + " " + text
else:
concat_audio.append(audio_array)
concat_text.append(text)
concat_speaker_id.append(speaker_id)
condition_on_prev.append(1 if is_same_speaker else 0)
batch[audio_column_name] = [{"array": array, "sampling_rate": sampling_rate} for array in concat_audio]
batch[text_column_name] = concat_text
batch[id_column_name] = concat_speaker_id
batch["condition_on_prev"] = condition_on_prev
return batch
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
if data_args.concatenate_audio and not data_args.streaming:
with accelerator.main_process_first():
raw_datasets = raw_datasets.map(
concatenate_dataset,
batched=True,
batch_size=preprocessing_batch_size,
num_proc=num_workers,
remove_columns=set(raw_datasets_features)
- {audio_column_name, text_column_name, id_column_name, "condition_on_prev"},
desc="Concatenating dataset...",
)
raw_datasets = raw_datasets.cast_column(
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
)
pretty_name = data_args.dataset_name.split("/")[-1]
def postprocess_ids(speaker_ids, indices):
speaker_ids_formatted = []
for speaker, idx in zip(speaker_ids, indices):
formatted_idx = f"{pretty_name}-{speaker}-{idx}" if speaker is not None else f"{pretty_name}-{idx}"
speaker_ids_formatted.append(formatted_idx)
return {id_column_name: speaker_ids_formatted}
with accelerator.main_process_first():
raw_datasets = raw_datasets.map(
postprocess_ids,
input_columns=[id_column_name],
with_indices=True,
desc="Setting sample idxs...",
batched=True,
batch_size=preprocessing_batch_size,
num_proc=num_workers,
)
elif data_args.concatenate_audio and data_args.streaming:
raise ValueError(
"Streaming mode is not yet compatible with concatenating audios to `max_duration_in_seconds`."
"Either set `--streaming=False` and download the audios locally, or open an issue on the Distil-Whisper repo to request this feature."
)
def prepare_dataset(batch):
# process audio
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
# process audio length
batch[model_input_name] = inputs.get(model_input_name)[0]
# process targets
input_str = batch[text_column_name]
batch["labels"] = tokenizer(input_str, max_length=max_label_length, truncation=True).input_ids
return batch
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
file_ids_dataset = IterableDatasetDict() if data_args.streaming else DatasetDict()
for split in raw_datasets:
file_ids_dataset[split] = raw_datasets[split][id_column_name]
if data_args.streaming:
with accelerator.main_process_first():
vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=raw_datasets_features)
else:
with accelerator.main_process_first():
vectorized_datasets = raw_datasets.map(
prepare_dataset,
remove_columns=raw_datasets_features,
num_proc=num_workers,
desc="preprocess dataset",
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with `args.preprocessing_only` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step `args.preprocessing_only` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only:
cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
return
if data_args.streaming and dataloader_num_workers > 0:
logger.warning(
"Using multiple dataloader num workers with streaming mode will result in different shards of "
"data being transcribed in parallel. This is not advised if you want to preserve the order of the "
"audio-text data."
)
# Handle the repository creation
output_dir = training_args.output_dir
if accelerator.is_main_process:
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(output_dir).absolute().name,
token=training_args.hub_token,
)
else:
repo_name = training_args.hub_model_id
create_repo(repo_name, repo_type="dataset", exist_ok=True, token=training_args.hub_token)
snapshot_download(repo_id=repo_name, repo_type="dataset", local_dir=output_dir, token=training_args.hub_token)
# Ensure large txt files can be pushed to the Hub with git-lfs
with open(os.path.join(output_dir, ".gitattributes"), "r+") as f:
git_lfs_extensions = f.read()
if "*.csv" not in git_lfs_extensions:
f.write("*.csv filter=lfs diff=lfs merge=lfs -text")
elif output_dir is not None:
# this is where we'll save our transcriptions
os.makedirs(output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# 8. Load Metric
metric = evaluate.load("wer")
def compute_metrics(preds, labels, file_ids):
# replace padded labels by the padding token
for idx in range(len(labels)):
labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(preds, skip_special_tokens=False, decode_with_timestamps=return_timestamps)
# we do not want to group tokens when computing the metrics
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
# normalize everything and re-compute the WER
norm_pred_str = [normalizer(pred) for pred in pred_str]
norm_label_str = [normalizer(label) for label in label_str]
# for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
file_ids = [file_ids[i] for i in range(len(file_ids)) if len(norm_label_str[i]) > 0]
# filtering step to only evaluate the samples that correspond to non-zero normalized references:
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
return {"wer": wer}, pred_str, label_str, norm_pred_str, norm_label_str, file_ids
def filter_eot_tokens(preds):
for idx in range(len(preds)):
# remove the EOT tokens to get the 'true' token length
token_ids = [token for token in preds[idx] if token != decoder_eot_token_id]
token_ids = token_ids + [decoder_eot_token_id]
preds[idx] = token_ids
return preds
# 12. Define Training Schedule
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id, # <|startoftranscript|>
input_padding="longest",
target_padding="max_length",
max_target_length=max_label_length,
)
# 14. Define generation arguments - we need to do this before we wrap the models in DDP
# so that we can still access the configs
num_beams = (
training_args.generation_num_beams
if training_args.generation_num_beams is not None
else getattr(model.generation_config, "num_beams", 1)
)
gen_kwargs = {
"max_length": max_label_length,
"num_beams": num_beams,
"return_timestamps": return_timestamps,
}
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
# forcing the language and task tokens helps multilingual models in their generations
gen_kwargs.update(
{
"language": data_args.language,
"task": data_args.task,
}
)
# remove any preset forced decoder ids since these are deprecated
model.generation_config.forced_decoder_ids = None
model.config.forced_decoder_ids = None
# 15. Prepare everything with accelerate
model = accelerator.prepare(model)
def eval_step_with_save(split="eval"):
# ======================== Evaluating ==============================
eval_preds = []
eval_labels = []
eval_ids = []
pred_str = []
eval_start = time.time()
eval_loader = DataLoader(
vectorized_datasets[split],
batch_size=per_device_eval_batch_size,
collate_fn=data_collator,
num_workers=dataloader_num_workers,
pin_memory=True,
)
file_loader = DataLoader(
file_ids_dataset[split],
batch_size=per_device_eval_batch_size * accelerator.num_processes,
num_workers=dataloader_num_workers,
)
eval_loader = accelerator.prepare(eval_loader)
batches = tqdm(eval_loader, desc=f"Evaluating {split}...", disable=not accelerator.is_local_main_process)
# make the split name pretty for librispeech etc
split = split.replace(".", "-").split("/")[-1]
output_csv = os.path.join(output_dir, f"{split}-transcription.csv")
for step, (batch, file_ids) in enumerate(zip(batches, file_loader)):
# Generate predictions and pad to max generated length
generate_fn = model.module.generate if accelerator.num_processes > 1 else model.generate
generated_ids = generate_fn(batch["input_features"].to(dtype=torch_dtype), **gen_kwargs)
generated_ids = accelerator.pad_across_processes(generated_ids, dim=1, pad_index=tokenizer.pad_token_id)
# Gather all predictions and targets
generated_ids, labels = accelerator.gather_for_metrics((generated_ids, batch["labels"]))
eval_preds.extend(generated_ids.cpu().numpy())
eval_labels.extend(labels.cpu().numpy())
eval_ids.extend(file_ids)
if step % training_args.logging_steps == 0 and step > 0:
batches.write(f"Saving transcriptions for split {split} step {step}")
accelerator.wait_for_everyone()
pred_ids = eval_preds[-(len(eval_preds) - len(pred_str)) :]
pred_ids = filter_eot_tokens(pred_ids)
pred_str.extend(
tokenizer.batch_decode(
pred_ids, skip_special_tokens=False, decode_with_timestamps=return_timestamps
)
)
csv_data = [[eval_ids[i], pred_str[i]] for i in range(len(eval_preds))]
with open(output_csv, "w", encoding="UTF8", newline="") as f:
writer = csv.writer(f)
# write multiple rows
writer.writerow(["file_id", "whisper_transcript"])
writer.writerows(csv_data)
if training_args.push_to_hub and accelerator.is_main_process:
upload_folder(
folder_path=output_dir,
repo_id=repo_name,
repo_type="dataset",
token=training_args.hub_token,
commit_message=f"Saving transcriptions for split {split} step {step}.",
)
accelerator.wait_for_everyone()
eval_time = time.time() - eval_start
# compute WER metric for eval sets
wer_desc = ""
if "validation" in split or "test" in split:
eval_preds = filter_eot_tokens(eval_preds)
wer_metric, pred_str, label_str, norm_pred_str, norm_label_str, eval_ids = compute_metrics(
eval_preds, eval_labels, eval_ids
)
wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
# Save metrics + predictions
log_metric(
accelerator,
metrics=wer_metric,
train_time=eval_time,
prefix=split,
)
log_pred(
accelerator,
pred_str,
label_str,
norm_pred_str,
norm_label_str,
prefix=split,
)
else:
pred_ids = eval_preds[-(len(eval_preds) - len(pred_str)) :]
pred_ids = filter_eot_tokens(pred_ids)
pred_str.extend(
tokenizer.batch_decode(pred_ids, skip_special_tokens=False, decode_with_timestamps=return_timestamps)
)
batches.write(f"Saving final transcriptions for split {split}.")
csv_data = [[eval_ids[i], eval_preds[i]] for i in range(len(eval_preds))]
with open(output_csv, "w", encoding="UTF8", newline="") as f:
writer = csv.writer(f)
# write multiple rows
writer.writerow(["file_id", "whisper_transcript"])
writer.writerows(csv_data)
# Print metrics
logger.info(wer_desc)
if not data_args.streaming:
raw_datasets[split] = raw_datasets[split].add_column("whisper_transcript", pred_str)
raw_datasets[split] = raw_datasets[split].add_column("eval_preds", eval_preds)
def add_concatenated_text(eval_preds, condition_on_prev):
concatenated_prev = [None]
for token_ids, condition in zip(eval_preds[:-1], condition_on_prev[1:]):
if condition is False:
concatenated_prev.append(None)
else:
prompt_ids = [token for token in token_ids if token != decoder_eot_token_id]
prompt_ids = [decoder_prev_token_id] + prompt_ids[timestamp_position:]
concatenated_prev.append(prompt_ids)
return {"condition_on_prev": concatenated_prev}
if data_args.concatenate_audio:
with accelerator.main_process_first():
raw_datasets[split] = raw_datasets[split].map(
add_concatenated_text,
input_columns=["eval_preds", "condition_on_prev"],
remove_columns=["eval_preds"],
desc="Setting condition on prev...",
batched=True,
batch_size=preprocessing_batch_size,
num_proc=num_workers,
)
logger.info("***** Running Labelling *****")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
logger.info(
f" Total eval batch size (w. parallel & distributed) = {training_args.per_device_eval_batch_size * accelerator.num_processes}"
)
logger.info(f" Predict labels with timestamps = {return_timestamps}")
for split in data_splits:
eval_step_with_save(split=split)
accelerator.wait_for_everyone()
if training_args.push_to_hub and accelerator.is_main_process:
upload_folder(
folder_path=output_dir,
repo_id=repo_name,
repo_type="dataset",
token=training_args.hub_token,
commit_message=f"Saving final transcriptions for split {split.replace('.', '-').split('/')[-1]}",
)
if not data_args.streaming and accelerator.is_main_process:
raw_datasets.save_to_disk(output_dir, num_proc=num_workers)
if training_args.push_to_hub:
raw_datasets.push_to_hub(repo_name, token=training_args.hub_token, config_name=data_args.dataset_config_name)
accelerator.end_training()