in models/nlp/albert/run_pretraining.py [0:0]
def main():
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments, PathArguments)
)
(
model_args,
data_args,
train_args,
log_args,
path_args,
remaining_strings,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
# SageMaker may have some extra strings. TODO: Test this on SM.
assert len(remaining_strings) == 0, f"The args {remaining_strings} could not be parsed."
tf.random.set_seed(train_args.seed)
tf.autograph.set_verbosity(0)
# Settings init
parse_bool = lambda arg: arg == "true"
do_gradient_accumulation = train_args.gradient_accumulation_steps > 1
do_xla = not parse_bool(train_args.skip_xla)
do_eager = parse_bool(train_args.eager)
skip_sop = parse_bool(train_args.skip_sop)
skip_mlm = parse_bool(train_args.skip_mlm)
pre_layer_norm = parse_bool(model_args.pre_layer_norm)
fast_squad = parse_bool(log_args.fast_squad)
dummy_eval = parse_bool(log_args.dummy_eval)
is_sagemaker = path_args.filesystem_prefix.startswith("/opt/ml")
disable_tqdm = is_sagemaker
global max_grad_norm
max_grad_norm = train_args.max_grad_norm
# Horovod init
hvd.init()
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU")
# XLA, AutoGraph
tf.config.optimizer.set_jit(do_xla)
tf.config.experimental_run_functions_eagerly(do_eager)
if hvd.rank() == 0:
# Run name should only be used on one process to avoid race conditions
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
platform = "sm" if is_sagemaker else "eks"
if skip_sop:
loss_str = "-skipsop"
elif skip_mlm:
loss_str = "-skipmlm"
else:
loss_str = ""
if log_args.run_name is None:
metadata = (
f"{model_args.model_type}"
f"-{model_args.model_size}"
f"-{model_args.load_from}"
f"-{hvd.size()}gpus"
f"-{train_args.per_gpu_batch_size * hvd.size() * train_args.gradient_accumulation_steps}globalbatch"
f"-{train_args.learning_rate}maxlr"
f"-{train_args.learning_rate_decay_power}power"
f"-{train_args.optimizer}opt"
f"-{train_args.total_steps}steps"
f"-{'preln' if pre_layer_norm else 'postln'}"
f"{loss_str}"
f"-{model_args.hidden_dropout_prob}dropout"
)
run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"
else:
run_name = log_args.run_name
# Logging should only happen on a single process
# https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
level = logging.INFO
format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
handlers = [
logging.FileHandler(
os.path.join(path_args.filesystem_prefix, path_args.log_dir, f"{run_name}.log")
),
TqdmLoggingHandler(),
]
logging.basicConfig(level=level, format=format, handlers=handlers)
# Check that arguments passed in properly, only after registering the alert_func and logging
assert not (skip_sop and skip_mlm), "Cannot use --skip_sop and --skip_mlm"
wrap_global_functions(do_gradient_accumulation)
# Create optimizer and enable AMP loss scaling.
if train_args.optimizer == "lamb":
optimizer = get_lamb_optimizer(train_args)
elif train_args.optimizer == "adamw":
optimizer = get_adamw_optimizer(train_args)
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale="dynamic"
)
gradient_accumulator = GradientAccumulator()
loaded_optimizer_weights = None
model = create_model(model_class=TFAutoModelForPreTraining, model_args=model_args)
tokenizer = create_tokenizer(model_args.model_type)
if model_args.load_from == "checkpoint":
checkpoint_path = os.path.join(path_args.filesystem_prefix, model_args.checkpoint_path)
model_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(checkpoint_path)
if hvd.rank() == 0:
model.load_weights(model_ckpt)
if model_args.load_optimizer_state == "true":
loaded_optimizer_weights = np.load(optimizer_ckpt, allow_pickle=True)
# We do not set the weights yet, we have to do a first step to initialize the optimizer.
# Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories
# Move to same folder structure and remove if/else
train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir, "*.tfrecord")
validation_glob = os.path.join(path_args.filesystem_prefix, path_args.val_dir, "*.tfrecord")
train_filenames = glob.glob(train_glob)
validation_filenames = glob.glob(validation_glob)
train_dataset = get_dataset_from_tfrecords(
model_type=model_args.model_type,
filenames=train_filenames,
max_seq_length=data_args.max_seq_length,
max_predictions_per_seq=data_args.max_predictions_per_seq,
per_gpu_batch_size=train_args.per_gpu_batch_size,
) # Of shape [per_gpu_batch_size, ...]
# Batch of batches, helpful for gradient accumulation. Shape [grad_steps, per_gpu_batch_size, ...]
train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps)
# One iteration with 10 dupes, 8 nodes seems to be 60-70k steps.
train_dataset = train_dataset.prefetch(buffer_size=8)
# Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks
if hvd.rank() == 0:
validation_dataset = get_dataset_from_tfrecords(
model_type=model_args.model_type,
filenames=validation_filenames,
max_seq_length=data_args.max_seq_length,
max_predictions_per_seq=data_args.max_predictions_per_seq,
per_gpu_batch_size=train_args.per_gpu_batch_size,
)
# validation_dataset = validation_dataset.batch(1)
validation_dataset = validation_dataset.prefetch(buffer_size=8)
pbar = tqdm.tqdm(total=train_args.total_steps, disable=disable_tqdm)
summary_writer = None # Only create a writer if we make it through a successful step
logger.info(f"Starting training, job name {run_name}")
i = 1
start_time = time.perf_counter()
for batch in train_dataset:
learning_rate = optimizer.learning_rate(step=tf.constant(i, dtype=tf.float32))
# weight_decay = wd_schedule(step=tf.constant(i, dtype=tf.float32))
loss_scale = optimizer.loss_scale
loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step(
model=model,
optimizer=optimizer,
gradient_accumulator=gradient_accumulator,
batch=batch,
gradient_accumulation_steps=train_args.gradient_accumulation_steps,
skip_sop=skip_sop,
skip_mlm=skip_mlm,
)
# Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors
if i == 1:
if hvd.rank() == 0 and loaded_optimizer_weights is not None:
optimizer.set_weights(loaded_optimizer_weights)
hvd.broadcast_variables(model.variables, root_rank=0)
hvd.broadcast_variables(optimizer.variables(), root_rank=0)
i = optimizer.get_weights()[0]
is_final_step = i >= train_args.total_steps
do_squad = (log_args.squad_frequency != 0) and (
(i % log_args.squad_frequency == 0) or is_final_step
)
# Squad requires all the ranks to train, but results are only returned on rank 0
if do_squad:
squad_results = get_squad_results_while_pretraining(
model=model,
tokenizer=tokenizer,
model_size=model_args.model_size,
filesystem_prefix=path_args.filesystem_prefix,
step=i,
dataset=data_args.squad_version,
fast=log_args.fast_squad,
dummy_eval=log_args.dummy_eval,
)
if hvd.rank() == 0:
squad_exact, squad_f1 = squad_results["exact"], squad_results["f1"]
logger.info(f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}")
# Re-wrap autograph so it doesn't get arg mismatches
wrap_global_functions(do_gradient_accumulation)
gc.collect()
if hvd.rank() == 0:
do_log = i % log_args.log_frequency == 0
do_checkpoint = (log_args.checkpoint_frequency != 0) and (
(i % log_args.checkpoint_frequency == 0) or is_final_step
)
do_validation = (log_args.validation_frequency != 0) and (
(i % log_args.validation_frequency == 0) or is_final_step
)
pbar.update(1)
description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}"
pbar.set_description(description)
if do_log:
elapsed_time = time.perf_counter() - start_time
if i == 1:
logger.info(f"First step: {elapsed_time:.3f} secs")
else:
it_per_sec = log_args.log_frequency / elapsed_time
logger.info(f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}")
start_time = time.perf_counter()
if do_checkpoint:
checkpoint_prefix = os.path.join(
path_args.filesystem_prefix, path_args.checkpoint_dir, f"{run_name}-step{i}"
)
model_ckpt = f"{checkpoint_prefix}.ckpt"
optimizer_ckpt = f"{checkpoint_prefix}-optimizer.npy"
logger.info(f"Saving model at {model_ckpt}, optimizer at {optimizer_ckpt}")
model.save_weights(model_ckpt)
# model.load_weights(model_ckpt)
optimizer_weights = optimizer.get_weights()
np.save(optimizer_ckpt, optimizer_weights)
# optimizer.set_weights(optimizer_weights)
if do_validation:
val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation(
model=model,
validation_dataset=validation_dataset,
skip_sop=skip_sop,
skip_mlm=skip_mlm,
)
description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}"
logger.info(f"Validation step {i} -- {description}")
# Create summary_writer after the first step
if summary_writer is None:
summary_writer = tf.summary.create_file_writer(
os.path.join(path_args.filesystem_prefix, path_args.log_dir, run_name)
)
config = {
**asdict(model_args),
**asdict(data_args),
**asdict(train_args),
**asdict(log_args),
"global_batch_size": train_args.per_gpu_batch_size * hvd.size(),
}
if is_wandb_available():
wandb.init(config=config, project=model_args.model_type)
wandb.run.save()
wandb_run_name = wandb.run.name
train_metrics = {
"weight_norm": weight_norm,
"grad_norm": grad_norm,
"loss_scale": loss_scale,
"learning_rate": learning_rate,
"train/loss": loss,
"train/mlm_loss": mlm_loss,
"train/mlm_acc": mlm_acc,
"train/sop_loss": sop_loss,
"train/sop_acc": sop_acc,
}
all_metrics = {**train_metrics}
if do_validation:
val_metrics = {
"val/loss": val_loss,
"val/mlm_loss": val_mlm_loss,
"val/mlm_acc": val_mlm_acc,
"val/sop_loss": val_sop_loss,
"val/sop_acc": val_sop_acc,
}
all_metrics = {**all_metrics, **val_metrics}
if do_squad:
squad_metrics = {
"squad/f1": squad_f1,
"squad/exact": squad_exact,
}
all_metrics = {**all_metrics, **squad_metrics}
# Log to TensorBoard
with summary_writer.as_default():
for name, val in all_metrics.items():
tf.summary.scalar(name, val, step=i)
# Log to Weights & Biases
if is_wandb_available():
wandb.log({"step": i, **all_metrics})
i += 1
if is_final_step:
break
if hvd.rank() == 0:
pbar.close()
logger.info(f"Finished pretraining, job name {run_name}")