def parse_n_prepare_postproc()

in shaDow/utils.py [0:0]


def parse_n_prepare_postproc(dir_load, f_config, name_graph, dir_log, arch_gnn, logger):
    if f_config is not None:
        with open(f_config) as f:
            config_postproc = yaml.load(f, Loader=yaml.FullLoader)
        name_key = f"postproc-{arch_gnn['aggr']}_{arch_gnn['num_layers']}"
        log_dir('postproc', config_postproc, name_key, dir_log, name_graph, git_rev, timestamp)
    skip_instantiate = []
    if 'check_record' in config_postproc:
        load_acc_record = config_postproc['check_record']
    else:
        load_acc_record = True
    if config_postproc['method'] == 'cs':               # C&S
        acc_record = [] if load_acc_record else None
        if dir_load is not None:
            if 'dir_pred_mat' not in config_postproc:
                config_postproc['dir_pred_mat'] = [dir_load]
            elif os.path.realpath(dir_load) not in [os.path.realpath(pc) for pc in config_postproc['dir_pred_mat']]:
                config_postproc['dir_pred_mat'].append(dir_load)
        config_postproc['pred_mat'] = [None] * len(config_postproc['dir_pred_mat'])
        for i, di in enumerate(config_postproc['dir_pred_mat']):
            if load_acc_record:
                acc_record.append(logger.decode_csv('final', di))
            for f in os.listdir(di):
                if 'cs' == f.split('.')[-1] and f.startswith('pred_mat'):
                    config_postproc['pred_mat'][i] = torch.load(f"{di}/{f}")
                    break
        if all(m is not None for m in config_postproc['pred_mat']):
            skip_instantiate = ['data', 'model']
    elif config_postproc['method'] == 'ensemble':       # Variant of subgraph ensemble as postproc
        acc_record = {s: [] for s in config_postproc['dir_emb_mat']} if load_acc_record else None
        assert dir_load is None
        config_postproc['emb_mat'] = {k: [None] * len(v) for k, v in config_postproc['dir_emb_mat'].items()}
        for sname, dirs_l in config_postproc['dir_emb_mat'].items():
            for i, di in enumerate(dirs_l):
                if load_acc_record:
                    acc_record[sname].append(logger.decode_csv('final', di))
                for f in os.listdir(di):
                    if 'ens' == f.split('.')[-1] and f.startswith('emb_mat'):
                        config_postproc['emb_mat'][sname][i] = torch.load(f"{di}/{f}")
                        break
        if all(m is not None for s, mat_l in config_postproc['emb_mat'].items() for m in mat_l):
            skip_instantiate = ['model']        # you have to load data (role, labels) anyways
    return config_postproc, acc_record, skip_instantiate