def eval_continual_comp()

in experiments/scripts/eval_supervised.py [0:0]


def eval_continual_comp(big_config, seed=-1, mode="hard"):
    """Evaluation of continual learning with unique composition functions
    """
    results_table = []
    # load all test data
    ev_exp = CheckpointableTestTube(
        config_id=big_config["representation"][mode][0][-1], seed=seed
    )
    ev_exp.config.general.test_rule = "rule_0"
    all_test_data = ev_exp.initialize_data(mode="test", override_mode="train")
    label2id = ev_exp.label2id
    del ev_exp
    for exp in big_config["representation"][mode]:
        print("evaluating experiment {}".format(json.dumps(exp)))
        res_row = copy.deepcopy(row)
        res_row["rep_fn"] = exp[0]
        res_row["comp_fn"] = exp[1]
        ev_exp = CheckpointableTestTube(config_id=exp[-1], seed=seed)
        # perform the modification with the worlds here
        train_worlds = ev_exp.config.general.train_rule.split(",")
        print("evaluating on {} train worlds".format(len(train_worlds)))
        for wi, current_world in enumerate(train_worlds):
            ev_exp = CheckpointableTestTube(config_id=exp[-1], seed=seed)
            ev_exp.config.general.test_rule = current_world + ","
            ev_exp.prepare_evaluator(
                epoch=wi,
                override_mode="train",
                test_data=all_test_data,
                label2id=label2id,
            )
            pr_current = ev_exp.evaluator.evaluate()
            if wi > 0:
                pr_past = {}
                pr_past_w = []
                print("loading {} test worlds".format(len(train_worlds[:wi])))
                for pi, past_world in enumerate(train_worlds[:wi]):
                    ev_exp = CheckpointableTestTube(config_id=exp[-1], seed=seed)
                    ev_exp.config.general.test_rule = past_world + ","
                    # load current rep function
                    ev_exp.prepare_evaluator(
                        epoch=wi,
                        override_mode="train",
                        test_data=all_test_data,
                        label2id=label2id,
                    )
                    # load past comp function
                    # ev_exp.config.model.use_composition_fn = True
                    ev_exp.config.model.use_representation_fn = True
                    ev_exp.evaluator.reset(epoch=pi)
                    # eval
                    pr_past_w.append(ev_exp.evaluator.evaluate())
                pr_past["accuracy"] = np.mean([pw["accuracy"] for pw in pr_past_w])
                pr_past["acc_std"] = np.mean([pw["acc_std"] for pw in pr_past_w])
            else:
                pr_past = {}
            res_row["current_world"] = current_world
            res_row["accuracy"] = pr_current["accuracy"]
            if len(pr_past) > 0:
                res_row["past_accuracy"] = pr_past["accuracy"]
                res_row["acc_std"] = pr_past["acc_std"]
            results_table.append(copy.deepcopy(res_row))
    return pd.DataFrame(results_table)