in training/flax/run_pt_long_form_transcription.py [0:0]
def main():
# 1. Parse input arguments
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if "tensorboard" in training_args.report_to:
if has_tensorboard:
try:
from torch.utils.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=os.path.join(training_args.output_dir, "runs"))
except ImportError as ie:
has_tensorboard = False
logger.warning(
"Unable to display metrics through TensorBoard because some" f" package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is"
" not installed: Please run `pip install tensorboard` to enable."
)
# Enable wandb only on the master node
has_wandb = is_wandb_available()
if "wandb" in training_args.report_to:
if has_wandb:
import wandb as wandb_logger
# Set up wandb run
wandb_logger.init(
project=data_args.wandb_project,
name=data_args.wandb_name,
job_type=data_args.wandb_job_type,
dir=data_args.wandb_dir,
save_code=data_args.save_code_to_wandb,
)
else:
logger.warning("Wandb logging requires wandb to be installed. Run `pip install wandb` to enable.")
# 2. Setup logging
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# Set the verbosity to info of the Transformers logger.
# We only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO)
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
logger.info("Evaluation parameters %s", training_args)
# 3. Load dataset
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
# Convert lists of dataset names/configs/splits to a dict
# names: "librispeech_asr+gigaspeech", configs: "all+l", splits: "validation.clean+validation"
# -> [{"name: "librispeech_asr": "config": "all", "split": "validation.clean"}, {"name: "gigaspeech": "config": "l", "split": "validation"}
dataset_names_dict = convert_dataset_str_to_list(
data_args.dataset_name,
data_args.dataset_config_name,
splits=data_args.dataset_split_name,
text_column_names=data_args.text_column_name,
)
# load multiple eval sets
for dataset_dict in dataset_names_dict:
# Clean-up the dataset name for pretty logging
# ("distil-whisper/librispeech_asr", "validation.clean") -> "librispeech_asr/validation-clean"
pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
)
if dataset_dict["text_column_name"] not in list(raw_datasets[pretty_name].features.keys()):
raise ValueError(
f"--text column name {dataset_dict['text_column_name']} not found in the evaluation "
f"dataset {dataset_dict['name']}. Ensure `text_column_name` is set to the correct column "
f"for the target text. Should be one of {' '.join(list(raw_datasets[pretty_name].features.keys()))}"
)
if dataset_dict["text_column_name"] != "text":
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["text_column_name"], "text"
)
# Streaming mode robust way of obtaining the features
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
audio_column_name = data_args.audio_column_name
if audio_column_name not in raw_datasets_features:
raise ValueError(
f"--audio_column_name '{audio_column_name}' not found in dataset"
f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
" the correct audio column - one of"
f" {', '.join(raw_datasets_features)}."
)
for split in raw_datasets:
raw_datasets[split] = raw_datasets[split].remove_columns(
set(raw_datasets[split].features.keys()) - {audio_column_name, "text"}
)
if data_args.max_eval_samples is not None:
for split in raw_datasets:
raw_datasets[split] = (
raw_datasets[split].take(data_args.max_eval_samples)
if data_args.streaming
else raw_datasets[split].select(range(data_args.max_eval_samples))
)
# Store some constants
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
num_beams = training_args.generation_num_beams if training_args.generation_num_beams is not None else 1
model_kwargs = {
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
"subfolder": model_args.subfolder,
}
# 5. Load pretrained model, tokenizer, and feature extractor
pipe = pipeline(
"automatic-speech-recognition",
model_args.model_name_or_path,
torch_dtype=getattr(torch, model_args.dtype),
model_kwargs=model_kwargs,
max_new_tokens=training_args.generation_max_length,
batch_size=per_device_eval_batch_size,
chunk_length_s=model_args.chunk_length_s,
return_timestamps=model_args.return_timestamps,
device="cuda:0" if torch.cuda.is_available() else "cpu",
)
if pipe.model.can_generate():
if pipe.model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
generate_kwargs = {
"num_beams": num_beams,
"length_penalty": model_args.length_penalty,
"do_sample": model_args.do_sample,
"top_k": model_args.top_k,
"temperature": model_args.temperature,
}
if hasattr(pipe.model.generation_config, "is_multilingual") and pipe.model.generation_config.is_multilingual:
generate_kwargs = generate_kwargs.update({"langauge": "English", "task": "transcribe"})
else:
generate_kwargs = None
# 8. Load Metric
whisper_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
normalizer = EnglishTextNormalizer(whisper_tokenizer.english_spelling_normalizer)
def compute_metrics(pred_str, label_str, ngram_degree=5):
# normalize everything and re-compute the WER
norm_pred_str = [normalizer(pred) for pred in pred_str]
norm_label_str = [normalizer(label) for label in label_str]
# for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
# filtering step to only evaluate the samples that correspond to non-zero normalized references:
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
wer_output = process_words(norm_label_str, norm_pred_str, wer_default, wer_default)
wer_norm = 100 * wer_output.wer
ier_norm = 100 * wer_output.insertions / sum([len(ref) for ref in wer_output.references])
ser_norm = 100 * wer_output.substitutions / sum([len(ref) for ref in wer_output.references])
der_norm = 100 * wer_output.deletions / sum([len(ref) for ref in wer_output.references])
all_ngrams = list(ngrams(" ".join(norm_pred_str).split(), ngram_degree))
repeated_ngrams = len(all_ngrams) - len(set(all_ngrams))
return (
{"wer": wer_norm, "ier": ier_norm, "ser": ser_norm, "der": der_norm, "repeated_ngrams": repeated_ngrams},
pred_str,
label_str,
norm_pred_str,
norm_label_str,
)
def eval_step(split="eval"):
# ======================== Evaluating ==============================
eval_preds = []
eval_labels = []
eval_audios = []
eval_start = time.time()
for sample in tqdm(
pipe(
data(raw_datasets[split], log_audio=data_args.log_audio),
generate_kwargs=generate_kwargs,
),
desc=f"Evaluating {split}...",
):
eval_preds.append(sample["text"])
eval_labels.append(sample["reference"][0])
if data_args.log_audio:
eval_audios.append(sample["audio"][0])
eval_time = time.time() - eval_start
wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
eval_preds, eval_labels, data_args.ngram_degree
)
wer_desc = " ".join([f"{split} {key}: {value} |" for key, value in wer_metric.items()])
# Print metrics to stdout
logger.info(wer_desc)
# Save metrics to tensorboard
if has_tensorboard and "tensorboard" in training_args.report_to:
write_metric(summary_writer, wer_metric, prefix=split)
# Save metrics to wandb
if has_wandb and "wandb" in training_args.report_to:
write_wandb_metric(wandb_logger, wer_metric, eval_time, prefix=split)
if data_args.log_predictions:
write_wandb_pred(
wandb_logger, eval_audios, pred_str, label_str, norm_pred_str, norm_label_str, prefix=split
)
logger.info("***** Running Eval *****")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
logger.info(f" Total eval batch size (w. parallel & distributed) = {training_args.per_device_eval_batch_size}")
if pipe.model.can_generate():
logger.info(f" Beam size = {num_beams}")
if num_beams > 1:
logger.info(f" Length penalty size = {model_args.length_penalty}")
logger.info(f" Do sample = {model_args.do_sample}")
if model_args.do_sample:
logger.info(f" Top k = {model_args.top_k}")
logger.info(f" Temperature = {model_args.temperature}")
for split in raw_datasets:
eval_step(split=split)