def main()

in models/nlp/electra/run_pretraining.py [0:0]


def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments, PathArguments)
    )
    (
        model_args,
        data_args,
        train_args,
        log_args,
        path_args,
        remaining_strings,
    ) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
    # SageMaker may have some extra strings. TODO: Test this on SM.
    assert len(remaining_strings) == 0, f"The args {remaining_strings} could not be parsed."

    hvd.init()
    gpus = tf.config.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")
    if train_args.eager == "true":
        tf.config.experimental_run_functions_eagerly(True)

    tokenizer = ElectraTokenizerFast.from_pretrained("bert-base-uncased")

    gen_config = ElectraConfig.from_pretrained(f"google/electra-{model_args.model_size}-generator")
    dis_config = ElectraConfig.from_pretrained(
        f"google/electra-{model_args.model_size}-discriminator"
    )

    gen = TFElectraForMaskedLM(config=gen_config)
    dis = TFElectraForPreTraining(config=dis_config)
    optimizer = get_adamw_optimizer(train_args)

    # Tie the weights
    if model_args.electra_tie_weights == "true":
        gen.electra.embeddings = dis.electra.embeddings

    loaded_optimizer_weights = None
    if model_args.load_from == "checkpoint":
        checkpoint_path = os.path.join(path_args.filesystem_prefix, model_args.checkpoint_path)
        dis_ckpt, gen_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(checkpoint_path)
        if hvd.rank() == 0:
            dis.load_weights(dis_ckpt)
            gen.load_weights(gen_ckpt)
            loaded_optimizer_weights = np.load(optimizer_ckpt, allow_pickle=True)

    start_time = time.perf_counter()

    if hvd.rank() == 0:
        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        handlers = [
            TqdmLoggingHandler(),
        ]
        summary_writer = None  # Only create a writer if we make it through a successful step
        logging.basicConfig(level=level, format=format, handlers=handlers)
        wandb_run_name = None

        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        if log_args.run_name is None:
            metadata = (
                f"electra-{hvd.size()}gpus"
                f"-{train_args.per_gpu_batch_size * hvd.size() * train_args.gradient_accumulation_steps}globalbatch"
                f"-{train_args.total_steps}steps"
            )
            run_name = (
                f"{current_time}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"
            )
        else:
            run_name = log_args.run_name

    logger.info(f"Training with dataset at {path_args.train_dir}")
    logger.info(f"Validating with dataset at {path_args.val_dir}")

    train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir, "*.tfrecord*")
    validation_glob = os.path.join(path_args.filesystem_prefix, path_args.val_dir, "*.tfrecord*")

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)
    logger.info(
        f"Number of train files {len(train_filenames)}, number of validation files {len(validation_filenames)}"
    )

    tf_train_dataset = get_dataset_from_tfrecords(
        model_type=model_args.model_type,
        filenames=train_filenames,
        per_gpu_batch_size=train_args.per_gpu_batch_size,
        max_seq_length=data_args.max_seq_length,
    )

    tf_train_dataset = tf_train_dataset.prefetch(buffer_size=8)

    if hvd.rank() == 0:
        tf_val_dataset = get_dataset_from_tfrecords(
            model_type=model_args.model_type,
            filenames=validation_filenames,
            per_gpu_batch_size=train_args.per_gpu_batch_size,
            max_seq_length=data_args.max_seq_length,
        )
        tf_val_dataset = tf_val_dataset.prefetch(buffer_size=8)

    wandb_run_name = None

    step = 1
    for batch in tf_train_dataset:
        learning_rate = optimizer.learning_rate(step=tf.constant(step, dtype=tf.float32))
        ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        train_result = train_step(
            optimizer=optimizer,
            gen=gen,
            dis=dis,
            ids=ids,
            attention_mask=attention_mask,
            per_gpu_batch_size=train_args.per_gpu_batch_size,
            max_seq_length=data_args.max_seq_length,
            mask_token_id=tokenizer.mask_token_id,
        )

        if step == 1:
            # Horovod broadcast
            if hvd.rank() == 0 and loaded_optimizer_weights is not None:
                optimizer.set_weights(loaded_optimizer_weights)
            hvd.broadcast_variables(gen.variables, root_rank=0)
            hvd.broadcast_variables(dis.variables, root_rank=0)
            hvd.broadcast_variables(optimizer.variables(), root_rank=0)
            step = optimizer.get_weights()[0]

        is_final_step = step >= train_args.total_steps
        if hvd.rank() == 0:
            do_log = step % log_args.log_frequency == 0
            do_checkpoint = (step > 1) and (
                (step % log_args.checkpoint_frequency == 0) or is_final_step
            )
            do_validation = step % log_args.validation_frequency == 0

            if do_log:
                elapsed_time = time.perf_counter() - start_time  # Off for first log
                it_s = log_args.log_frequency / elapsed_time
                start_time = time.perf_counter()
                description = f"Step {step} -- gen_loss: {train_result.gen_loss:.3f}, dis_loss: {train_result.dis_loss:.3f}, gen_acc: {train_result.gen_acc:.3f}, dis_acc: {train_result.dis_acc:.3f}, it/s: {it_s:.3f}\n"
                logger.info(description)

            if do_validation:
                for batch in tf_val_dataset.take(1):
                    val_ids = batch["input_ids"]
                    val_attention_mask = batch["attention_mask"]
                    val_result = val_step(
                        gen=gen,
                        dis=dis,
                        ids=val_ids,
                        attention_mask=val_attention_mask,
                        per_gpu_batch_size=train_args.per_gpu_batch_size,
                        max_seq_length=data_args.max_seq_length,
                        mask_token_id=tokenizer.mask_token_id,
                    )
                    log_example(
                        tokenizer,
                        val_ids,
                        val_result.masked_ids,
                        val_result.corruption_mask,
                        val_result.gen_ids,
                        val_result.dis_preds,
                    )
                    description = f"VALIDATION, Step {step} -- val_gen_loss: {val_result.gen_loss:.3f}, val_dis_loss: {val_result.dis_loss:.3f}, val_gen_acc: {val_result.gen_acc:.3f}, val_dis_acc: {val_result.dis_acc:.3f}\n"
                    logger.info(description)

            train_metrics = {
                "learning_rate": learning_rate,
                "train/loss": train_result.loss,
                "train/gen_loss": train_result.gen_loss,
                "train/dis_loss": train_result.dis_loss,
                "train/gen_acc": train_result.gen_acc,
                "train/dis_acc": train_result.dis_acc,
            }
            all_metrics = {**train_metrics}
            if do_validation:
                val_metrics = {
                    "val/loss": val_result.loss,
                    "val/gen_loss": val_result.gen_loss,
                    "val/dis_loss": val_result.dis_loss,
                    "val/gen_acc": val_result.gen_acc,
                    "val/dis_acc": val_result.dis_acc,
                }
                all_metrics = {**all_metrics, **val_metrics}
            if do_log:
                all_metrics = {"it_s": it_s, **all_metrics}

            if is_wandb_available():
                if wandb_run_name is None:
                    config = {
                        **asdict(model_args),
                        **asdict(data_args),
                        **asdict(train_args),
                        **asdict(log_args),
                        **asdict(path_args),
                        "global_batch_size": train_args.per_gpu_batch_size * hvd.size(),
                        "n_gpus": hvd.size(),
                    }
                    wandb.init(config=config, project="electra")
                    wandb.run.save()
                    wandb_run_name = wandb.run.name
                wandb.log({"step": step, **all_metrics})

                # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    os.path.join(path_args.filesystem_prefix, path_args.log_dir, run_name)
                )
                config = {
                    **asdict(model_args),
                    **asdict(data_args),
                    **asdict(train_args),
                    **asdict(log_args),
                    **asdict(path_args),
                    "global_batch_size": train_args.per_gpu_batch_size * hvd.size(),
                    "n_gpus": hvd.size(),
                }

            # Log to TensorBoard
            with summary_writer.as_default():
                for name, val in all_metrics.items():
                    tf.summary.scalar(name, val, step=step)

            if do_checkpoint:
                dis_model_ckpt = os.path.join(
                    path_args.filesystem_prefix,
                    path_args.checkpoint_dir,
                    f"{run_name}-step{step}-discriminator.ckpt",
                )
                gen_model_ckpt = os.path.join(
                    path_args.filesystem_prefix,
                    path_args.checkpoint_dir,
                    f"{run_name}-step{step}-generator.ckpt",
                )
                optimizer_ckpt = os.path.join(
                    path_args.filesystem_prefix,
                    path_args.checkpoint_dir,
                    f"{run_name}-step{step}-optimizer.npy",
                )
                logger.info(
                    f"Saving discriminator model at {dis_model_ckpt}, generator model at {gen_model_ckpt}, optimizer at {optimizer_ckpt}"
                )
                dis.save_weights(dis_model_ckpt)
                gen.save_weights(gen_model_ckpt)
                np.save(optimizer_ckpt, optimizer.get_weights())

        step += 1
        if is_final_step:
            break