in models/nlp/electra/run_pretraining.py [0:0]
def main():
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments, PathArguments)
)
(
model_args,
data_args,
train_args,
log_args,
path_args,
remaining_strings,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
# SageMaker may have some extra strings. TODO: Test this on SM.
assert len(remaining_strings) == 0, f"The args {remaining_strings} could not be parsed."
hvd.init()
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")
if train_args.eager == "true":
tf.config.experimental_run_functions_eagerly(True)
tokenizer = ElectraTokenizerFast.from_pretrained("bert-base-uncased")
gen_config = ElectraConfig.from_pretrained(f"google/electra-{model_args.model_size}-generator")
dis_config = ElectraConfig.from_pretrained(
f"google/electra-{model_args.model_size}-discriminator"
)
gen = TFElectraForMaskedLM(config=gen_config)
dis = TFElectraForPreTraining(config=dis_config)
optimizer = get_adamw_optimizer(train_args)
# Tie the weights
if model_args.electra_tie_weights == "true":
gen.electra.embeddings = dis.electra.embeddings
loaded_optimizer_weights = None
if model_args.load_from == "checkpoint":
checkpoint_path = os.path.join(path_args.filesystem_prefix, model_args.checkpoint_path)
dis_ckpt, gen_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(checkpoint_path)
if hvd.rank() == 0:
dis.load_weights(dis_ckpt)
gen.load_weights(gen_ckpt)
loaded_optimizer_weights = np.load(optimizer_ckpt, allow_pickle=True)
start_time = time.perf_counter()
if hvd.rank() == 0:
# Logging should only happen on a single process
# https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
level = logging.INFO
format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
handlers = [
TqdmLoggingHandler(),
]
summary_writer = None # Only create a writer if we make it through a successful step
logging.basicConfig(level=level, format=format, handlers=handlers)
wandb_run_name = None
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
if log_args.run_name is None:
metadata = (
f"electra-{hvd.size()}gpus"
f"-{train_args.per_gpu_batch_size * hvd.size() * train_args.gradient_accumulation_steps}globalbatch"
f"-{train_args.total_steps}steps"
)
run_name = (
f"{current_time}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"
)
else:
run_name = log_args.run_name
logger.info(f"Training with dataset at {path_args.train_dir}")
logger.info(f"Validating with dataset at {path_args.val_dir}")
train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir, "*.tfrecord*")
validation_glob = os.path.join(path_args.filesystem_prefix, path_args.val_dir, "*.tfrecord*")
train_filenames = glob.glob(train_glob)
validation_filenames = glob.glob(validation_glob)
logger.info(
f"Number of train files {len(train_filenames)}, number of validation files {len(validation_filenames)}"
)
tf_train_dataset = get_dataset_from_tfrecords(
model_type=model_args.model_type,
filenames=train_filenames,
per_gpu_batch_size=train_args.per_gpu_batch_size,
max_seq_length=data_args.max_seq_length,
)
tf_train_dataset = tf_train_dataset.prefetch(buffer_size=8)
if hvd.rank() == 0:
tf_val_dataset = get_dataset_from_tfrecords(
model_type=model_args.model_type,
filenames=validation_filenames,
per_gpu_batch_size=train_args.per_gpu_batch_size,
max_seq_length=data_args.max_seq_length,
)
tf_val_dataset = tf_val_dataset.prefetch(buffer_size=8)
wandb_run_name = None
step = 1
for batch in tf_train_dataset:
learning_rate = optimizer.learning_rate(step=tf.constant(step, dtype=tf.float32))
ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
train_result = train_step(
optimizer=optimizer,
gen=gen,
dis=dis,
ids=ids,
attention_mask=attention_mask,
per_gpu_batch_size=train_args.per_gpu_batch_size,
max_seq_length=data_args.max_seq_length,
mask_token_id=tokenizer.mask_token_id,
)
if step == 1:
# Horovod broadcast
if hvd.rank() == 0 and loaded_optimizer_weights is not None:
optimizer.set_weights(loaded_optimizer_weights)
hvd.broadcast_variables(gen.variables, root_rank=0)
hvd.broadcast_variables(dis.variables, root_rank=0)
hvd.broadcast_variables(optimizer.variables(), root_rank=0)
step = optimizer.get_weights()[0]
is_final_step = step >= train_args.total_steps
if hvd.rank() == 0:
do_log = step % log_args.log_frequency == 0
do_checkpoint = (step > 1) and (
(step % log_args.checkpoint_frequency == 0) or is_final_step
)
do_validation = step % log_args.validation_frequency == 0
if do_log:
elapsed_time = time.perf_counter() - start_time # Off for first log
it_s = log_args.log_frequency / elapsed_time
start_time = time.perf_counter()
description = f"Step {step} -- gen_loss: {train_result.gen_loss:.3f}, dis_loss: {train_result.dis_loss:.3f}, gen_acc: {train_result.gen_acc:.3f}, dis_acc: {train_result.dis_acc:.3f}, it/s: {it_s:.3f}\n"
logger.info(description)
if do_validation:
for batch in tf_val_dataset.take(1):
val_ids = batch["input_ids"]
val_attention_mask = batch["attention_mask"]
val_result = val_step(
gen=gen,
dis=dis,
ids=val_ids,
attention_mask=val_attention_mask,
per_gpu_batch_size=train_args.per_gpu_batch_size,
max_seq_length=data_args.max_seq_length,
mask_token_id=tokenizer.mask_token_id,
)
log_example(
tokenizer,
val_ids,
val_result.masked_ids,
val_result.corruption_mask,
val_result.gen_ids,
val_result.dis_preds,
)
description = f"VALIDATION, Step {step} -- val_gen_loss: {val_result.gen_loss:.3f}, val_dis_loss: {val_result.dis_loss:.3f}, val_gen_acc: {val_result.gen_acc:.3f}, val_dis_acc: {val_result.dis_acc:.3f}\n"
logger.info(description)
train_metrics = {
"learning_rate": learning_rate,
"train/loss": train_result.loss,
"train/gen_loss": train_result.gen_loss,
"train/dis_loss": train_result.dis_loss,
"train/gen_acc": train_result.gen_acc,
"train/dis_acc": train_result.dis_acc,
}
all_metrics = {**train_metrics}
if do_validation:
val_metrics = {
"val/loss": val_result.loss,
"val/gen_loss": val_result.gen_loss,
"val/dis_loss": val_result.dis_loss,
"val/gen_acc": val_result.gen_acc,
"val/dis_acc": val_result.dis_acc,
}
all_metrics = {**all_metrics, **val_metrics}
if do_log:
all_metrics = {"it_s": it_s, **all_metrics}
if is_wandb_available():
if wandb_run_name is None:
config = {
**asdict(model_args),
**asdict(data_args),
**asdict(train_args),
**asdict(log_args),
**asdict(path_args),
"global_batch_size": train_args.per_gpu_batch_size * hvd.size(),
"n_gpus": hvd.size(),
}
wandb.init(config=config, project="electra")
wandb.run.save()
wandb_run_name = wandb.run.name
wandb.log({"step": step, **all_metrics})
# Create summary_writer after the first step
if summary_writer is None:
summary_writer = tf.summary.create_file_writer(
os.path.join(path_args.filesystem_prefix, path_args.log_dir, run_name)
)
config = {
**asdict(model_args),
**asdict(data_args),
**asdict(train_args),
**asdict(log_args),
**asdict(path_args),
"global_batch_size": train_args.per_gpu_batch_size * hvd.size(),
"n_gpus": hvd.size(),
}
# Log to TensorBoard
with summary_writer.as_default():
for name, val in all_metrics.items():
tf.summary.scalar(name, val, step=step)
if do_checkpoint:
dis_model_ckpt = os.path.join(
path_args.filesystem_prefix,
path_args.checkpoint_dir,
f"{run_name}-step{step}-discriminator.ckpt",
)
gen_model_ckpt = os.path.join(
path_args.filesystem_prefix,
path_args.checkpoint_dir,
f"{run_name}-step{step}-generator.ckpt",
)
optimizer_ckpt = os.path.join(
path_args.filesystem_prefix,
path_args.checkpoint_dir,
f"{run_name}-step{step}-optimizer.npy",
)
logger.info(
f"Saving discriminator model at {dis_model_ckpt}, generator model at {gen_model_ckpt}, optimizer at {optimizer_ckpt}"
)
dis.save_weights(dis_model_ckpt)
gen.save_weights(gen_model_ckpt)
np.save(optimizer_ckpt, optimizer.get_weights())
step += 1
if is_final_step:
break