in experiments/scripts/eval_supervised.py [0:0]
def eval_continual_comp(big_config, seed=-1, mode="hard"):
"""Evaluation of continual learning with unique composition functions
"""
results_table = []
# load all test data
ev_exp = CheckpointableTestTube(
config_id=big_config["representation"][mode][0][-1], seed=seed
)
ev_exp.config.general.test_rule = "rule_0"
all_test_data = ev_exp.initialize_data(mode="test", override_mode="train")
label2id = ev_exp.label2id
del ev_exp
for exp in big_config["representation"][mode]:
print("evaluating experiment {}".format(json.dumps(exp)))
res_row = copy.deepcopy(row)
res_row["rep_fn"] = exp[0]
res_row["comp_fn"] = exp[1]
ev_exp = CheckpointableTestTube(config_id=exp[-1], seed=seed)
# perform the modification with the worlds here
train_worlds = ev_exp.config.general.train_rule.split(",")
print("evaluating on {} train worlds".format(len(train_worlds)))
for wi, current_world in enumerate(train_worlds):
ev_exp = CheckpointableTestTube(config_id=exp[-1], seed=seed)
ev_exp.config.general.test_rule = current_world + ","
ev_exp.prepare_evaluator(
epoch=wi,
override_mode="train",
test_data=all_test_data,
label2id=label2id,
)
pr_current = ev_exp.evaluator.evaluate()
if wi > 0:
pr_past = {}
pr_past_w = []
print("loading {} test worlds".format(len(train_worlds[:wi])))
for pi, past_world in enumerate(train_worlds[:wi]):
ev_exp = CheckpointableTestTube(config_id=exp[-1], seed=seed)
ev_exp.config.general.test_rule = past_world + ","
# load current rep function
ev_exp.prepare_evaluator(
epoch=wi,
override_mode="train",
test_data=all_test_data,
label2id=label2id,
)
# load past comp function
# ev_exp.config.model.use_composition_fn = True
ev_exp.config.model.use_representation_fn = True
ev_exp.evaluator.reset(epoch=pi)
# eval
pr_past_w.append(ev_exp.evaluator.evaluate())
pr_past["accuracy"] = np.mean([pw["accuracy"] for pw in pr_past_w])
pr_past["acc_std"] = np.mean([pw["acc_std"] for pw in pr_past_w])
else:
pr_past = {}
res_row["current_world"] = current_world
res_row["accuracy"] = pr_current["accuracy"]
if len(pr_past) > 0:
res_row["past_accuracy"] = pr_past["accuracy"]
res_row["acc_std"] = pr_past["acc_std"]
results_table.append(copy.deepcopy(res_row))
return pd.DataFrame(results_table)