def main()

in train.py [0:0]


def main():
    args = parse_train_arg()
    task = task_dict[args.task]

    init_distributed_mode(args)
    logger = init_logger(args)

    if hasattr(args, 'base_model_name'):
        logger.warning('Argument base_model_name is deprecated! Use `--table-bert-extra-config` instead!')

    init_signal_handler()

    train_data_dir = args.data_dir / 'train'
    dev_data_dir = args.data_dir / 'dev'
    table_bert_config = task['config'].from_file(
        args.data_dir / 'config.json', **args.table_bert_extra_config)

    if args.is_master:
        args.output_dir.mkdir(exist_ok=True, parents=True)
        with (args.output_dir / 'train_config.json').open('w') as f:
            json.dump(vars(args), f, indent=2, sort_keys=True, default=str)

        logger.info(f'Table Bert Config: {table_bert_config.to_log_string()}')

        # copy the table bert config file to the working directory
        # shutil.copy(args.data_dir / 'config.json', args.output_dir / 'tb_config.json')
        # save table BERT config
        table_bert_config.save(args.output_dir / 'tb_config.json')

    assert args.data_dir.is_dir(), \
        "--data_dir should point to the folder of files made by pregenerate_training_data.py!"

    if args.cpu:
        device = torch.device('cpu')
    else:
        device = torch.device(f'cuda:{torch.cuda.current_device()}')

    logger.info("device: {} gpu_id: {}, distributed training: {}, 16-bits training: {}".format(
        device, args.local_rank, bool(args.multi_gpu), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    real_batch_size = args.train_batch_size  # // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not args.cpu:
        torch.cuda.manual_seed_all(args.seed)

    if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
        logger.warning(f"Output directory ({args.output_dir}) already exists and is not empty!")
    args.output_dir.mkdir(parents=True, exist_ok=True)

    # Prepare model
    if args.multi_gpu and args.global_rank != 0:
        torch.distributed.barrier()

    if args.no_init:
        raise NotImplementedError
    else:
        model = task['model'](table_bert_config)

    if args.multi_gpu and args.global_rank == 0:
        torch.distributed.barrier()

    if args.fp16:
        model = model.half()

    model = model.to(device)
    if args.multi_gpu:
        if args.ddp_backend == 'pytorch':
            model = nn.parallel.DistributedDataParallel(
                model,
                find_unused_parameters=True,
                device_ids=[args.local_rank], output_device=args.local_rank,
                broadcast_buffers=False
            )
        else:
            import apex
            model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True)

        model_ptr = model.module
    else:
        model_ptr = model

    # set up update parameters for LR scheduler
    dataset_cls = task['dataset']

    train_set_info = dataset_cls.get_dataset_info(train_data_dir, args.max_epoch)
    total_num_updates = train_set_info['total_size'] // args.train_batch_size // args.world_size // args.gradient_accumulation_steps
    args.max_epoch = train_set_info['max_epoch']
    logger.info(f'Train data size: {train_set_info["total_size"]} for {args.max_epoch} epochs, total num. updates: {total_num_updates}')

    args.total_num_update = total_num_updates
    args.warmup_updates = int(total_num_updates * 0.1)

    trainer = Trainer(model, args)

    checkpoint_file = args.output_dir / 'model.ckpt.bin'
    is_resumed = False
    # trainer.save_checkpoint(checkpoint_file)
    if checkpoint_file.exists():
        logger.info(f'Logging checkpoint file {checkpoint_file}')
        is_resumed = True
        trainer.load_checkpoint(checkpoint_file)

    model.train()

    # we also partitation the dev set for every local process
    logger.info('Loading dev set...')
    sys.stdout.flush()
    dev_set = dataset_cls(epoch=0, training_path=dev_data_dir, tokenizer=model_ptr.tokenizer, config=table_bert_config,
                          multi_gpu=args.multi_gpu, debug=args.debug_dataset)

    logger.info("***** Running training *****")
    logger.info(f"  Current config: {args}")

    if trainer.num_updates > 0:
        logger.info(f'Resume training at epoch {trainer.epoch}, '
                    f'epoch step {trainer.in_epoch_step}, '
                    f'global step {trainer.num_updates}')

    start_epoch = trainer.epoch
    for epoch in range(start_epoch, args.max_epoch):  # inclusive
        model.train()

        with torch.random.fork_rng(devices=None if args.cpu else [device.index]):
            torch.random.manual_seed(131 + epoch)

            epoch_dataset = dataset_cls(epoch=trainer.epoch, training_path=train_data_dir, config=table_bert_config,
                                        tokenizer=model_ptr.tokenizer, multi_gpu=args.multi_gpu, debug=args.debug_dataset)
            train_sampler = RandomSampler(epoch_dataset)
            train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=real_batch_size,
                                          num_workers=0,
                                          collate_fn=epoch_dataset.collate)

        samples_iter = GroupedIterator(iter(train_dataloader), args.gradient_accumulation_steps)
        trainer.resume_batch_loader(samples_iter)

        with tqdm(total=len(samples_iter), initial=trainer.in_epoch_step,
                  desc=f"Epoch {epoch}", file=sys.stdout, disable=not args.is_master, miniters=100) as pbar:

            for samples in samples_iter:
                logging_output = trainer.train_step(samples)

                pbar.update(1)
                pbar.set_postfix_str(', '.join(f"{k}: {v:.4f}" for k, v in logging_output.items()))

                if (
                    0 < trainer.num_updates and
                    trainer.num_updates % args.save_checkpoint_every_niter == 0 and
                    args.is_master
                ):
                    # Save model checkpoint
                    logger.info("** ** * Saving checkpoint file ** ** * ")
                    trainer.save_checkpoint(checkpoint_file)

            logger.info(f'Epoch {epoch} finished.')

            if args.is_master:
                # Save a trained table_bert
                logger.info("** ** * Saving fine-tuned table_bert ** ** * ")
                model_to_save = model_ptr  # Only save the table_bert it-self
                output_model_file = args.output_dir / f"pytorch_model_epoch{epoch:02d}.bin"
                torch.save(model_to_save.state_dict(), str(output_model_file))

            # perform validation
            logger.info("** ** * Perform validation ** ** * ")
            dev_results = trainer.validate(dev_set)

            if args.is_master:
                logger.info('** ** * Validation Results ** ** * ')
                logger.info(f'Epoch {epoch} Validation Results: {dev_results}')

            # flush logging information to disk
            sys.stderr.flush()

        trainer.next_epoch()