in train.py [0:0]
def main(logger, args):
if args.gpt2.startswith("gpt2"):
tokenizer = GPT2Tokenizer.from_pretrained(args.gpt2)
else:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
batch_size = args.batch_size
max_length_per_example = 256
max_length = 256
if args.use_demonstrations:
max_length = min(max_length * args.k, 1024)
logger.info("batch_size=%d\tmax_length=%d\tmax_length_per_example=%d" % (
args.batch_size, max_length, max_length_per_example))
train_data = load_data(args.task, "train", args.k, seed=args.seed)
train_counter = Counter()
for dp in train_data:
train_counter[dp["task"]] += 1
if args.local_rank <= 0:
for k, v in train_counter.items():
logger.info("[Train] %s\t%d" % (k, v))
logger.info("%s on %s (%d train)" % (args.method, args.task, len(train_counter)))
if args.init_checkpoint is not None:
assert os.path.exists(args.init_checkpoint)
######### load tensorize data
metaicl_data = MetaICLData(logger, tokenizer, args.method, args.use_demonstrations,
args.test_k, max_length, max_length_per_example,
do_tensorize=args.do_tensorize,
tensorize_dir=args.tensorize_dir,
n_process=args.n_process, n_gpu=args.n_gpu, local_rank=args.local_rank)
metaicl_data.tensorize_for_training(train_data, keyword=args.task, seed=args.seed)
if args.do_tensorize:
return
######## actual training part
random.seed(args.train_seed)
np.random.seed(args.train_seed)
torch.manual_seed(args.train_seed)
if torch.cuda.device_count() > 0:
torch.cuda.manual_seed_all(args.train_seed)
num_training_steps = args.num_training_steps
save_period = 10000
log_period = 10000
if args.no_masking:
metaicl_data.tensorized_inputs["token_type_ids"] = torch.ones_like(metaicl_data.tensorized_inputs["input_ids"])
metaicl_data.print_tensorized_example()
logger.info(args.out_dir)
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
metaicl_model = MetaICLModel(logger, args.out_dir, args.fp16, args.local_rank)
metaicl_model.load(args.init_checkpoint, args.gpt2)
metaicl_model.to_device()
metaicl_model.setup_optimizer(args.optimization, num_training_steps, args.lr,
args.weight_decay, args.warmup_steps)
metaicl_model.parallel()
metaicl_model.train()
metaicl_model.do_train(metaicl_data, args.batch_size, num_training_steps, save_period, log_period)