def postprocessing()

in shaDow/main.py [0:0]


def postprocessing(data_post, model, minibatch, logger, config, acc_record):
    """
    Detailed instructions to run post-processing to be added soon. 
    Post-processing is not described in our paper. So this part of code is WIP and
    only meant for experimentation. 
    If acc_record is None, then we don't check accuracy. This enables CS for still running jobs. 
    """
    from shaDow.utils import merge_stat_record
    logger.init_log2file(status='final')
    def _common_setup(dmodel):
        logger.set_loader_path(dmodel)
        logger.load_model(model, optimizer=None, copy=False)
        logger.info_batch[TRAIN].PERIOD_LOG = 1
        for md in [TRAIN, VALID, TEST]:
            minibatch.disable_cache(md)
    if config['method'].lower() == 'cs':
        from shaDow.postproc_CnS import correct_smooth
        # NOTE: setting the TRAIN evaluation period to > 1 will only make the 
        #   log / print message "appear" to be nondeterministic. However, the 
        #   full prediction matrix `pred_mat` is always deterministic regardless
        #   of the evaluation frequency. So PERIOD_LOG has no effect on the C&S output. 
        assert acc_record is None or (type(acc_record) == list and len(acc_record) == len(config['dir_pred_mat']))
        # generate and store prediction matrix if not yet available from external file
        for i, dmodel in enumerate(config['dir_pred_mat']):
            if config['pred_mat'][i] is None:
                _common_setup(dmodel)
                if minibatch.name_data not in ['arxiv', 'products']:
                    logger.printf(f"POSTPROC OF CS ONLY DOES NOT SUPPORT {minibatch.name_data} YET")
                    raise NotImplementedError
                pred_mat = torch.zeros(minibatch.label_full.shape).to(config['dev_torch'])
                for md in [TRAIN, VALID, TEST]:
                    one_epoch(0, md, model, minibatch, logger, status='final', pred_mat=pred_mat)
                fname_pred = 'pred_mat_{}.cs' if acc_record is not None else '__pred_mat_{}.cs'
                logger.save_tensor(pred_mat, fname_pred, use_path_loader=True)
                config['pred_mat'][i] = pred_mat
                logger.reset()
        if acc_record is not None:
            acc_record = merge_stat_record(acc_record)
        acc_orig, acc_post = correct_smooth(
            config['name_data'], 
            config['dev_torch'], 
            config['pred_mat'], 
            config['hyperparameter']['norm_sym'], 
            config['hyperparameter']['alpha']
        )
        # double check if acc calulated by C&S matches with the record (i.e., acc_orig & acc_record)
        if acc_record is not None:
            for md in [VALID, TEST]:
                acc_orig_m = [round(a, 4) for a in acc_orig[md]]
                acc_recd_m = [round(a, 4) for a in acc_record['accuracy'][md]]
                assert all(abs(acc_orig_m[i] - acc_recd_m[i]) <= 0.0001 for i in range(len(acc_orig_m))),\
                            "[ACC MISMATCH] MAYBE YOU WANT TO REMOVE THE STORED IN THIS RUN. "
    elif config['method'].lower() == 'ensemble':
        from shaDow.postproc_ens import ensemble_multirun
        assert acc_record is None or (type(acc_record) == dict and len(acc_record) == len(config['dir_emb_mat']))
        # the below 'for' loop is eval / inference only (no need to reset model)
        for sname, dirs_l in config['dir_emb_mat'].items(): # ppr: [,,], khop: [,,]
            for i, dmodel in enumerate(dirs_l):             # [,,]
                if config['emb_mat'][sname][i] is None:     # single model
                    # inference
                    _common_setup(dmodel)
                    N, F = minibatch.feat_full.shape[0], model.dim_hidden
                    emb_mat = [torch.zeros((N, F)).to(config['dev_torch']) for i in range(model.num_ensemble)]
                    for md in [TRAIN, VALID, TEST]:
                        one_epoch(0, md, model, minibatch, logger, status='final', emb_ens=emb_mat)
                    fname_emb = 'emb_mat_{}.ens' if acc_record is not None else '__emb_mat_{}.ens'
                    _fname = logger.save_tensor(emb_mat, fname_emb, use_path_loader=True)
                    config['emb_mat'][sname][i] = emb_mat
                    logger.reset()
        # ensemble and train
        acc_orig, acc_post = ensemble_multirun(
            data_post['node_set'], 
            config['emb_mat'], 
            data_post['label'], 
            config['architecture'], 
            config['hyperparameter'], 
            logger, 
            config['dev_torch'], 
            acc_record
        )
    # wrap up
    logger.print_table_postproc(acc_orig, acc_post)