def main()

in train.py [0:0]


def main(logger, args):
    if args.gpt2.startswith("gpt2"):
        tokenizer = GPT2Tokenizer.from_pretrained(args.gpt2)
    else:
        tokenizer = AutoTokenizer.from_pretrained("gpt2")

    batch_size = args.batch_size
    max_length_per_example = 256
    max_length = 256
    if args.use_demonstrations:
        max_length = min(max_length * args.k, 1024)

    logger.info("batch_size=%d\tmax_length=%d\tmax_length_per_example=%d" % (
        args.batch_size, max_length, max_length_per_example))

    train_data = load_data(args.task, "train", args.k, seed=args.seed)

    train_counter = Counter()
    for dp in train_data:
        train_counter[dp["task"]] += 1
    if args.local_rank <= 0:
        for k, v in train_counter.items():
            logger.info("[Train] %s\t%d" % (k, v))
        logger.info("%s on %s (%d train)" % (args.method, args.task, len(train_counter)))

    if args.init_checkpoint is not None:
        assert os.path.exists(args.init_checkpoint)

    ######### load tensorize data
    metaicl_data = MetaICLData(logger, tokenizer, args.method, args.use_demonstrations,
                               args.test_k, max_length, max_length_per_example,
                               do_tensorize=args.do_tensorize,
                               tensorize_dir=args.tensorize_dir,
                               n_process=args.n_process, n_gpu=args.n_gpu, local_rank=args.local_rank)
    metaicl_data.tensorize_for_training(train_data, keyword=args.task, seed=args.seed)

    if args.do_tensorize:
        return

    ######## actual training part

    random.seed(args.train_seed)
    np.random.seed(args.train_seed)
    torch.manual_seed(args.train_seed)
    if torch.cuda.device_count() > 0:
        torch.cuda.manual_seed_all(args.train_seed)

    num_training_steps = args.num_training_steps
    save_period = 10000
    log_period = 10000

    if args.no_masking:
        metaicl_data.tensorized_inputs["token_type_ids"] = torch.ones_like(metaicl_data.tensorized_inputs["input_ids"])
    metaicl_data.print_tensorized_example()

    logger.info(args.out_dir)

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    metaicl_model = MetaICLModel(logger, args.out_dir, args.fp16, args.local_rank)
    metaicl_model.load(args.init_checkpoint, args.gpt2)
    metaicl_model.to_device()
    metaicl_model.setup_optimizer(args.optimization, num_training_steps, args.lr,
                                  args.weight_decay, args.warmup_steps)
    metaicl_model.parallel()
    metaicl_model.train()
    metaicl_model.do_train(metaicl_data, args.batch_size, num_training_steps, save_period, log_period)