def main()

in shaDow/main.py [0:0]


def main(task, args, args_logger):
    assert task in ['train', 'inference', 'postproc']
    dataset = args.dataset
    dir_log = meta_config['logging']['dir']['local']
    os_ = meta_config['device']['software']['os']
    (
        params_train, 
        config_sampler_preproc,
        config_sampler_train, 
        config_data, 
        arch_gnn, 
        dir_log_full
    ) = parse_n_prepare(task, args, dataset, dir_log, os_=os_)
    metrics = Metrics(dataset, (arch_gnn['loss'] == 'sigmoid'), DATA_METRIC[dataset], params_train['term_window_size'])
    config_term = {'window_size': params_train['term_window_size'], 'window_aggr': params_train['term_window_aggr']}
    logger = Logger(
        task,
        {
            "args"         : args, 
            "arch_gnn"     : arch_gnn, 
            "data"         : config_data,
            "hyperparams"  : params_train, 
            "sampler_preproc": config_sampler_preproc,
            "sampler_train"  : config_sampler_train
        }, 
        dir_log_full, 
        metrics, 
        config_term,
        no_log=args.no_log, 
        timestamp=timestamp, 
        log_test_convergence=args.log_test_convergence,
        period_batch_train=args.eval_train_every, 
        no_pbar=args.no_pbar,
        **args_logger
    )
    if task == 'postproc':
        config_postproc, acc_record, skip_instantiate = parse_n_prepare_postproc(
            args.postproc_dir, 
            args.postproc_configs, 
            dataset, dir_log, 
            arch_gnn, 
            logger
        )
    else:
        skip_instantiate = []

    # skip_instantiate specifies if we want to skip certain steps in instantiating the model:
    # e.g., For C&S postproc, don't need to load the model if we have already stored the generated embeddings. 
    dir_data = meta_config['data']['dir']
    if 'data' not in skip_instantiate:
        data_train = load_data(dir_data, dataset, config_data, printf=logger.printf)
    else:
        data_train = None
    if 'model' not in skip_instantiate:
        assert 'data' not in skip_instantiate
        model, minibatch = instantiate(
            dataset, 
            dir_data, 
            data_train, 
            params_train, arch_gnn, 
            config_sampler_preproc, config_sampler_train,
            meta_config['device']['cpu']['max_threads'],
            args.full_tensor_on_gpu,
            args.no_pbar,
            args.seed
        )
        logger.printf(f"TOTAL NUM OF PARAMS = {sum(p.numel() for p in model.parameters())}", style="yellow")
    else:
        model = minibatch = None
    
    # Now handle the specific tasks
    if task == 'train':
        try:
            nocache = args.nocache if type(args.nocache) != str else args.nocache.lower()
            if args.reload_model_dir is not None:
                logger.set_loader_path(args.reload_model_dir)
                logger.load_model(model, optimizer=model.optimizer, copy=False, device=device)
            train(model, minibatch, params_train["end"], logger, nocache=nocache)
            status = 'finished'
        except KeyboardInterrupt:
            status = 'killed'
            print("Pressed CTRL-C! Stopping. ")        
        except Exception as err:
            status = 'crashed'
            import traceback
            traceback.print_tb(err.__traceback__)
        finally:
            # logger will only remove file when you are running the test *.yml
            logger.end_training(status)     # cleanup the unwanted log files
    elif task == 'inference':
        if not args.compute_complexity_only:
            logger.set_loader_path(args.inference_dir)
            inference(model, minibatch, logger, device=device, inf_train=args.is_inf_train)
        else:
            compute_complexity(model, minibatch, args.inference_budget, logger)
    else:       # postprocessing
        config_postproc['dev_torch'] = device
        config_postproc['name_data'] = dataset
        if minibatch is not None:
            assert minibatch.prediction_task == 'node'
            data_postproc = {"label": minibatch.label_full, "node_set": minibatch.entity_epoch}
        elif data_train is not None:
            data_postproc = {"label": data_train['label_full'], "node_set": data_train['node_set']}
        else:
            data_postproc = None
        postprocessing(data_postproc, model, minibatch, logger, config_postproc, acc_record)