def eval_continual()

in experiments/scripts/eval_supervised.py [0:0]


def eval_continual(big_config, seed=-1, mode="hard"):
    """Evaluation of continual learning experiments
    """
    results_table = []
    # load all test data
    ev_exp = CheckpointableTestTube(
        config_id=big_config["representation"][mode][0][-1], seed=seed
    )
    ev_exp.config.general.test_rule = "rule_0"
    all_test_data = ev_exp.initialize_data(mode="test", override_mode="train")
    label2id = ev_exp.label2id
    del ev_exp
    for exp in big_config["representation"][mode]:
        print("evaluating experiment {}".format(json.dumps(exp)))
        res_row = copy.deepcopy(row)
        res_row["rep_fn"] = exp[0]
        res_row["comp_fn"] = exp[1]
        ev_exp = CheckpointableTestTube(config_id=exp[-1], seed=seed)
        # perform the modification with the worlds here
        train_worlds = ev_exp.config.general.train_rule.split(",")
        print("evaluating on {} train worlds".format(len(train_worlds)))
        for wi, current_world in enumerate(train_worlds):
            ev_exp = CheckpointableTestTube(config_id=exp[-1], seed=seed)
            ev_exp.config.general.test_rule = current_world + ","
            ev_exp.prepare_evaluator(
                epoch=wi,
                override_mode="train",
                test_data=all_test_data,
                label2id=label2id,
            )
            pr_current = ev_exp.evaluator.evaluate()
            if wi > 0:
                ev_exp.config.general.test_rule = ",".join(train_worlds[:wi]) + ","
                print("loading {} test worlds".format(len(train_worlds[:wi])))
                ev_exp.prepare_evaluator(
                    epoch=wi,
                    override_mode="train",
                    test_data=all_test_data,
                    label2id=label2id,
                )
                pr_past = ev_exp.evaluator.evaluate()
            else:
                pr_past = {}
            res_row["current_world"] = current_world
            res_row["accuracy"] = pr_current["accuracy"]
            if len(pr_past) > 0:
                res_row["past_accuracy"] = pr_past["accuracy"]
                res_row["acc_std"] = pr_past["acc_std"]
            results_table.append(copy.deepcopy(res_row))
    return pd.DataFrame(results_table)