in test.py [0:0]
def main(logger, args):
if args.gpt2.startswith("gpt2"):
tokenizer = GPT2Tokenizer.from_pretrained(args.gpt2)
else:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
add_newlines = True
### checkpoint ...
if not args.do_zeroshot:
if args.checkpoint is not None:
checkpoint = args.checkpoint
assert args.global_step is None
else:
assert args.global_step is not None
checkpoint = os.path.join(args.out_dir, "model-{}.pt".format(args.global_step))
assert os.path.exists(checkpoint)
else:
checkpoint = None
add_newlines = args.gpt2=="gpt-j-6B"
metaicl_model = MetaICLModel(logger, args.out_dir)
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
# setup hyperparams for data
max_length_per_example = 256
max_length = 256
if args.use_demonstrations:
orig_max_length = max_length
if args.do_zeroshot:
max_length = min(max_length * args.k, 1024)
else:
max_length = min(max_length * args.k, 1024)
logger.info("batch_size=%d\tmax_length=%d\tmax_length_per_example=%d" % (
args.test_batch_size, max_length, max_length_per_example))
metaicl_data = MetaICLData(logger, tokenizer, args.method,args.use_demonstrations, args.k,
max_length, max_length_per_example)
results = []
errors = []
seeds = args.seed.split(",")
config_split = "unseen_domain_test" if args.unseen_domain_only else "test"
for seed in seeds:
### data ...
train_data = load_data(args.task, "train", args.k, seed=seed, config_split=config_split)
dev_data = load_data(args.task, args.split, args.k, seed=seed, config_split=config_split, is_null=args.is_null)
train_counter = Counter()
dev_counter = Counter()
for dp in train_data:
train_counter[dp["task"]] += 1
for dp in dev_data:
dev_counter[dp["task"]] += 1
for k, v in train_counter.items():
logger.info("[Train] %s\t%d" % (k, v))
for k, v in dev_counter.items():
logger.info("[Dev] %s\t%d" % (k, v))
logger.info("%s on %s (%d train, %d dev)" % (args.method, args.task, len(train_counter), len(dev_counter)))
for test_task in dev_counter:
curr_dev_data = [dp for dp in dev_data if dp["task"]==test_task]
curr_train_data = [dp for dp in train_data if dp["task"]==test_task]
assert len(curr_dev_data)>0
assert not args.use_demonstrations or len(curr_train_data)==args.k, \
(args.use_demonstrations, len(curr_train_data), args.k)
config_file = "config/tasks/{}.json".format(test_task)
assert os.path.exists(config_file), config_file
with open(config_file, "r") as f:
config = json.load(f)
is_classification = config["task_type"]=="classification"
if is_classification:
options = curr_dev_data[0]["options"]
assert np.all([d["options"]==options for d in curr_dev_data+curr_train_data])
result = run(logger, test_task, metaicl_data, metaicl_model,
curr_train_data, curr_dev_data, seed, checkpoint, is_classification, add_newlines)
if result is None:
errors.append("%s/%s" % (test_task, seed))
else:
results.append(result)
if args.is_null:
return
logger.info("Macro-F1 of %s over %d target tasks: %.1f" % (args.task, len(results) // len(seeds), 100*np.mean(results)))
if len(errors)>0:
logger.info("You had errors with datasets:", ",".join(errors))
logger.info("Please see the error messages")