in shaDow/main.py [0:0]
def main(task, args, args_logger):
assert task in ['train', 'inference', 'postproc']
dataset = args.dataset
dir_log = meta_config['logging']['dir']['local']
os_ = meta_config['device']['software']['os']
(
params_train,
config_sampler_preproc,
config_sampler_train,
config_data,
arch_gnn,
dir_log_full
) = parse_n_prepare(task, args, dataset, dir_log, os_=os_)
metrics = Metrics(dataset, (arch_gnn['loss'] == 'sigmoid'), DATA_METRIC[dataset], params_train['term_window_size'])
config_term = {'window_size': params_train['term_window_size'], 'window_aggr': params_train['term_window_aggr']}
logger = Logger(
task,
{
"args" : args,
"arch_gnn" : arch_gnn,
"data" : config_data,
"hyperparams" : params_train,
"sampler_preproc": config_sampler_preproc,
"sampler_train" : config_sampler_train
},
dir_log_full,
metrics,
config_term,
no_log=args.no_log,
timestamp=timestamp,
log_test_convergence=args.log_test_convergence,
period_batch_train=args.eval_train_every,
no_pbar=args.no_pbar,
**args_logger
)
if task == 'postproc':
config_postproc, acc_record, skip_instantiate = parse_n_prepare_postproc(
args.postproc_dir,
args.postproc_configs,
dataset, dir_log,
arch_gnn,
logger
)
else:
skip_instantiate = []
# skip_instantiate specifies if we want to skip certain steps in instantiating the model:
# e.g., For C&S postproc, don't need to load the model if we have already stored the generated embeddings.
dir_data = meta_config['data']['dir']
if 'data' not in skip_instantiate:
data_train = load_data(dir_data, dataset, config_data, printf=logger.printf)
else:
data_train = None
if 'model' not in skip_instantiate:
assert 'data' not in skip_instantiate
model, minibatch = instantiate(
dataset,
dir_data,
data_train,
params_train, arch_gnn,
config_sampler_preproc, config_sampler_train,
meta_config['device']['cpu']['max_threads'],
args.full_tensor_on_gpu,
args.no_pbar,
args.seed
)
logger.printf(f"TOTAL NUM OF PARAMS = {sum(p.numel() for p in model.parameters())}", style="yellow")
else:
model = minibatch = None
# Now handle the specific tasks
if task == 'train':
try:
nocache = args.nocache if type(args.nocache) != str else args.nocache.lower()
if args.reload_model_dir is not None:
logger.set_loader_path(args.reload_model_dir)
logger.load_model(model, optimizer=model.optimizer, copy=False, device=device)
train(model, minibatch, params_train["end"], logger, nocache=nocache)
status = 'finished'
except KeyboardInterrupt:
status = 'killed'
print("Pressed CTRL-C! Stopping. ")
except Exception as err:
status = 'crashed'
import traceback
traceback.print_tb(err.__traceback__)
finally:
# logger will only remove file when you are running the test *.yml
logger.end_training(status) # cleanup the unwanted log files
elif task == 'inference':
if not args.compute_complexity_only:
logger.set_loader_path(args.inference_dir)
inference(model, minibatch, logger, device=device, inf_train=args.is_inf_train)
else:
compute_complexity(model, minibatch, args.inference_budget, logger)
else: # postprocessing
config_postproc['dev_torch'] = device
config_postproc['name_data'] = dataset
if minibatch is not None:
assert minibatch.prediction_task == 'node'
data_postproc = {"label": minibatch.label_full, "node_set": minibatch.entity_epoch}
elif data_train is not None:
data_postproc = {"label": data_train['label_full'], "node_set": data_train['node_set']}
else:
data_postproc = None
postprocessing(data_postproc, model, minibatch, logger, config_postproc, acc_record)