training/utils.py (155 lines of code) (raw):
import os
import re
import shutil
from dataclasses import field
from pathlib import Path
from typing import Dict, List
import torch
from datasets import concatenate_datasets, load_from_disk
from wandb import Audio
from datasets import load_from_disk, concatenate_datasets
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
CHECKPOINT_CODEC_PREFIX = "checkpoint"
_RE_CODEC_CHECKPOINT = re.compile(r"^checkpoint-(\d+)$")
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [
path
for path in content
if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
]
if len(checkpoints) == 0:
return
return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
"""Helper function to sort saved checkpoints from oldest to newest."""
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
for path in glob_checkpoints:
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint", logger=None) -> None:
"""Helper function to delete old checkpoints."""
if save_total_limit is None or save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
if len(checkpoints_sorted) <= save_total_limit:
return
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)
def save_codec_checkpoint(output_dir, dataset, step):
checkpoint_path = f"{CHECKPOINT_CODEC_PREFIX}-{step}"
output_path = os.path.join(output_dir, checkpoint_path)
dataset.save_to_disk(output_path)
def load_codec_checkpoint(checkpoint_path):
dataset = load_from_disk(checkpoint_path)
return dataset
def sorted_codec_checkpoints(output_dir=None) -> List[str]:
"""Helper function to sort saved checkpoints from oldest to newest."""
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_CODEC_PREFIX}-*")]
for path in glob_checkpoints:
regex_match = re.match(f".*{CHECKPOINT_CODEC_PREFIX}-([0-9]+)", path)
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def load_all_codec_checkpoints(output_dir=None) -> List[str]:
"""Helper function to load and concat all checkpoints."""
checkpoints_sorted = sorted_codec_checkpoints(output_dir=output_dir)
datasets = [load_from_disk(checkpoint) for checkpoint in checkpoints_sorted]
datasets = concatenate_datasets(datasets, axis=0)
return datasets
def get_last_codec_checkpoint_step(folder) -> int:
if not os.path.exists(folder) or not os.path.isdir(folder):
os.makedirs(folder, exist_ok=True)
return 0
content = os.listdir(folder)
checkpoints = [path for path in content if _RE_CODEC_CHECKPOINT.search(path) is not None]
if len(checkpoints) == 0:
return 0
last_checkpoint = os.path.join(
folder, max(checkpoints, key=lambda x: int(_RE_CODEC_CHECKPOINT.search(x).groups()[0]))
)
# Find num steps saved state string pattern
pattern = r"checkpoint-(\d+)"
match = re.search(pattern, last_checkpoint)
cur_step = int(match.group(1))
return cur_step
def log_metric(
accelerator,
metrics: Dict,
train_time: float,
step: int,
epoch: int,
learning_rate: float = None,
prefix: str = "train",
):
"""Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
log_metrics = {}
for k, v in metrics.items():
if "codebook" in k:
log_metrics[f"codebook_{prefix}/{k}"] = v
else:
log_metrics[f"{prefix}/{k}"] = v
log_metrics[f"{prefix}/time"] = train_time
log_metrics[f"{prefix}/epoch"] = epoch
if learning_rate is not None:
log_metrics[f"{prefix}/learning_rate"] = learning_rate
accelerator.log(log_metrics, step=step)
def log_pred(
accelerator,
pred_descriptions: List[str],
pred_prompts: List[str],
transcriptions: List[str],
audios: List[torch.Tensor],
si_sdr_measures: List[float],
sampling_rate: int,
step: int,
prefix: str = "eval",
num_lines: int = 200000,
):
"""Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
if accelerator.is_main_process:
wandb_tracker = accelerator.get_tracker("wandb")
# pretty name for current step: step 50000 -> step 50k
cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
prefix_pretty = prefix.replace("/", "-")
if si_sdr_measures is None:
# convert str data to a wandb compatible format
str_data = [
[pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))
]
# log as a table with the appropriate headers
wandb_tracker.log_table(
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
data=str_data[:num_lines],
step=step,
commit=False,
)
else:
# convert str data to a wandb compatible format
str_data = [
[pred_descriptions[i], pred_prompts[i], transcriptions[i], si_sdr_measures[i]]
for i in range(len(pred_descriptions))
]
# log as a table with the appropriate headers
wandb_tracker.log_table(
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
columns=["Target descriptions", "Target prompts", "Predicted transcriptions", "Noise estimation"],
data=str_data[:num_lines],
step=step,
commit=False,
)
# wandb can only loads 100 audios per step
wandb_tracker.log(
{
"Speech samples": [
Audio(
audio,
caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}",
sample_rate=sampling_rate,
)
for (i, audio) in enumerate(audios[: min(len(audios), 100)])
]
},
step=step,
)