in training/flax/run_distillation.py [0:0]
def main():
# 1. Parse input arguments
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxSeq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your JAX/Flax versions.
send_example_telemetry("run_flax_speech_recognition_seq2seq", model_args, data_args, framework="flax")
# 2. Define remote logging - do this early so that we get the full traceback on our remote logs
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard:
if jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=os.path.join(Path(training_args.output_dir), "runs"))
except ImportError as ie:
has_tensorboard = False
logger.warning(
"Unable to display metrics through TensorBoard because some package" f" are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not"
" installed: Please run `pip install tensorboard` to enable."
)
# Enable wandb only on the master node
has_wandb = is_wandb_available()
if has_wandb:
import wandb as wandb_logger
# Set up wandb run
if jax.process_index() == 0:
wandb_logger.init(
project=data_args.wandb_project,
name=data_args.wandb_name,
job_type=data_args.wandb_job_type,
dir=data_args.wandb_dir,
save_code=data_args.save_code_to_wandb,
)
else:
logger.warning("Wandb logging requires wandb to be installed. Run `pip install wandb` to enable.")
# 3. Setup local logging
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# Set the verbosity to info of the Transformers logger.
# We only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info("Training/evaluation parameters %s", training_args)
# Check the output dir is valid
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not"
" empty. Use `--overwrite_output_dir` to overcome."
)
# 4. Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name,
token=training_args.hub_token,
)
else:
repo_name = training_args.hub_model_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(
training_args.output_dir,
clone_from=repo_name,
token=training_args.hub_token,
)
if training_args.compilation_cache:
cc.initialize_cache(os.path.join(model_args.cache_dir, "jax_cache"))
# 5. Load dataset
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
# set seed for determinism
set_seed(training_args.seed)
if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets(
data_args.train_dataset_name,
data_args.train_dataset_config_name,
splits=data_args.train_split_name,
streaming=data_args.streaming,
dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed,
cache_dir=data_args.dataset_cache_dir,
token=True if model_args.use_auth_token else None,
)
if training_args.do_eval:
dataset_names_dict = convert_dataset_str_to_list(
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
(
data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name
),
splits=data_args.eval_split_name,
text_column_names=data_args.eval_text_column_name,
)
all_eval_splits = []
if len(dataset_names_dict) == 1:
# load a single eval set
dataset_dict = dataset_names_dict[0]
all_eval_splits.append("eval")
raw_datasets["eval"] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
)
else:
# load multiple eval sets
for dataset_dict in dataset_names_dict:
if dataset_dict["name"] == "esb/diagnostic-dataset":
# for the ESB diagnostic dataset, the dataset name is effectively the config
pretty_name = f"{dataset_dict['config']}-diagnostic/{dataset_dict['split']}"
else:
pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
all_eval_splits.append(pretty_name)
raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
)
features = raw_datasets[pretty_name].features.keys()
if "text" not in features:
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["text_column_name"], "text"
)
raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
set(raw_datasets[pretty_name].features.keys()) - {"audio", "text"}
)
if not training_args.do_train and not training_args.do_eval:
raise ValueError(
"Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
)
raw_datasets_train_features = list(raw_datasets["train"].features.keys())
if data_args.audio_column_name not in raw_datasets_train_features:
raise ValueError(
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset"
f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
" the correct audio column - one of"
f" {', '.join(raw_datasets_train_features)}."
)
if data_args.train_text_column_name not in raw_datasets_train_features:
raise ValueError(
f"--train_text_column_name {data_args.train_text_column_name} not found in dataset"
f" '{data_args.dataset_name}'. Make sure to set `--train_text_column_name` to the"
" correct text column - one of"
f" {', '.join(raw_datasets_train_features)}."
)
# 6. Load pretrained model, tokenizer, and feature extractor
config = WhisperConfig.from_pretrained(
(model_args.config_name if model_args.config_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
)
feature_extractor = FlaxWhisperFeatureExtractor.from_pretrained(
(model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
)
tokenizer = WhisperTokenizerFast.from_pretrained(
(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
)
# override timestamp tokens until tokenizer issues are fixed in transformers
timestamps = [AddedToken("<|%.2f|>" % (i * 0.02), lstrip=False, rstrip=False) for i in range(1500 + 1)]
tokenizer.add_tokens(timestamps)
config.update(
{
"activation_dropout": model_args.activation_dropout,
"attention_dropout": model_args.attention_dropout,
"dropout": model_args.dropout,
}
)
if training_args.precision == "full_mixed":
# forward pass, backward pass and optimiser states in bf16
dtype = jnp.bfloat16
to_dtype = to_bf16
elif training_args.precision == "half_mixed" or model_args.dtype == "bfloat16":
# forward pass in bf16, backward pass and optimiser states in fp32
dtype = jnp.bfloat16
to_dtype = to_fp32
else:
if training_args.precision != "full":
raise ValueError(
f"`precision` should be one of: `full`, `half_mixed` or `full_mixed`, got {training_args.precision}"
)
# forward pass, backward pass and optimiser states in fp32
dtype = jnp.float32
to_dtype = to_fp32
student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
config=config,
dtype=dtype,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
subfolder=model_args.subfolder,
token=True if model_args.use_auth_token else None,
_do_init=False,
use_scan=model_args.load_with_scan_weights,
)
teacher_model, teacher_params = FlaxWhisperForConditionalGeneration.from_pretrained(
model_args.teacher_model_name_or_path,
# config=config,
dtype=dtype,
cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
_do_init=False,
)
if student_model.config.decoder_start_token_id is None or teacher_model.config.decoder_start_token_id is None:
raise ValueError(
f"Make sure that `config.decoder_start_token_id` is correctly defined for both the "
f"student and teacher model. Got {student_model.config.decoder_start_token_id} for the "
f"student and {teacher_model.config.decoder_start_token_id} for the teacher."
)
# enable scan / gradient checkpointing if necessary
if training_args.use_scan:
student_model.enable_scan() # to enable scan in the nn.Module
student_params = student_model.convert_unroll_to_scan(student_params) # to convert the unrolled params to scan
teacher_model.enable_scan() # faster compile time (even though we don't train the teacher)
teacher_params = teacher_model.convert_unroll_to_scan(teacher_params)
if training_args.gradient_checkpointing:
student_model.enable_gradient_checkpointing() # to enable checkpointing in the nn.Module, there is no change to the params structure
teacher_model.enable_gradient_checkpointing()
if hasattr(teacher_model.generation_config, "is_multilingual") and teacher_model.generation_config.is_multilingual:
# We need to set the language and task ids for previously multilingual checkpoints - for now we hardcode this to English
tokenizer.set_prefix_tokens(language="English", task="transcribe", predict_timestamps=False)
student_model.generation_config.update(
**{
"language": "<|en|>",
"task": "transcribe",
}
)
# 7. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name,
datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
)
# 8. Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
max_label_length = (
data_args.max_label_length if data_args.max_label_length is not None else student_model.config.max_length
)
audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers
dataloader_num_workers = training_args.dataloader_num_workers
dataloader_prefetch_size = data_args.prefetch_size
train_text_column_name = data_args.train_text_column_name
eval_text_column_name = "text"
model_input_name = feature_extractor.model_input_names[0]
normalizer = EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
wer_threshold = data_args.wer_threshold
round_timestamps = data_args.round_timestamps
if training_args.do_train and data_args.max_train_samples is not None:
raw_datasets["train"] = (
raw_datasets["train"].take(data_args.max_train_samples)
if data_args.streaming
else raw_datasets["train"].select(range(data_args.max_train_samples))
)
if training_args.do_eval and data_args.max_eval_samples is not None:
for eval_split in all_eval_splits:
raw_datasets[eval_split] = (
raw_datasets[eval_split].take(data_args.max_eval_samples)
if data_args.streaming
else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
)
def is_wer_in_range(ground_truth, whisper_transcript):
norm_ground_truth = normalizer(ground_truth)
if len(norm_ground_truth) > 0 and whisper_transcript is not None:
norm_whisper_transcript = normalizer(whisper_transcript)
wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
return wer < wer_threshold
else:
# filter automatically since we can't know the WER
return False
filter_by_wer_threshold = partial(
raw_datasets["train"].filter,
function=is_wer_in_range,
input_columns=[eval_text_column_name, train_text_column_name],
)
if wer_threshold is not None:
raw_datasets["train"] = (
filter_by_wer_threshold(num_proc=num_workers, desc="filtering train dataset by wer")
if not data_args.streaming
else filter_by_wer_threshold()
)
def has_timestamp_tokens(input_str):
"""
Identify whether the input string contains timestamp tokens, of the form <|0.00|>, by searching for
pairs of left and right-angle brackets.
"""
return bool(re.search("\<[^\>]*\>", input_str))
def round_timestamp_tokens(input_str: str, ndigits: int = 1):
timestamps = re.findall("\<[^\>]*\>", input_str, re.DOTALL)
for token in timestamps:
# extract time digits from timestamp token, e.g. <|6.24|> to 6.24
time_digit = token[2:-2]
# round to specified number of digits, e.g. 6.24 to 6.2
time_digit = round(float(time_digit), ndigits=ndigits)
# replace in original string with the same precision, e.g. <|6.24|> to <|6.20|>
input_str = input_str.replace(token, "<|{:.2f}|>".format(time_digit))
return input_str
def prepare_train_dataset(batch):
# process audio input
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
batch[model_input_name] = inputs.get(model_input_name)[0]
batch["input_length"] = len(sample["array"])
# process text targets
input_str = batch[train_text_column_name]
# prompt & timestamp processing: for now, we only do one or the other
if input_str.startswith("<|startoftranscript|>") or input_str.startswith("<|startofprev|>"):
# prompted target text already has special ids added, so don't add them here
batch["labels"] = tokenizer(input_str, add_special_tokens=False).input_ids
return batch
has_timestamps = has_timestamp_tokens(input_str)
if has_timestamps:
predict_timestamps = bool(np.random.binomial(1, data_args.timestamp_probability))
if not predict_timestamps:
# filter timestamp token ids if not part of the prediction task
input_str = tokenizer._filter_timestamp_ids(input_str)
elif round_timestamps:
input_str = round_timestamp_tokens(input_str)
else:
predict_timestamps = False
tokenizer.set_prefix_tokens(language="English", task="transcribe", predict_timestamps=predict_timestamps)
input_ids = tokenizer(input_str).input_ids
batch["labels"] = input_ids
return batch
def prepare_eval_dataset(batch):
# process audio
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
# process audio length
batch[model_input_name] = inputs.get(model_input_name)[0]
batch["input_length"] = len(sample["array"])
# process targets
input_str = batch[eval_text_column_name]
batch["labels"] = tokenizer(input_str).input_ids
return batch
vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
if training_args.do_train:
map_fn_train = partial(
raw_datasets["train"].map, function=prepare_train_dataset, remove_columns=raw_datasets_train_features
)
vectorized_datasets["train"] = (
map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
if not data_args.streaming
else map_fn_train()
)
if training_args.do_eval:
for eval_split in all_eval_splits:
raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
map_fn_eval = partial(
raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
)
vectorized_datasets[eval_split] = (
map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
if not data_args.streaming
else map_fn_eval()
)
# filter training data with inputs longer than max_input_length
def is_audio_in_length_range(length):
return min_input_length < length < max_input_length
filter_by_audio_fn = partial(
vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
)
vectorized_datasets = (
filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
if not data_args.streaming
else filter_by_audio_fn()
)
# filter training data with labels longer than max_label_length
def is_labels_in_length_range(labels):
return 0 < len(labels) < max_label_length
filter_by_labels_fn = partial(
vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
)
vectorized_datasets = (
filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
if not data_args.streaming
else filter_by_labels_fn()
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with `args.preprocessing_only` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step `args.preprocessing_only` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only:
cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
return
# 8. Load Metric
metric = evaluate.load("wer")
# convention is that we space all punctuation *except* apostrophes
all_punctuation = list(string.punctuation.replace("'", ""))
return_timestamps = data_args.return_timestamps if data_args.timestamp_probability > 0 else False
def compute_metrics(preds, labels):
# replace padded labels by the padding token
for idx in range(len(labels)):
labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
# we do not want to group tokens when computing the metrics
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
# space punctuation for orthographic WER (c.f. ESB paper https://arxiv.org/abs/2210.13352)
spaced_pred_str = [
pred_str[i].replace(punctuation, f" {punctuation} ")
for punctuation in all_punctuation
for i in range(len(pred_str))
]
spaced_label_str = [
label_str[i].replace(punctuation, f" {punctuation} ")
for punctuation in all_punctuation
for i in range(len(label_str))
]
wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
# normalize everything and re-compute the WER
norm_pred_str = [normalizer(pred) for pred in pred_str]
norm_label_str = [normalizer(label) for label in label_str]
# for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
# filtering step to only evaluate the samples that correspond to non-zero normalized references:
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
# 9. Save feature extractor, tokenizer, config and generation config
feature_extractor.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
student_model.generation_config.save_pretrained(
training_args.output_dir
) # generation config stays bound to model to make it easy to jit
processor = WhisperProcessor.from_pretrained(training_args.output_dir)
data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=student_model.config.decoder_start_token_id, # <|startoftranscript|>
decoder_prev_token_id=tokenizer.all_special_ids[-3], # <|startofprev|>
input_padding="longest",
target_padding="max_length",
max_target_length=max_label_length,
)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
rng, dropout_rng = jax.random.split(rng)
# Store some constants
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
if not data_args.streaming and training_args.max_steps < 0:
num_epochs = int(training_args.num_train_epochs)
steps_per_epoch = len(vectorized_datasets["train"]) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
elif training_args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
total_train_steps = int(training_args.max_steps)
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_epochs = sys.maxsize
steps_per_epoch = total_train_steps
else:
raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
if training_args.eval_steps is None:
logger.info(
f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
)
eval_steps = steps_per_epoch
else:
eval_steps = training_args.eval_steps
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
total_train_steps * gradient_accumulation_steps,
training_args.lr_scheduler_type,
training_args.warmup_steps * gradient_accumulation_steps,
training_args.learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = [
"layer_norm",
"self_attn_layer_norm",
"final_layer_norm",
"encoder_attn_layer_norm",
]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
if gradient_accumulation_steps > 1:
# accumulate gradients and apply once every k steps
adamw = optax.MultiSteps(adamw, every_k_schedule=gradient_accumulation_steps)
share_hidden_states = training_args.freeze_encoder and student_model.config.d_model == teacher_model.config.d_model
encoder_layer_mapping = get_layers_to_supervise(
student_model.config.encoder_layers, teacher_model.config.encoder_layers
)
decoder_layer_mapping = get_layers_to_supervise(
student_model.config.decoder_layers, teacher_model.config.decoder_layers
)
# Setup train state
student_state = TrainState.create(
apply_fn=student_model.decode if share_hidden_states else student_model.__call__,
params=student_params,
tx=adamw,
to_dtype=to_dtype,
dropout_rng=dropout_rng,
max_grad_norm=training_args.max_grad_norm,
)
if training_args.resume_from_checkpoint is not None:
if os.path.isfile(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")):
logger.info(
f"Checkpoint detected, resuming training at {training_args.resume_from_checkpoint}. To avoid "
"this behavior, omit the resume_from_checkpoint argument."
)
with Path(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")).open("rb") as f:
student_state = from_bytes(student_state, f.read())
else:
logger.warning(
f"Checkpoint {training_args.resume_from_checkpoint} not detected, training from scratch. Ensure "
f"you pass the path to a folder with a valid checkpoint for your model."
)
def cross_entropy_loss(logits, labels):
vocab_size = logits.shape[-1]
# optax onehot always returns a float32 device array, need to downcast if performing mixed precision training
onehot_targets = to_dtype(onehot(labels, vocab_size))
loss = optax.softmax_cross_entropy(logits, onehot_targets)
# ignore padded tokens from loss, i.e. where labels are not set to -100
padding = labels >= 0
loss = loss * padding
loss = loss.sum()
num_labels = padding.sum()
return loss, num_labels
# temperature smoothed kl-divergence
def kl_divergence(target_distribution, log_predicted_distribution, labels, eps=1e-20):
divergence = -target_distribution * (log_predicted_distribution - jnp.log(target_distribution + eps))
# ignore padded tokens from divergence, i.e. where labels are not set to -100
padding_mask = labels >= 0
padding_mask = jnp.expand_dims(padding_mask, axis=-1)
divergence = (divergence * padding_mask).sum()
return to_dtype(divergence) # respect the dtype of the backprop
def mean_square_error_loss(student_outputs, teacher_outputs):
mse = dtype(0.0)
# tie encoder embeddings
mse += jnp.mean(
jnp.square(teacher_outputs.encoder_hidden_states[0] - student_outputs.encoder_hidden_states[0])
)
for student_layer_id, teacher_layer_id in encoder_layer_mapping.items():
# offset the hidden-state layer ids by 1 to account for the extra embedding hidden-state
student_hidden_state = student_outputs.encoder_hidden_states[student_layer_id + 1]
teacher_hidden_state = teacher_outputs.encoder_hidden_states[teacher_layer_id + 1]
mse += jnp.mean(jnp.square(teacher_hidden_state - student_hidden_state))
# student_attention = student_outputs.encoder_attentions[student_layer_id]
# teacher_attention = teacher_outputs.encoder_attentions[teacher_layer_id]
# mse += jnp.mean(jnp.square(student_attention - teacher_attention))
# tie decoder embeddings
mse += jnp.mean(
jnp.square(teacher_outputs.decoder_hidden_states[0] - student_outputs.decoder_hidden_states[0])
)
for student_layer_id, teacher_layer_id in decoder_layer_mapping.items():
# offset the hidden-state layer ids by 1 to account for the extra embedding hidden-state
student_hidden_state = student_outputs.decoder_hidden_states[student_layer_id + 1]
teacher_hidden_state = teacher_outputs.decoder_hidden_states[teacher_layer_id + 1]
mse += jnp.mean(jnp.square(teacher_hidden_state - student_hidden_state))
# student_attention = student_outputs.decoder_attentions[student_layer_id]
# teacher_attention = teacher_outputs.decoder_attentions[teacher_layer_id]
# mse += jnp.mean(jnp.square(student_attention - teacher_attention))
# student_cross_attention = student_outputs.cross_attentions[student_layer_id]
# teacher_cross_attention = teacher_outputs.cross_attentions[teacher_layer_id]
# mse += jnp.mean(jnp.square(student_cross_attention - teacher_cross_attention))
return to_dtype(mse) # respect the dtype of the backprop
# Define gradient update step fn
def train_step(
student_state,
teacher_params,
batch,
freeze_encoder,
share_hidden_states,
temperature=2.0,
):
dropout_rng, new_dropout_rng = jax.random.split(student_state.dropout_rng)
def compute_loss(student_params):
labels = batch.pop("labels")
output_hidden_states = not share_hidden_states and training_args.mse_weight > 0.0
teacher_outputs = teacher_model(
**batch,
params=teacher_params,
freeze_encoder=True,
output_hidden_states=output_hidden_states,
train=False,
)
if share_hidden_states:
# if the student and teacher share the same frozen encoder then we don't have to recompute the
# encoder hidden-states for the student model, we can just re-use from the teacher
encoder_hidden_states = jax.lax.stop_gradient(teacher_outputs.encoder_last_hidden_state)
encoder_outputs = FlaxBaseModelOutput(last_hidden_state=encoder_hidden_states)
student_outputs = student_state.apply_fn(
decoder_input_ids=batch["decoder_input_ids"],
encoder_outputs=encoder_outputs,
params=student_params,
dropout_rng=dropout_rng,
train=True,
)
else:
# do the full forward pass for the student model (encoder + decoder)
student_outputs = student_state.apply_fn(
**batch,
params=student_params,
dropout_rng=dropout_rng,
freeze_encoder=freeze_encoder,
output_hidden_states=output_hidden_states,
train=True,
)
# CE (data) loss
ce_loss, num_labels = cross_entropy_loss(student_outputs.logits, labels)
# rescale by temperature to ensure gradients scale correctly
teacher_distribution = jax.nn.softmax(teacher_outputs.logits / temperature, axis=-1)
# ensure no information flow backwards through teacher
teacher_distribution = jax.lax.stop_gradient(teacher_distribution)
# log softmax of student predictions for numerical stability
student_distribution = jax.nn.log_softmax(student_outputs.logits / temperature, axis=-1)
# KL-divergence loss (scaled by temperature)
kl_loss = kl_divergence(teacher_distribution, student_distribution, labels) * temperature**2
# MSE loss between enc-dec hidden-states and attentions
mse_loss = (
mean_square_error_loss(student_outputs, teacher_outputs)
if output_hidden_states
else jnp.zeros_like(kl_loss)
)
# use DistilBart formulation - only tune the MSE weight and take remaining HPs from DistilBERT
ce_weight = 0.8 if training_args.kl_weight > 0 else 1.0
loss = ce_weight * ce_loss + training_args.kl_weight * kl_loss + training_args.mse_weight * mse_loss
return loss, (
ce_loss,
kl_loss,
mse_loss,
num_labels,
)
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, (ce_loss, kl_loss, mse_loss, num_labels)), grad = grad_fn(to_dtype(student_state.params))
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
num_labels = jax.lax.psum(num_labels, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = student_state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng, to_dtype=to_dtype)
# CE/KL/MSE losses for logging
ce_loss = jax.lax.psum(ce_loss, "batch")
ce_loss = jax.tree_util.tree_map(lambda x: x / num_labels, ce_loss)
kl_loss = jax.lax.psum(kl_loss, "batch")
kl_loss = jax.tree_util.tree_map(lambda x: x / num_labels, kl_loss)
mse_loss = jax.lax.psum(mse_loss, "batch")
mse_loss = jax.tree_util.tree_map(lambda x: x / num_labels, mse_loss)
metrics = {
"loss": loss,
"learning_rate": linear_decay_lr_schedule_fn(student_state.step),
"ce_loss": ce_loss,
"kl_loss": kl_loss,
"mse_loss": mse_loss,
}
return new_state, metrics
# Define eval fn
def eval_step(student_params, teacher_params, batch):
labels = batch.pop("labels")
output_hidden_states = not share_hidden_states and training_args.mse_weight > 0
student_outputs = student_model(
**batch,
params=student_params,
output_hidden_states=output_hidden_states,
train=False,
)
student_distribution = jax.nn.log_softmax(student_outputs.logits, axis=-1)
ce_loss, num_labels = cross_entropy_loss(student_outputs.logits, labels)
teacher_outputs = teacher_model(
**batch,
params=teacher_params,
output_hidden_states=output_hidden_states,
train=False,
)
teacher_distribution = jax.nn.softmax(teacher_outputs.logits, axis=-1)
# temperature is always 1 for eval
kl_loss = kl_divergence(teacher_distribution, student_distribution, labels)
mse_loss = (
mean_square_error_loss(student_outputs, teacher_outputs)
if output_hidden_states
else jnp.zeros_like(kl_loss)
)
ce_weight = 0.8 if training_args.kl_weight > 0 else 1.0
loss = ce_weight * ce_loss + training_args.kl_weight * kl_loss + training_args.mse_weight * mse_loss
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
num_labels = jax.lax.psum(num_labels, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# CE/KL/MSE losses for logging
ce_loss = jax.lax.psum(ce_loss, "batch")
ce_loss = jax.tree_util.tree_map(lambda x: x / num_labels, ce_loss)
kl_loss = jax.lax.psum(kl_loss, "batch")
kl_loss = jax.tree_util.tree_map(lambda x: x / num_labels, kl_loss)
mse_loss = jax.lax.psum(mse_loss, "batch")
mse_loss = jax.tree_util.tree_map(lambda x: x / num_labels, mse_loss)
metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss, "mse_loss": mse_loss}
return metrics
# Define generation function
num_beams = (
training_args.generation_num_beams
if training_args.generation_num_beams is not None
else student_model.config.num_beams
)
# forcing the language and task tokens helps the model in its generations
gen_kwargs = {
"max_length": max_label_length,
"num_beams": num_beams,
"language": "<|en|>",
"task": "transcribe",
"return_timestamps": return_timestamps,
}
def generate_step(student_params, batch):
output_ids = student_model.generate(
batch[model_input_name],
attention_mask=batch.get("attention_mask"),
params=student_params,
**gen_kwargs,
)
return output_ids.sequences
# Replicate the train state on each device
student_state = student_state.replicate()
# Replicate the teacher params on each device
teacher_params = jax_utils.replicate(teacher_params)
# Create parallel version of the train and eval step
p_train_step = jax.pmap(
train_step,
"batch",
in_axes=(0, 0, 0, None, None, None),
donate_argnums=(0,),
static_broadcasted_argnums=(
3,
4,
),
)
p_eval_step = jax.pmap(eval_step, "batch")
p_generate_step = jax.pmap(generate_step, "batch")
logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
logger.info(
f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {total_train_steps}")
# ======================== Training ================================
train_time = 0
train_start = time.time()
train_metrics = []
batches_to_skip = jax.device_get(unreplicate(student_state.step))
cur_step = int(batches_to_skip) # will be zero if starting from scratch
epochs_trained = batches_to_skip // steps_per_epoch
steps_trained_progress_bar = tqdm(range(total_train_steps), desc="Train steps ... ", position=0)
steps_trained_progress_bar.update(batches_to_skip)
continue_training = True
minibatch_steps = 0
if batches_to_skip > 0:
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {batches_to_skip}")
# Generate a training data loader by shuffling sampling indices from the train dataset
train_loader = get_data_loader(
training_args.seed,
vectorized_datasets["train"],
batch_size=train_batch_size,
data_collator=data_collator,
dataloader_num_workers=dataloader_num_workers,
skip_batches=batches_to_skip,
prefetch_size=dataloader_prefetch_size,
)
for epoch in range(epochs_trained, num_epochs):
if hasattr(train_loader, "dataset") and isinstance(train_loader.dataset, IterableDataset):
train_loader.dataset.set_epoch(epoch)
for batch in train_loader:
minibatch_steps += 1
update_step = minibatch_steps == gradient_accumulation_steps
if update_step:
steps_trained_progress_bar.update(1)
cur_step += 1
minibatch_steps = 0
batch = shard(batch.data)
student_state, train_metric = p_train_step(
student_state,
teacher_params,
batch,
training_args.freeze_encoder,
share_hidden_states,
training_args.temperature,
)
if cur_step % training_args.logging_steps == 0 and update_step:
train_metrics.append(train_metric)
train_metric_to_write = unreplicate(train_metric)
steps_trained_progress_bar.write(
f"Step... ({cur_step} / {total_train_steps} | Loss:"
f" {train_metric_to_write['loss']}, Learning Rate:"
f" {train_metric_to_write['learning_rate']})"
)
if has_wandb and jax.process_index() == 0:
write_wandb_metric(
wandb_logger,
train_metric_to_write,
train_time + time.time() - train_start,
cur_step,
epoch,
prefix="train",
)
# save checkpoint and weights after each save_steps and at the end of training
if (cur_step % training_args.save_steps == 0 and update_step) or cur_step == total_train_steps:
if jax.process_index() == 0:
save_hf_weights(
student_state,
student_model,
processor,
training_args.output_dir,
cur_step,
total_train_steps,
use_scan=training_args.use_scan,
)
if training_args.save_train_state:
student_state.save_state(
training_args.output_dir, save_total_limit=training_args.save_total_limit
)
if training_args.push_to_hub:
repo.push_to_hub(
commit_message=f"Saving train state of step {cur_step}",
blocking=False,
)
if training_args.do_eval and (
(cur_step % eval_steps == 0 and update_step) or cur_step == total_train_steps
):
train_time += time.time() - train_start
# ======================== Evaluating ==============================
for eval_split in all_eval_splits:
eval_metrics = []
eval_preds = []
eval_labels = []
eval_start = time.time()
eval_loader = get_data_loader(
training_args.seed,
vectorized_datasets[eval_split],
batch_size=eval_batch_size,
data_collator=data_collator,
shuffle=False,
drop_last=False,
dataloader_num_workers=dataloader_num_workers,
)
for batch in tqdm(eval_loader, desc=f"Evaluating {eval_split}...", position=2):
# Model forward
labels = batch["labels"]
metrics = pad_shard_unpad(
p_eval_step,
static_argnums=(
0,
1,
),
static_return=True,
)(
student_state.params,
teacher_params,
batch.data,
min_device_batch=per_device_eval_batch_size,
)
eval_metrics.append(metrics)
# generation
if training_args.predict_with_generate:
generated_ids = pad_shard_unpad(p_generate_step)(
student_state.params, batch.data, min_device_batch=per_device_eval_batch_size
)
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
eval_labels.extend(labels)
eval_time = time.time() - eval_start
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# compute WER metric
wer_desc = ""
if training_args.predict_with_generate:
wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
eval_preds, eval_labels
)
eval_metrics.update(wer_metric)
wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
# Print metrics and update progress bar
steps_trained_progress_bar.write(
f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
f" {wer_desc})"
)
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(
summary_writer,
eval_metrics,
cur_step,
prefix=eval_split,
)
if has_wandb and jax.process_index() == 0:
write_wandb_metric(wandb_logger, eval_metrics, eval_time, cur_step, epoch, prefix=eval_split)
if training_args.predict_with_generate:
write_wandb_pred(
wandb_logger,
pred_str,
label_str,
norm_pred_str,
norm_label_str,
cur_step,
prefix=eval_split,
)
if has_tensorboard and jax.process_index() == 0:
# we'll only log to tensorboard every eval steps
write_train_metric(
summary_writer,
train_metrics,
train_time,
cur_step,
training_args.logging_steps,
)
# flush the train metrics
train_start = time.time()
train_metrics = []
# break condition
if cur_step == total_train_steps:
continue_training = False
break
if not continue_training:
break