in training/flax/run_long_form_transcription.py [0:0]
def main():
# 1. Parse input arguments
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your JAX/Flax versions.
send_example_telemetry("run_flax_speech_recognition_seq2seq", model_args, data_args, framework="flax")
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if "tensorboard" in training_args.report_to:
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(os.path.join(training_args.output_dir, "runs")))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some packages are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is"
" not installed: Please run `pip install tensorboard` to enable."
)
# Enable wandb only on the master node
has_wandb = is_wandb_available()
if "wandb" in training_args.report_to:
if has_wandb and jax.process_index() == 0:
import wandb as wandb_logger
# Set up wandb run
wandb_logger.init(
project=data_args.wandb_project,
name=data_args.wandb_name,
job_type=data_args.wandb_job_type,
dir=data_args.wandb_dir,
save_code=data_args.save_code_to_wandb,
)
else:
logger.warning("Wandb logging requires wandb to be installed. Run `pip install wandb` to enable.")
# 2. Setup logging
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# Set the verbosity to info of the Transformers logger.
# We only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info("Evaluation parameters %s", training_args)
if model_args.compilation_cache:
cc.initialize_cache(os.path.join(model_args.cache_dir, "jax_cache"))
# 3. Load dataset
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
# Convert lists of dataset names/configs/splits to a dict
# names: "librispeech_asr+gigaspeech", configs: "all+l", splits: "validation.clean+validation"
# -> [{"name: "librispeech_asr": "config": "all", "split": "validation.clean"}, {"name: "gigaspeech": "config": "l", "split": "validation"}
dataset_names_dict = convert_dataset_str_to_list(
data_args.dataset_name,
data_args.dataset_config_name,
splits=data_args.dataset_split_name,
text_column_names=data_args.text_column_name,
)
# load multiple eval sets
for dataset_dict in dataset_names_dict:
# Clean-up the dataset name for pretty logging
# ("distil-whisper/librispeech_asr", "validation.clean") -> "librispeech_asr/validation-clean"
pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
)
if dataset_dict["text_column_name"] not in list(raw_datasets[pretty_name].features.keys()):
raise ValueError(
f"--text column name {dataset_dict['text_column_name']} not found in the evaluation "
f"dataset {dataset_dict['name']}. Ensure `text_column_name` is set to the correct column "
f"for the target text. Should be one of {' '.join(list(raw_datasets[pretty_name].features.keys()))}"
)
if dataset_dict["text_column_name"] != "text":
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["text_column_name"], "text"
)
# Streaming mode robust way of obtaining the features
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
audio_column_name = data_args.audio_column_name
if audio_column_name not in raw_datasets_features:
raise ValueError(
f"--audio_column_name '{audio_column_name}' not found in dataset"
f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
" the correct audio column - one of"
f" {', '.join(raw_datasets_features)}."
)
for split in raw_datasets:
raw_datasets[split] = raw_datasets[split].remove_columns(
set(raw_datasets[split].features.keys()) - {audio_column_name, "text"}
)
# 5. Load pretrained model, tokenizer, and feature extractor
pipeline = FlaxWhisperPipeline(
model_args.model_name_or_path,
dtype=getattr(jnp, model_args.dtype),
max_length=training_args.generation_max_length,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
subfolder=model_args.subfolder,
# use_scan=model_args.load_with_scan, # Model might have (erroneously) been saved with scan still enabled
)
if pipeline.model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
# disable scan if necessary (makes the inference step faster)
if model_args.load_with_scan:
pipeline.model.disable_scan() # to disable scan in the nn.Module
pipeline.params = pipeline.model.convert_scan_to_unroll(
pipeline.params
) # to convert the scan params to unrolled
# 6. Possibly evaluate on a subset of data
if data_args.max_eval_samples is not None:
for split in raw_datasets:
raw_datasets[split] = (
raw_datasets[split].take(data_args.max_eval_samples)
if data_args.streaming
else raw_datasets[split].select(range(data_args.max_eval_samples))
)
# 8. Compute WER Metrics
normalizer = EnglishTextNormalizer(pipeline.tokenizer.english_spelling_normalizer)
def compute_metrics(pred_str, label_str, ngram_degree=5):
# normalize everything and compute the WER
norm_pred_str = [normalizer(pred).replace(".", "") for pred in pred_str]
norm_label_str = [normalizer(label) for label in label_str]
# for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
# filtering step to only evaluate the samples that correspond to non-zero normalized references:
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
wer_output = process_words(norm_label_str, norm_pred_str, wer_default, wer_default)
wer_norm = 100 * wer_output.wer
ier_norm = 100 * wer_output.insertions / sum([len(ref) for ref in wer_output.references])
ser_norm = 100 * wer_output.substitutions / sum([len(ref) for ref in wer_output.references])
der_norm = 100 * wer_output.deletions / sum([len(ref) for ref in wer_output.references])
all_ngrams = list(ngrams(" ".join(norm_pred_str).split(), ngram_degree))
repeated_ngrams = len(all_ngrams) - len(set(all_ngrams))
return (
{"wer": wer_norm, "ier": ier_norm, "ser": ser_norm, "der": der_norm, "repeated_ngrams": repeated_ngrams},
pred_str,
label_str,
norm_pred_str,
norm_label_str,
)
# Store some constants
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
num_beams = (
training_args.generation_num_beams
if training_args.generation_num_beams is not None
else pipeline.model.config.num_beams
)
generation_config = pipeline.model.generation_config
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
# We need to set the language and task ids for previously multilingual checkpoints - for now we hardcode this to English
language = "English"
task = "transcribe"
else:
language = None
task = None
# pre-compile the model so that we don't count it in our eval
logger.info("Pre-compiling the generate call...")
random_inputs = {"input_features": np.ones((eval_batch_size, 80, 2 * pipeline.model.config.max_source_positions))}
pipeline.forward(
random_inputs,
batch_size=eval_batch_size,
language=language,
task=task,
return_timestamps=model_args.return_timestamps,
num_beams=num_beams,
length_penalty=model_args.length_penalty,
do_sample=model_args.do_sample,
top_k=model_args.top_k,
temperature=model_args.temperature,
)
def eval_step(split="eval"):
# ======================== Evaluating ==============================
eval_preds = []
eval_labels = []
eval_audios = []
eval_start = time.time()
for sample in tqdm(raw_datasets[split], desc=f"Evaluating {split}..."):
# Model forward
label_str = sample["text"]
if data_args.log_audio:
eval_audios.append(sample["audio"])
pred_str = pipeline(
sample["audio"],
batch_size=eval_batch_size,
language=language,
task=task,
chunk_length_s=model_args.chunk_length_s,
return_timestamps=model_args.return_timestamps,
num_beams=num_beams,
length_penalty=model_args.length_penalty,
do_sample=model_args.do_sample,
top_k=model_args.top_k,
temperature=model_args.temperature,
)
eval_preds.append(pred_str["text"])
eval_labels.append(label_str)
eval_time = time.time() - eval_start
wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
eval_preds, eval_labels, ngram_degree=data_args.ngram_degree
)
wer_desc = " ".join([f"{split} {key}: {value} |" for key, value in wer_metric.items()])
# Print metrics to stdout
logger.info(wer_desc)
# Save metrics to tensorboard
if has_tensorboard and jax.process_index() == 0 and "tensorboard" in training_args.report_to:
write_metric(summary_writer, wer_metric, prefix=split)
# Save metrics to wandb
if has_wandb and jax.process_index() == 0 and "wandb" in training_args.report_to:
write_wandb_metric(wandb_logger, wer_metric, eval_time, prefix=split)
if data_args.log_predictions:
write_wandb_pred(
wandb_logger, eval_audios, pred_str, label_str, norm_pred_str, norm_label_str, prefix=split
)
logger.info("***** Running Eval *****")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
logger.info(f" Total eval batch size (w. parallel & distributed) = {eval_batch_size}")
logger.info(f" Beam size = {num_beams}")
if num_beams > 1:
logger.info(f" Length penalty size = {model_args.length_penalty}")
logger.info(f" Do sample = {model_args.do_sample}")
if model_args.do_sample:
logger.info(f" Top k = {model_args.top_k}")
logger.info(f" Temperature = {model_args.temperature}")
for split in raw_datasets:
eval_step(split=split)