in training/flax/run_pseudo_labelling_pt.py [0:0]
def main():
# 1. Parse input arguments
# We keep distinct sets of args, for cleaner separation of model/data/training related args
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# 2. Initialize the accelerator
# We will let the accelerator handle device placement for us in this example
# We simply have to specify the training precision and any trackers being used
# We'll use the same dtype arguments as our JAX/Flax training script and convert
# it to accelerate format
if model_args.dtype == "float16":
mixed_precision = "fp16"
torch_dtype = torch.float16
elif model_args.dtype == "bfloat16":
mixed_precision = "bf16"
torch_dtype = torch.bfloat16
else:
mixed_precision = "no"
torch_dtype = torch.float32
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=training_args.report_to,
project_dir=training_args.output_dir,
kwargs_handlers=[kwargs],
)
accelerator.init_trackers(project_name=data_args.wandb_project)
# 3. Set-up basic logging
# Create one log on every process with the configuration for debugging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Log a small summary on each proces
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info("Training/evaluation parameters %s", training_args)
# 3. Load dataset
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
token = model_args.token if model_args.token is not None else HfFolder().get_token()
data_splits = data_args.data_split_name.split("+")
for split in data_splits:
if data_args.streaming:
raw_datasets[split] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=split,
cache_dir=data_args.dataset_cache_dir,
token=token,
streaming=True,
)
else:
raw_datasets[split] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=split,
cache_dir=data_args.dataset_cache_dir,
token=token,
streaming=False,
num_proc=data_args.preprocessing_num_workers,
)
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
raise ValueError(
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset"
f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
" the correct audio column - one of"
f" {', '.join(next(iter(raw_datasets.values())).column_names)}."
)
if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
raise ValueError(
f"--text_column_name {data_args.text_column_name} not found in dataset"
f" '{data_args.dataset_name}'. Make sure to set `--text_column_name` to the"
" correct text column - one of"
f" {', '.join(next(iter(raw_datasets.values())).column_names)}."
)
# 7. Load pretrained model, tokenizer, and feature extractor
config = WhisperConfig.from_pretrained(
(model_args.config_name if model_args.config_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=token,
)
feature_extractor = WhisperFeatureExtractor.from_pretrained(
(model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=token,
)
tokenizer = WhisperTokenizerFast.from_pretrained(
(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
token=token,
)
processor = WhisperProcessor.from_pretrained(
(model_args.processor_name if model_args.processor_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=token,
)
model = WhisperForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
subfolder=model_args.subfolder,
token=token,
low_cpu_mem_usage=True,
torch_dtype=torch_dtype,
use_flash_attention_2=model_args.attn_type == "flash_attn_2",
)
if model_args.attn_type == "flash_attn":
model = model.to_bettertransformer()
elif model_args.attn_type not in [None, "flash_attn", "flash_attn_2"]:
raise ValueError(
f"Argument `attn_type` is set to {model_args.attn_type}. Should be one of:"
"1. `None`: default Transformers attention implementation."
"2. `flash_attn`: Flash Attention through PyTorch SDPA. Requires `torch>=2.0` and `optimum` to be installed. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080)."
"3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
)
if model_args.compile_encoder:
model.model.encoder.forward = torch.compile(
model.model.encoder.forward, mode="reduce-overhead", fullgraph=True
)
model.eval()
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
return_timestamps = data_args.return_timestamps
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
# We need to set the language and task ids for multilingual checkpoints
tokenizer.set_prefix_tokens(
language=data_args.language, task=data_args.task, predict_timestamps=return_timestamps
)
elif data_args.language is not None:
raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
"only be set for multilingual checkpoints."
)
# 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name,
datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
)
# 7. Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
max_label_length = (
data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
)
audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers
dataloader_num_workers = training_args.dataloader_num_workers
text_column_name = data_args.text_column_name
model_input_name = feature_extractor.model_input_names[0]
id_column_name = data_args.id_column_name
normalizer = EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
if data_args.max_samples_per_split is not None:
for split in data_splits:
raw_datasets[split] = (
raw_datasets[split].take(data_args.max_samples_per_split)
if data_args.streaming
else raw_datasets[split].select(range(data_args.max_samples_per_split))
)
def prepare_dataset(batch):
# process audio
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
# process audio length
batch[model_input_name] = inputs.get(model_input_name)[0]
# process targets
input_str = batch[text_column_name]
batch["labels"] = tokenizer(input_str, max_length=max_label_length, truncation=True).input_ids
# record the id of the sample as token ids
batch["file_id"] = tokenizer(batch[id_column_name], add_special_tokens=False).input_ids
return batch
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
if data_args.streaming:
vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=raw_datasets_features)
else:
vectorized_datasets = raw_datasets.map(
prepare_dataset,
remove_columns=raw_datasets_features,
num_proc=num_workers,
desc="preprocess dataset",
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with `args.preprocessing_only` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step `args.preprocessing_only` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only:
cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
return
if data_args.streaming and dataloader_num_workers > 0:
logger.warning(
"Using multiple dataloader num workers with streaming mode will result in different shards of "
"data being transcribed in parallel. This is not advised if you want to preserve the order of the "
"audio-text data."
)
# Handle the repository creation
output_dir = training_args.output_dir
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(output_dir).absolute().name,
token=token,
)
else:
repo_name = training_args.hub_model_id
create_repo(repo_name, exist_ok=True, token=token, repo_type="dataset", private=data_args.private_dataset)
repo = Repository(
output_dir,
clone_from=repo_name,
token=token,
repo_type="dataset",
)
# Ensure large txt files can be pushed to the Hub with git-lfs
with open(os.path.join(output_dir, ".gitattributes"), "r+") as f:
git_lfs_extensions = f.read()
if "*.csv" not in git_lfs_extensions:
f.write("*.csv filter=lfs diff=lfs merge=lfs -text")
else:
# this is where we'll save our transcriptions
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 8. Load Metric
metric = evaluate.load("wer")
# convention is that we space all punctuation *except* apostrophes
all_punctuation = list(string.punctuation.replace("'", ""))
def compute_metrics(preds, labels, file_ids):
# replace padded labels by the padding token
for idx in range(len(labels)):
labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
# we do not want to group tokens when computing the metrics
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
# space punctuation for orthographic WER (c.f. ESB paper https://arxiv.org/abs/2210.13352)
spaced_pred_str = [
pred_str[i].replace(punctuation, f" {punctuation} ")
for punctuation in all_punctuation
for i in range(len(pred_str))
]
spaced_label_str = [
label_str[i].replace(punctuation, f" {punctuation} ")
for punctuation in all_punctuation
for i in range(len(label_str))
]
wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
# normalize everything and re-compute the WER
norm_pred_str = [normalizer(pred) for pred in pred_str]
norm_label_str = [normalizer(label) for label in label_str]
# for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
file_ids = [file_ids[i] for i in range(len(file_ids)) if len(norm_label_str[i]) > 0]
# filtering step to only evaluate the samples that correspond to non-zero normalized references:
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str, file_ids
# 12. Define Training Schedule
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id, # <|startoftranscript|>
decoder_prev_token_id=tokenizer.all_special_ids[-3], # <|startofprev|>
input_padding="longest",
target_padding="max_length",
max_target_length=max_label_length,
)
# 14. Define generation arguments - we need to do this before we wrap the models in DDP
# so that we can still access the configs
num_beams = (
training_args.generation_num_beams
if training_args.generation_num_beams is not None
else getattr(model.generation_config, "num_beams", 1)
)
gen_kwargs = {
"max_length": max_label_length,
"num_beams": num_beams,
"return_timestamps": return_timestamps,
}
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
# forcing the language and task tokens helps multilingual models in their generations
gen_kwargs.update(
{
"language": data_args.language,
"task": data_args.task,
}
)
# 15. Prepare everything with accelerate
model = accelerator.prepare(model)
def eval_step_with_save(split="eval"):
# ======================== Evaluating ==============================
eval_preds = []
eval_labels = []
eval_ids = []
eval_start = time.time()
eval_loader = DataLoader(
vectorized_datasets[split],
batch_size=per_device_eval_batch_size,
collate_fn=data_collator,
num_workers=dataloader_num_workers,
pin_memory=True,
)
eval_loader = accelerator.prepare(eval_loader)
batches = tqdm(eval_loader, desc=f"Evaluating {split}...", disable=not accelerator.is_local_main_process)
# make the split name pretty for librispeech etc
split = split.replace(".", "-").split("/")[-1]
output_csv = os.path.join(output_dir, f"{split}-transcription.csv")
for step, batch in enumerate(batches):
file_ids = batch.pop("file_ids")
# Generate predictions and pad to max generated length
generated_ids = model.module.generate(batch["input_features"].to(dtype=torch_dtype), **gen_kwargs)
generated_ids = accelerator.pad_across_processes(generated_ids, dim=1, pad_index=tokenizer.pad_token_id)
# Gather all predictions and targets
file_ids, generated_ids, labels = accelerator.gather_for_metrics(
(file_ids, generated_ids, batch["labels"])
)
eval_preds.extend(generated_ids.cpu().numpy())
eval_labels.extend(labels.cpu().numpy())
file_ids = tokenizer.batch_decode(file_ids, skip_special_tokens=True)
eval_ids.extend(file_ids)
if step % training_args.logging_steps == 0 and step > 0:
batches.write(f"Saving transcriptions for split {split} step {step}")
accelerator.wait_for_everyone()
if data_args.decode_token_ids:
eval_preds = tokenizer.batch_decode(
eval_preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps
)
csv_data = [[eval_ids[i], eval_preds[i]] for i in range(len(eval_preds))]
with open(output_csv, "w", encoding="UTF8", newline="") as f:
writer = csv.writer(f)
# write multiple rows
writer.writerow(["file_id", "whisper_transcript"])
writer.writerows(csv_data)
if training_args.push_to_hub and accelerator.is_main_process:
repo.push_to_hub(
commit_message=f"Saving transcriptions for split {split} step {step}.",
blocking=False,
)
accelerator.wait_for_everyone()
eval_time = time.time() - eval_start
# compute WER metric for eval sets
wer_desc = ""
if "validation" in split or "test" in split:
wer_metric, pred_str, label_str, norm_pred_str, norm_label_str, eval_ids = compute_metrics(
eval_preds, eval_labels, eval_ids
)
wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
# Save metrics + predictions
log_metric(
accelerator,
metrics=wer_metric,
train_time=eval_time,
prefix=split,
)
log_pred(
accelerator,
pred_str,
label_str,
norm_pred_str,
norm_label_str,
prefix=split,
)
if data_args.decode_token_ids:
eval_preds = pred_str
elif data_args.decode_token_ids:
eval_preds = tokenizer.batch_decode(
eval_preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps
)
batches.write(f"Saving final transcriptions for split {split}.")
csv_data = [[eval_ids[i], eval_preds[i]] for i in range(len(eval_preds))]
with open(output_csv, "w", encoding="UTF8", newline="") as f:
writer = csv.writer(f)
# write multiple rows
writer.writerow(["file_id", "whisper_transcript"])
writer.writerows(csv_data)
# Print metrics
logger.info(wer_desc)
if not data_args.streaming:
raw_datasets[split] = raw_datasets[split].add_column("whisper_transcript", eval_preds)
logger.info("***** Running Labelling *****")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
logger.info(
f" Total eval batch size (w. parallel & distributed) = {training_args.per_device_eval_batch_size * accelerator.num_processes}"
)
logger.info(f" Predict labels with timestamps = {return_timestamps}")
logger.info(f" Decode labels to transcriptions = {data_args.decode_token_ids}")
for split in data_splits:
eval_step_with_save(split=split)
accelerator.wait_for_everyone()
if training_args.push_to_hub and accelerator.is_main_process:
repo.push_to_hub(
commit_message=f"Saving final transcriptions for split {split.replace('.', '-').split('/')[-1]}",
blocking=False,
)
if not data_args.streaming and accelerator.is_main_process:
raw_datasets.save_to_disk(output_dir, num_proc=num_workers)
if training_args.push_to_hub:
raw_datasets.push_to_hub(repo_name, config_name=data_args.dataset_config_name)
accelerator.end_training()