in shaDow/main.py [0:0]
def one_epoch(ep, mode, model, minibatch, logger, status='running', pred_mat=None, emb_ens=None):
"""
NOTE that pred_mat and emb_ens are ONLY used for post-processing.
For all experiments in our main paper, we have pred_mat = emb_ens = None
Also, for ensemble, we implement two algorithms.
1. ensemble during training: so no post-processing is needed.
2. ensemble during post-processing: so train a few models first, and then
launch another trainer just to train the ensembler during post-proc
The algorithm described in our paper (and appendix) follow algorithm 1.
"""
assert status in ['running', 'final'] and mode in [TRAIN, VALID, TEST]
minibatch.epoch_start_reset(ep, mode)
minibatch.shuffle_entity(mode)
logger.epoch_start_reset(ep, mode, minibatch.entity_epoch[mode].shape[0])
t1 = time.time()
while not minibatch.is_end_epoch(mode):
input_batch = minibatch.one_batch(
mode=mode, ret_raw_idx=(pred_mat is not None or emb_ens is not None)
)
if pred_mat is not None or emb_ens is not None:
idx_pred_raw = input_batch.pop_idx_raw()[0][input_batch.target_ens[0]]
output_batch = model.step(mode, status, input_batch)
if pred_mat is not None: # prepare for C&S
pred_mat[idx_pred_raw] = output_batch['preds']
if emb_ens is not None: # prepare for subgraph ensemble
assert len(emb_ens) == len(output_batch['emb_ens'])
for ie, e in enumerate(emb_ens):
e[idx_pred_raw] = output_batch['emb_ens'][ie]
logger.update_batch(mode, minibatch.batch_num, output_batch)
minibatch.profiler.print_summary()
t2 = time.time()
minibatch.epoch_end_reset(mode)
logger.update_epoch(ep, mode)
return logger.log_key_step(mode, status=status, time=t2 - t1)