in benchmarks/rnnt/ootb/train/train.py [0:0]
def main():
args = parse_args()
if args.mlperf:
logging.configure_logger('RNNT')
logging.log_start(logging.constants.INIT_START)
if args.fb5logger is not None:
fb5logger = FB5Logger(args.fb5logger)
fb5logger.header("RNN-T", "OOTB", "train", args.fb5config, score_metric=loggerconstants.EXPS)
assert(torch.cuda.is_available())
assert args.prediction_frequency is None or args.prediction_frequency % args.log_frequency == 0
torch.backends.cudnn.benchmark = args.cudnn_benchmark
# set up distributed training
multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
if multi_gpu:
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
world_size = dist.get_world_size()
print_once(f'Distributed training with {world_size} GPUs\n')
else:
world_size = 1
if args.seed is not None:
if args.mlperf:
logging.log_event(logging.constants.SEED, value=args.seed)
torch.manual_seed(args.seed + args.local_rank)
np.random.seed(args.seed + args.local_rank)
random.seed(args.seed + args.local_rank)
# np_rng is used for buckets generation, and needs the same seed on every worker
np_rng = np.random.default_rng(seed=args.seed)
init_log(args)
cfg = config.load(args.model_config)
config.apply_duration_flags(cfg, args.max_duration)
assert args.grad_accumulation_steps >= 1
assert args.batch_size % args.grad_accumulation_steps == 0, \
f'{args.batch_size} % {args.grad_accumulation_steps} != 0'
batch_size = args.batch_size // args.grad_accumulation_steps
if args.mlperf:
logging.log_event(logging.constants.GRADIENT_ACCUMULATION_STEPS, value=args.grad_accumulation_steps)
logging.log_event(logging.constants.SUBMISSION_BENCHMARK, value=logging.constants.RNNT)
logging.log_event(logging.constants.SUBMISSION_ORG, value='my-organization')
logging.log_event(logging.constants.SUBMISSION_DIVISION, value=logging.constants.CLOSED) # closed or open
logging.log_event(logging.constants.SUBMISSION_STATUS, value=logging.constants.ONPREM) # on-prem/cloud/research
logging.log_event(logging.constants.SUBMISSION_PLATFORM, value='my platform')
logging.log_end(logging.constants.INIT_STOP)
if multi_gpu:
torch.distributed.barrier()
if args.mlperf:
logging.log_start(logging.constants.RUN_START)
if multi_gpu:
torch.distributed.barrier()
print_once('Setting up datasets...')
(
train_dataset_kw,
train_features_kw,
train_splicing_kw,
train_specaugm_kw,
) = config.input(cfg, 'train')
(
val_dataset_kw,
val_features_kw,
val_splicing_kw,
val_specaugm_kw,
) = config.input(cfg, 'val')
if args.mlperf:
logging.log_event(logging.constants.DATA_TRAIN_MAX_DURATION,
value=train_dataset_kw['max_duration'])
logging.log_event(logging.constants.DATA_SPEED_PERTURBATON_MAX,
value=train_dataset_kw['speed_perturbation']['max_rate'])
logging.log_event(logging.constants.DATA_SPEED_PERTURBATON_MIN,
value=train_dataset_kw['speed_perturbation']['min_rate'])
logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_N,
value=train_specaugm_kw['freq_masks'])
logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_MIN,
value=train_specaugm_kw['min_freq'])
logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_MAX,
value=train_specaugm_kw['max_freq'])
logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_N,
value=train_specaugm_kw['time_masks'])
logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_MIN,
value=train_specaugm_kw['min_time'])
logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_MAX,
value=train_specaugm_kw['max_time'])
logging.log_event(logging.constants.GLOBAL_BATCH_SIZE,
value=batch_size * world_size * args.grad_accumulation_steps)
tokenizer_kw = config.tokenizer(cfg)
tokenizer = Tokenizer(**tokenizer_kw)
class PermuteAudio(torch.nn.Module):
def forward(self, x):
return (x[0].permute(2, 0, 1), *x[1:])
train_augmentations = torch.nn.Sequential(
train_specaugm_kw and features.SpecAugment(optim_level=args.amp, **train_specaugm_kw) or torch.nn.Identity(),
features.FrameSplicing(optim_level=args.amp, **train_splicing_kw),
PermuteAudio(),
)
val_augmentations = torch.nn.Sequential(
val_specaugm_kw and features.SpecAugment(optim_level=args.amp, **val_specaugm_kw) or torch.nn.Identity(),
features.FrameSplicing(optim_level=args.amp, **val_splicing_kw),
PermuteAudio(),
)
if args.mlperf:
logging.log_event(logging.constants.DATA_TRAIN_NUM_BUCKETS, value=args.num_buckets)
if args.num_buckets is not None:
sampler = dali_sampler.BucketingSampler(
args.num_buckets,
batch_size,
world_size,
args.epochs,
np_rng
)
else:
sampler = dali_sampler.SimpleSampler()
train_loader = DaliDataLoader(gpu_id=args.local_rank,
dataset_path=args.dataset_dir,
config_data=train_dataset_kw,
config_features=train_features_kw,
json_names=args.train_manifests,
batch_size=batch_size,
sampler=sampler,
grad_accumulation_steps=args.grad_accumulation_steps,
pipeline_type="train",
device_type=args.dali_device,
tokenizer=tokenizer)
val_loader = DaliDataLoader(gpu_id=args.local_rank,
dataset_path=args.dataset_dir,
config_data=val_dataset_kw,
config_features=val_features_kw,
json_names=args.val_manifests,
batch_size=args.val_batch_size,
sampler=dali_sampler.SimpleSampler(),
pipeline_type="val",
device_type=args.dali_device,
tokenizer=tokenizer)
train_feat_proc = train_augmentations
val_feat_proc = val_augmentations
train_feat_proc.cuda()
val_feat_proc.cuda()
steps_per_epoch = len(train_loader) // args.grad_accumulation_steps
if args.mlperf:
logging.log_event(logging.constants.TRAIN_SAMPLES, value=train_loader.dataset_size)
logging.log_event(logging.constants.EVAL_SAMPLES, value=val_loader.dataset_size)
# set up the model
rnnt_config = config.rnnt(cfg)
rnnt_config['mlperf'] = args.mlperf
if args.mlperf:
logging.log_event(logging.constants.MODEL_WEIGHTS_INITIALIZATION_SCALE, value=args.weights_init_scale)
if args.weights_init_scale is not None:
rnnt_config['weights_init_scale'] = args.weights_init_scale
if args.hidden_hidden_bias_scale is not None:
rnnt_config['hidden_hidden_bias_scale'] = args.hidden_hidden_bias_scale
model = RNNT(n_classes=tokenizer.num_labels + 1, **rnnt_config)
model.cuda()
blank_idx = tokenizer.num_labels
loss_fn = RNNTLoss(blank_idx=blank_idx)
if args.mlperf:
logging.log_event(logging.constants.EVAL_MAX_PREDICTION_SYMBOLS, value=args.max_symbol_per_sample)
greedy_decoder = RNNTGreedyDecoder(blank_idx=blank_idx,
max_symbol_per_sample=args.max_symbol_per_sample)
print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')
opt_eps = 1e-9
if args.mlperf:
logging.log_event(logging.constants.OPT_NAME, value='lamb')
logging.log_event(logging.constants.OPT_BASE_LR, value=args.lr)
logging.log_event(logging.constants.OPT_LAMB_EPSILON, value=opt_eps)
logging.log_event(logging.constants.OPT_LAMB_LR_DECAY_POLY_POWER, value=args.lr_exp_gamma)
logging.log_event(logging.constants.OPT_LR_WARMUP_EPOCHS, value=args.warmup_epochs)
logging.log_event(logging.constants.OPT_LAMB_LR_HOLD_EPOCHS, value=args.hold_epochs)
logging.log_event(logging.constants.OPT_LAMB_BETA_1, value=args.beta1)
logging.log_event(logging.constants.OPT_LAMB_BETA_2, value=args.beta2)
logging.log_event(logging.constants.OPT_GRADIENT_CLIP_NORM, value=args.clip_norm)
logging.log_event(logging.constants.OPT_LR_ALT_DECAY_FUNC, value=True)
logging.log_event(logging.constants.OPT_LR_ALT_WARMUP_FUNC, value=True)
logging.log_event(logging.constants.OPT_LAMB_LR_MIN, value=args.min_lr)
logging.log_event(logging.constants.OPT_WEIGHT_DECAY, value=args.weight_decay)
# optimization
kw = {'params': model.param_groups(args.lr), 'lr': args.lr,
'weight_decay': args.weight_decay}
initial_lrs = [group['lr'] for group in kw['params']]
print_once(f'Starting with LRs: {initial_lrs}')
optimizer = FusedLAMB(betas=(args.beta1, args.beta2), eps=opt_eps, max_grad_norm=args.clip_norm, **kw)
adjust_lr = lambda step, epoch: lr_policy(
step, epoch, initial_lrs, optimizer, steps_per_epoch=steps_per_epoch,
warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs,
min_lr=args.min_lr, exp_gamma=args.lr_exp_gamma)
if args.amp:
model, optimizer = amp.initialize(
models=model,
optimizers=optimizer,
opt_level='O1',
max_loss_scale=512.0)
if args.ema > 0:
ema_model = copy.deepcopy(model).cuda()
else:
ema_model = None
if args.mlperf:
logging.log_event(logging.constants.MODEL_EVAL_EMA_FACTOR, value=args.ema)
if multi_gpu:
model = DistributedDataParallel(model)
# load checkpoint
meta = {'best_wer': 10**6, 'start_epoch': 0}
checkpointer = Checkpointer(args.output_dir, 'RNN-T',
args.keep_milestones, args.amp)
if args.resume:
args.ckpt = checkpointer.last_checkpoint() or args.ckpt
if args.ckpt is not None:
checkpointer.load(args.ckpt, model, ema_model, optimizer, meta)
start_epoch = meta['start_epoch']
best_wer = meta['best_wer']
last_wer = meta['best_wer']
epoch = 1
step = start_epoch * steps_per_epoch + 1
# FB5 Log for a certain amount of time.
if args.fb5logger is not None:
fb5logger.run_start()
total_batches = 0
start_time = time.time()
MAX_TIME = 120.0
# Start Batch Loop
# training loop
model.train()
for epoch in range(start_epoch + 1, args.epochs + 1):
if args.mlperf:
logging.log_start(logging.constants.BLOCK_START,
metadata=dict(first_epoch_num=epoch,
epoch_count=1))
logging.log_start(logging.constants.EPOCH_START,
metadata=dict(epoch_num=epoch))
epoch_utts = 0
accumulated_batches = 0
epoch_start_time = time.time()
for batch in train_loader:
if accumulated_batches == 0:
adjust_lr(step, epoch)
optimizer.zero_grad()
step_utts = 0
step_start_time = time.time()
all_feat_lens = []
audio, audio_lens, txt, txt_lens = batch
feats, feat_lens = train_feat_proc([audio, audio_lens])
all_feat_lens += feat_lens
log_probs, log_prob_lens = model(feats, feat_lens, txt, txt_lens)
loss = loss_fn(log_probs[:, :log_prob_lens.max().item()],
log_prob_lens, txt, txt_lens)
loss /= args.grad_accumulation_steps
del log_probs, log_prob_lens
if torch.isnan(loss).any():
print_once('WARNING: loss is NaN; skipping update')
else:
if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
loss_item = loss.item()
del loss
step_utts += batch[0].size(0) * world_size
epoch_utts += batch[0].size(0) * world_size
accumulated_batches += 1
total_batches += 1
if accumulated_batches % args.grad_accumulation_steps == 0:
total_norm = 0.0
try:
if args.log_norm:
for p in getattr(model, 'module', model).parameters():
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
except AttributeError as e:
print_once(f'Exception happened: {e}')
total_norm = 0.0
optimizer.step()
apply_ema(model, ema_model, args.ema)
if step % args.log_frequency == 0 or (time.time() - start_time) > MAX_TIME:
if args.prediction_frequency is None or step % args.prediction_frequency == 0:
preds = greedy_decoder.decode(model, feats, feat_lens)
wer, pred_utt, ref = greedy_wer(preds,
txt,
txt_lens,
tokenizer.detokenize)
print_once(f' Decoded: {pred_utt[:90]}')
print_once(f' Reference: {ref[:90]}')
wer = {'wer': 100 * wer}
else:
wer = {}
step_time = time.time() - step_start_time
log((epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch),
step, 'train',
{'loss': loss_item,
**wer, # optional entry
'throughput': step_utts / step_time,
'took': step_time,
'grad-norm': total_norm,
'seq-len-min': min(all_feat_lens).item(),
'seq-len-max': max(all_feat_lens).item(),
'lrate': optimizer.param_groups[0]['lr']})
# FB5 Logger
if (time.time() - start_time) > MAX_TIME:
break
step_start_time = time.time()
step += 1
accumulated_batches = 0
# end of step
if args.mlperf:
logging.log_end(logging.constants.EPOCH_STOP,
metadata=dict(epoch_num=epoch))
epoch_time = time.time() - epoch_start_time
log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
'took': epoch_time})
# FB5 Logger
if (time.time() - start_time) > MAX_TIME:
break
if epoch % args.val_frequency == 0:
wer = evaluate(epoch, step, val_loader, val_feat_proc,
tokenizer.detokenize, ema_model, loss_fn,
greedy_decoder, args.amp, args)
last_wer = wer
if wer < best_wer and epoch >= args.save_best_from:
checkpointer.save(model, ema_model, optimizer, epoch,
step, best_wer, is_best=True)
best_wer = wer
save_this_epoch = (args.save_frequency is not None and epoch % args.save_frequency == 0) \
or (epoch in args.keep_milestones)
if save_this_epoch:
checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)
if args.mlperf:
logging.log_end(logging.constants.BLOCK_STOP, metadata=dict(first_epoch_num=epoch))
if last_wer <= args.target:
if args.mlperf:
logging.log_end(logging.constants.RUN_STOP, metadata={'status': 'success'})
if args.fb5logger is not None:
fb5logger.run_stop(total_batches, args.batch_size)
print_once(f'Finished after {args.epochs_this_job} epochs.')
break
if 0 < args.epochs_this_job <= epoch - start_epoch:
print_once(f'Finished after {args.epochs_this_job} epochs.')
break
# end of epoch
log((), None, 'train_avg', {'throughput': epoch_utts / epoch_time})
if last_wer > args.target:
if args.mlperf:
logging.log_end(logging.constants.RUN_STOP, metadata={'status': 'aborted'})
if args.fb5logger is not None:
fb5logger.run_stop(total_batches, args.batch_size)
if epoch == args.epochs:
evaluate(epoch, step, val_loader, val_feat_proc, tokenizer.detokenize,
ema_model, loss_fn, greedy_decoder, args.amp, args)
flush_log()
if args.save_at_the_end:
checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)