in shaDow/main.py [0:0]
def postprocessing(data_post, model, minibatch, logger, config, acc_record):
"""
Detailed instructions to run post-processing to be added soon.
Post-processing is not described in our paper. So this part of code is WIP and
only meant for experimentation.
If acc_record is None, then we don't check accuracy. This enables CS for still running jobs.
"""
from shaDow.utils import merge_stat_record
logger.init_log2file(status='final')
def _common_setup(dmodel):
logger.set_loader_path(dmodel)
logger.load_model(model, optimizer=None, copy=False)
logger.info_batch[TRAIN].PERIOD_LOG = 1
for md in [TRAIN, VALID, TEST]:
minibatch.disable_cache(md)
if config['method'].lower() == 'cs':
from shaDow.postproc_CnS import correct_smooth
# NOTE: setting the TRAIN evaluation period to > 1 will only make the
# log / print message "appear" to be nondeterministic. However, the
# full prediction matrix `pred_mat` is always deterministic regardless
# of the evaluation frequency. So PERIOD_LOG has no effect on the C&S output.
assert acc_record is None or (type(acc_record) == list and len(acc_record) == len(config['dir_pred_mat']))
# generate and store prediction matrix if not yet available from external file
for i, dmodel in enumerate(config['dir_pred_mat']):
if config['pred_mat'][i] is None:
_common_setup(dmodel)
if minibatch.name_data not in ['arxiv', 'products']:
logger.printf(f"POSTPROC OF CS ONLY DOES NOT SUPPORT {minibatch.name_data} YET")
raise NotImplementedError
pred_mat = torch.zeros(minibatch.label_full.shape).to(config['dev_torch'])
for md in [TRAIN, VALID, TEST]:
one_epoch(0, md, model, minibatch, logger, status='final', pred_mat=pred_mat)
fname_pred = 'pred_mat_{}.cs' if acc_record is not None else '__pred_mat_{}.cs'
logger.save_tensor(pred_mat, fname_pred, use_path_loader=True)
config['pred_mat'][i] = pred_mat
logger.reset()
if acc_record is not None:
acc_record = merge_stat_record(acc_record)
acc_orig, acc_post = correct_smooth(
config['name_data'],
config['dev_torch'],
config['pred_mat'],
config['hyperparameter']['norm_sym'],
config['hyperparameter']['alpha']
)
# double check if acc calulated by C&S matches with the record (i.e., acc_orig & acc_record)
if acc_record is not None:
for md in [VALID, TEST]:
acc_orig_m = [round(a, 4) for a in acc_orig[md]]
acc_recd_m = [round(a, 4) for a in acc_record['accuracy'][md]]
assert all(abs(acc_orig_m[i] - acc_recd_m[i]) <= 0.0001 for i in range(len(acc_orig_m))),\
"[ACC MISMATCH] MAYBE YOU WANT TO REMOVE THE STORED IN THIS RUN. "
elif config['method'].lower() == 'ensemble':
from shaDow.postproc_ens import ensemble_multirun
assert acc_record is None or (type(acc_record) == dict and len(acc_record) == len(config['dir_emb_mat']))
# the below 'for' loop is eval / inference only (no need to reset model)
for sname, dirs_l in config['dir_emb_mat'].items(): # ppr: [,,], khop: [,,]
for i, dmodel in enumerate(dirs_l): # [,,]
if config['emb_mat'][sname][i] is None: # single model
# inference
_common_setup(dmodel)
N, F = minibatch.feat_full.shape[0], model.dim_hidden
emb_mat = [torch.zeros((N, F)).to(config['dev_torch']) for i in range(model.num_ensemble)]
for md in [TRAIN, VALID, TEST]:
one_epoch(0, md, model, minibatch, logger, status='final', emb_ens=emb_mat)
fname_emb = 'emb_mat_{}.ens' if acc_record is not None else '__emb_mat_{}.ens'
_fname = logger.save_tensor(emb_mat, fname_emb, use_path_loader=True)
config['emb_mat'][sname][i] = emb_mat
logger.reset()
# ensemble and train
acc_orig, acc_post = ensemble_multirun(
data_post['node_set'],
config['emb_mat'],
data_post['label'],
config['architecture'],
config['hyperparameter'],
logger,
config['dev_torch'],
acc_record
)
# wrap up
logger.print_table_postproc(acc_orig, acc_post)