def main()

in custom/gpt2/run_gpt2.py [0:0]


def main():
    parser = argparse.ArgumentParser(description='openGPT-2 analysis')

    parser.add_argument('--mode', choices=['train', 'eval-singletoken', 'eval-completion', 'eval-both'], default='eval-singletoken')
    parser.add_argument('--eval-split', choices=['train', 'valid', 'test'])
    parser.add_argument('--model-name', choices=['gpt2', 'gpt2-medium', 'gpt2-large'], default='gpt2-medium')
    parser.add_argument('--model-load-dir', type=str, default=None)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--data-base', type=str)
    parser.add_argument('--num-train-epochs', type=int, default=1)
    parser.add_argument('--batch-size-singletoken', type=int, default=1024)
    parser.add_argument('--batch-size-completion', type=int, default=300)
    parser.add_argument("--output-dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    # eval-completion
    parser.add_argument('--prefix-length', type=int, default=50)
    parser.add_argument('--continuation-length', type=int, default=100)
    parser.add_argument('--top-k', type=int, default=1)
    parser.add_argument('--top-p', type=float, default=0.0)

    # custom training
    parser.add_argument('--sequence-tune-rate', type=float, default=0.5)
    parser.add_argument('--train-batch-size', type=int, default=300)
    parser.add_argument('--report-metrics-every', type=int, default=10)
    parser.add_argument('--save-every', type=int, default=1000)
    parser.add_argument('--sequence-ngram-n', type=int, default=4)
    parser.add_argument('--train-n-steps', type=int, default=10000)
    parser.add_argument('--validate-every', type=int, default=10000)

    # training loop
    parser.add_argument("--adam-epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument('--max-grad-norm', type=int, default=1)
    parser.add_argument("--max-steps", default=-1, type=int,
                        help="If > 0: set total number of training \
                            steps to perform. Override num_train_epochs.")
    parser.add_argument('--gradient-accumulation-steps', type=int, default=1,
                        help="Number of updates steps to accumulate before\
                            performing a backward/update pass.")
    parser.add_argument('--learning-rate', type=float, default=6.25e-5)
    parser.add_argument("--warmup-steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument('--lr-schedule', type=str, default='warmup_linear')
    parser.add_argument('--weight-decay', type=float, default=0.01)
    parser.add_argument('--lm-coef', type=float, default=0.9)

    args = parser.parse_args()
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info("device: {}, n_gpu {}".format(device, n_gpu))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

    dataset_paths = {
        'train': os.path.join(args.data_base, 'train_tokens_bpe_gpt2.pt'),
        'valid': os.path.join(args.data_base, 'valid_tokens_bpe_gpt2.pt'),
        'test': os.path.join(args.data_base, 'test_tokens_bpe_gpt2.pt'),
    }

    if args.model_load_dir:
        model = GPT2LMHeadModel.from_pretrained(args.model_load_dir)
    else:
        model = GPT2LMHeadModel.from_pretrained(args.model_name)
    model.to(device)

    if args.mode == 'eval-singletoken' or args.mode == 'eval-both':
        eval_singletoken(model, args, dataset_paths)

    if args.mode == 'eval-completion' or args.mode == 'eval-both':
        datasets = get_datasets(dataset_paths, max_len=args.batch_size_completion)
        eval_sampler = SequentialSampler(datasets[args.eval_split])
        eval_dataloader = DataLoader(datasets[args.eval_split], sampler=eval_sampler, batch_size=1)

        model.eval()

        with torch.no_grad():
            all_text_completions = []

            bpe_ngram_metrics = Metrics(pad=-1)
            word_ngram_metrics = Metrics(pad=-1)

            for i, batch in tqdm(enumerate(eval_dataloader), desc="Evaluating", total=len(eval_dataloader)):
                input_sequence = batch[0].cuda()
                if input_sequence.size(1) < args.prefix_length:
                    continue

                # Predict the completions.
                batch = batch_input_sequence_by_prefix_length(input_sequence, args.prefix_length)
                bpe_completions, _ = sample_sequence(model, batch, args.prefix_length, args.continuation_length, args.top_k, args.top_p)
                bpe_completions = bpe_completions.tolist()

                # Extract continuations from the predicted completions.
                bpe_continuations = []
                text_continuations = []
                for bpe_completion in bpe_completions:
                    bpe_continuations.append(bpe_completion[args.prefix_length:])
                    text_continuations.append(get_text_continuation(bpe_completion, tokenizer, args))
                    all_text_completions.append(tokenizer.decode(bpe_completion))

                # Only keep continuations with at least one 4-gram
                # (A short continuation may occur due to predicted whitespace, then tokenizing, despite being
                #  normal length in BPE tokens).
                text_continuations = [c for c in text_continuations if len(c) > 3]

                # Update metrics with this batch of continuations.
                bpe_ngram_metrics.update(bpe_continuations)
                word_ngram_metrics.update(text_continuations)

                # Save the (possibly intermediate) metrics.
                save_completion_metrics(bpe_metrics=bpe_ngram_metrics.report('bpe_%s' % args.eval_split),
                                        word_metrics=word_ngram_metrics.report('word_%s' % args.eval_split),
                                        text_completions=all_text_completions,
                                        config=model.config.to_dict(),
                                        args=args)

    if args.mode == 'train':
        if not os.path.exists(os.path.join(args.output_dir, 'best')):
            os.makedirs(os.path.join(args.output_dir, 'best'))

        token_loss = mle_loss
        datasets = get_datasets(dataset_paths, max_len=args.train_batch_size)
        train_sampler = RandomSampler(datasets['train'])
        train_seq_dataloader = DataLoader(datasets['train'], sampler=train_sampler, batch_size=1)

        # Setup optimizer
        if args.max_steps > 0:
            t_total = args.max_steps
            args.num_train_epochs = args.max_steps // (len(train_seq_dataloader) // args.gradient_accumulation_steps) + 1
        else:
            t_total = len(train_seq_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)

        total_steps = 0
        best_ppl = 1e20
        for _ in trange(args.num_train_epochs, desc="Epoch"):
            logging_outputs = []
            epoch_loss = 0
            epoch_steps = 0
            tqdm_bar = tqdm(train_seq_dataloader, desc="Training", total=args.train_n_steps)
            for step, batch in enumerate(tqdm_bar):
                optimizer.zero_grad()

                # Sequence loss
                if torch.rand(1).item() < args.sequence_tune_rate:
                    if batch[0].size(1) < args.prefix_length:
                        continue
                    loss, batch_metrics = ul_seq(model, batch, args)

                # Token loss
                else:
                    loss, batch_metrics = token_loss(model, batch, args)

                loss.backward()
                optimizer.step()
                scheduler.step()
                epoch_loss += loss.item()
                epoch_steps += 1
                total_steps += 1
                tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(epoch_loss/epoch_steps, scheduler.get_lr()[0])

                logging_outputs.append(batch_metrics)

                if epoch_steps % args.report_metrics_every == 0:
                    logging_average = CrossEntropyCriterionWCustomMetrics.aggregate_logging_outputs(logging_outputs)
                    temp = SequencePenaltyCriterion.aggregate_logging_outputs(logging_outputs)
                    for k, v in temp.items():
                        logging_average[k] = v
                    logging_average['ppl'] = 2 ** logging_average['loss']
                    print(logging_average)
                    logging_outputs = []

                if step == args.train_n_steps:
                    break

                if epoch_steps % args.save_every == 0:
                    model_to_save = model.module if hasattr(model, 'module') else model
                    output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
                    output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
                    torch.save(model_to_save.state_dict(), output_model_file)
                    model_to_save.config.to_json_file(output_config_file)
                    tokenizer.save_vocabulary(args.output_dir)

                if total_steps % args.validate_every == 0:
                    print("Validating...")
                    validation_outputs = eval_singletoken(model, args, dataset_paths, train_iter=total_steps)
                    if validation_outputs['ppl'] < best_ppl:
                        best_ppl = validation_outputs['ppl']
                        model_to_save = model.module if hasattr(model, 'module') else model
                        output_model_file = os.path.join(args.output_dir, 'best', WEIGHTS_NAME)
                        output_config_file = os.path.join(args.output_dir, 'best', CONFIG_NAME)
                        torch.save(model_to_save.state_dict(), output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(os.path.join(args.output_dir, 'best'))
                        save_singletoken_metrics(validation_outputs, model.config.to_dict(), args,
                                                 train_iter=total_steps, best=True)