in training/flax/run_eval.py [0:0]
def main():
# 1. Parse input arguments
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your JAX/Flax versions.
send_example_telemetry("run_flax_speech_recognition_seq2seq", model_args, data_args, framework="flax")
# 2. Setup logging
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# Set the verbosity to info of the Transformers logger.
# We only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info("Evaluation parameters %s", training_args)
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if "tensorboard" in training_args.report_to:
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
"Unable to display metrics through TensorBoard because some" f" package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is"
" not installed: Please run `pip install tensorboard` to enable."
)
# Enable wandb only on the master node
has_wandb = is_wandb_available()
if "wandb" in training_args.report_to:
if has_wandb and jax.process_index() == 0:
import wandb as wandb_logger
# Set up wandb run
wandb_logger.init(
project=data_args.wandb_project,
name=data_args.wandb_name,
job_type=data_args.wandb_job_type,
dir=data_args.wandb_dir,
save_code=data_args.save_code_to_wandb,
)
else:
logger.warning("Wandb logging requires wandb to be installed. Run `pip install wandb` to enable.")
# 3. Load dataset
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
# Convert lists of dataset names/configs/splits to a dict
# names: "librispeech_asr+gigaspeech", configs: "all+l", splits: "validation.clean+validation"
# -> [{"name: "librispeech_asr": "config": "all", "split": "validation.clean"}, {"name: "gigaspeech": "config": "l", "split": "validation"}
dataset_names_dict = convert_dataset_str_to_list(
data_args.dataset_name,
data_args.dataset_config_name,
splits=data_args.dataset_split_name,
text_column_names=data_args.text_column_name,
)
if len(dataset_names_dict) == 1:
# load a single eval set
dataset_dict = dataset_names_dict[0]
raw_datasets["eval"] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
)
if dataset_dict["text_column_name"] not in list(raw_datasets["eval"].features.keys()):
raise ValueError(
f"--text column name {dataset_dict['text_column_name']} not found in the evaluation "
f"dataset {dataset_dict['name']}. Ensure `text_column_name` is set to the correct column "
f"for the target text. Should be one of {' '.join(list(raw_datasets['eval'].features.keys()))}"
)
if dataset_dict["text_column_name"] != "text":
raw_datasets["eval"] = raw_datasets["eval"].rename_column(dataset_dict["text_column_name"], "text")
else:
# load multiple eval sets
for dataset_dict in tqdm(dataset_names_dict, desc="Loading datasets..."):
# Clean-up the dataset name for pretty logging
# ("distil-whisper/librispeech_asr", "validation.clean") -> "librispeech_asr/validation-clean"
pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
)
if dataset_dict["text_column_name"] not in list(raw_datasets[pretty_name].features.keys()):
raise ValueError(
f"`--text_column_name` {dataset_dict['text_column_name']} not found in the evaluation "
f"dataset {dataset_dict['name']}. Ensure `text_column_name` is set to the correct column "
f"for the target text. Should be one of {' '.join(list(raw_datasets[pretty_name].features.keys()))}"
)
if dataset_dict["text_column_name"] != "text":
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["text_column_name"], "text"
)
# 5. Load pretrained model, tokenizer, and feature extractor
config = WhisperConfig.from_pretrained(
(model_args.config_name if model_args.config_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
feature_extractor = FlaxWhisperFeatureExtractor.from_pretrained(
(model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = WhisperTokenizerFast.from_pretrained(
(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
processor = WhisperProcessor.from_pretrained(
(model_args.processor_name if model_args.processor_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
config=config,
dtype=getattr(jnp, model_args.dtype),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
_do_init=False,
subfolder=model_args.subfolder,
# use_scan=model_args.load_with_scan, # Model might have (erroneously) been saved with scan still enabled
)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
# disable scan if necessary (makes the inference step faster)
if model_args.load_with_scan:
model.disable_scan() # to disable scan in the nn.Module
params = model.convert_scan_to_unroll(params) # to convert the scan params to unrolled
# 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name,
datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
)
# 7. Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
max_label_length = (
data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
)
audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers
dataloader_num_workers = training_args.dataloader_num_workers
model_input_name = feature_extractor.model_input_names[0]
normalizer = EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
if data_args.max_eval_samples is not None:
for split in raw_datasets:
raw_datasets[split] = (
raw_datasets[split].take(data_args.max_eval_samples)
if data_args.streaming
else raw_datasets[split].select(range(data_args.max_eval_samples))
)
def prepare_dataset(batch):
# process audio
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
# process audio length
batch[model_input_name] = inputs.get(model_input_name)[0]
# process targets
input_str = batch["text"]
batch["labels"] = tokenizer(input_str, max_length=max_label_length, truncation=True).input_ids
return batch
vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
for split in raw_datasets:
raw_datasets_features = list(raw_datasets[split].features.keys())
if data_args.log_audio:
# if logging audio samples preserve the audio column when mapping the dataset
raw_datasets_features.remove(audio_column_name)
map_fn = partial(
raw_datasets[split].map,
function=prepare_dataset,
remove_columns=raw_datasets_features,
)
vectorized_datasets[split] = (
map_fn(num_proc=num_workers, desc="preprocess eval dataset")
if not data_args.streaming
else map_fn() # In streaming, we can't run multiproc - errors out if we try to
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with `args.preprocessing_only` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step `args.preprocessing_only` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only:
cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
return
# 8. Load Metric
metric = evaluate.load("wer")
# convention is that we space all punctuation *except* apostrophes
all_punctuation = list(string.punctuation.replace("'", ""))
return_timestamps = model_args.return_timestamps
def compute_metrics(preds, labels):
# replace padded labels by the padding token
for idx in range(len(labels)):
labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
# we do not want to group tokens when computing the metrics
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
# space punctuation for orthographic WER (c.f. ESB paper https://arxiv.org/abs/2210.13352)
spaced_pred_str = [
pred_str[i].replace(punctuation, f" {punctuation} ")
for punctuation in all_punctuation
for i in range(len(pred_str))
]
spaced_label_str = [
label_str[i].replace(punctuation, f" {punctuation} ")
for punctuation in all_punctuation
for i in range(len(label_str))
]
wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
# normalize everything and re-compute the WER
norm_pred_str = [normalizer(pred) for pred in pred_str]
norm_label_str = [normalizer(label) for label in label_str]
# for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
# filtering step to only evaluate the samples that correspond to non-zero normalized references:
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
input_padding="longest",
target_padding="max_length",
max_target_length=max_label_length,
log_audio=data_args.log_audio,
)
# Store some constants
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
# label smoothed cross entropy
def loss_fn(logits, labels, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss, i.e. where labels are not set to -100
padding_mask = labels >= 0
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels")
logits = model(**batch, params=params, freeze_encoder=True, train=False)[0]
loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
metrics = {"loss": loss}
return metrics
# Define generation function
num_beams = (
training_args.generation_num_beams
if training_args.generation_num_beams is not None
else model.config.num_beams
)
# forcing the language and task tokens helps the flax teacher model in its generations
gen_kwargs = {
"max_length": max_label_length,
"num_beams": num_beams,
"language": "<|en|>",
"task": "transcribe",
"return_timestamps": return_timestamps,
}
def generate_step(params, batch):
output_ids = model.generate(
batch[model_input_name],
attention_mask=batch.get("attention_mask"),
params=params,
freeze_encoder=True,
**gen_kwargs,
)
return output_ids.sequences
# Create parallel version of the eval and generate step
p_eval_step = jax.pmap(
partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor),
"batch",
)
p_generate_step = jax.pmap(generate_step, "batch")
# Replicate params on each device
params = jax_utils.replicate(params)
def eval_step(split="eval"):
# ======================== Evaluating ==============================
eval_metrics = []
eval_preds = []
eval_labels = []
eval_audios = []
eval_start = time.time()
eval_loader = get_data_loader(
vectorized_datasets[split],
batch_size=eval_batch_size,
data_collator=data_collator,
dataloader_num_workers=dataloader_num_workers,
)
for batch in tqdm(eval_loader, desc=f"Evaluating {split}..."):
# Model forward
labels = batch["labels"]
if data_args.log_audio:
eval_audios.extend(batch.pop("audio"))
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
params, batch.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# generation
if training_args.predict_with_generate:
generated_ids = pad_shard_unpad(p_generate_step)(
params, batch.data, min_device_batch=per_device_eval_batch_size
)
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
eval_labels.extend(labels)
eval_time = time.time() - eval_start
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# compute WER metric
wer_desc = ""
if training_args.predict_with_generate:
wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(eval_preds, eval_labels)
eval_metrics.update(wer_metric)
wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
# Print metrics
logger.info(f"Eval Loss: {eval_metrics['loss']} | {wer_desc})")
# Save metrics
if has_tensorboard and jax.process_index() == 0 and "tensorboard" in training_args.report_to:
write_metric(summary_writer, eval_metrics, model_args.step, prefix=split)
if has_wandb and jax.process_index() == 0 and "wandb" in training_args.report_to:
write_wandb_metric(wandb_logger, eval_metrics, eval_time, prefix=split)
if training_args.predict_with_generate:
write_wandb_pred(
wandb_logger, eval_audios, pred_str, label_str, norm_pred_str, norm_label_str, prefix=split
)
logger.info("***** Running Eval *****")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
logger.info(f" Total eval batch size (w. parallel & distributed) = {eval_batch_size}")
for split in vectorized_datasets:
eval_step(split=split)