def main()

in python/dglke/train.py [0:0]


def main():
    args = ArgParser().parse_args()
    prepare_save_path(args)

    init_time_start = time.time()
    # load dataset and samplers
    dataset = get_dataset(args.data_path,
                          args.dataset,
                          args.format,
                          args.delimiter,
                          args.data_files,
                          args.has_edge_importance)

    if args.neg_sample_size_eval < 0:
        args.neg_sample_size_eval = dataset.n_entities
    args.batch_size = get_compatible_batch_size(args.batch_size, args.neg_sample_size)
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval, args.neg_sample_size_eval)
    # We should turn on mix CPU-GPU training for multi-GPU training.
    if len(args.gpu) > 1:
        args.mix_cpu_gpu = True
        if args.num_proc < len(args.gpu):
            args.num_proc = len(args.gpu)
    # We need to ensure that the number of processes should match the number of GPUs.
    if len(args.gpu) > 1 and args.num_proc > 1:
        assert args.num_proc % len(args.gpu) == 0, \
                'The number of processes needs to be divisible by the number of GPUs'
    # For multiprocessing training, we need to ensure that training processes are synchronized periodically.
    if args.num_proc > 1:
        args.force_sync_interval = 1000

    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    args.soft_rel_part = args.mix_cpu_gpu and args.rel_part
    g = ConstructGraph(dataset, args)
    train_data = TrainDataset(g, dataset, args, ranks=args.num_proc, has_importance=args.has_edge_importance)
    # if there is no cross partition relaiton, we fall back to strict_rel_part
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part == False)
    args.num_workers = 8 # fix num_worker to 8

    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            # for each GPU, allocate num_proc // num_GPU processes
            train_sampler_head = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
                                                           args.neg_sample_size,
                                                           mode='head',
                                                           num_workers=args.num_workers,
                                                           shuffle=True,
                                                           exclude_positive=False,
                                                           rank=i)
            train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
                                                           args.neg_sample_size,
                                                           mode='tail',
                                                           num_workers=args.num_workers,
                                                           shuffle=True,
                                                           exclude_positive=False,
                                                           rank=i)
            train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
                                                                  args.neg_sample_size, args.neg_sample_size,
                                                                  True, dataset.n_entities,
                                                                  args.has_edge_importance))

        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
                                                        args.neg_sample_size, args.neg_sample_size,
                                                       True, dataset.n_entities,
                                                       args.has_edge_importance)
    else: # This is used for debug
        train_sampler_head = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
                                                       args.neg_sample_size,
                                                       mode='head',
                                                       num_workers=args.num_workers,
                                                       shuffle=True,
                                                       exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
                                                       args.neg_sample_size,
                                                       mode='tail',
                                                       num_workers=args.num_workers,
                                                       shuffle=True,
                                                       exclude_positive=False)
        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
                                                        args.neg_sample_size, args.neg_sample_size,
                                                        True, dataset.n_entities,
                                                        args.has_edge_importance)


    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
    train_data = None
    gc.collect()

    if args.valid or args.test:
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
        if args.valid:
            assert dataset.valid is not None, 'validation set is not provided'
        if args.test:
            assert dataset.test is not None, 'test set is not provided'
        eval_dataset = EvalDataset(g, dataset, args)


    if args.valid:
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            if args.dataset == "wikikg90M":
                for i in range(args.num_proc):
                    valid_sampler_tail = eval_dataset.create_sampler_wikikg90M('valid', args.batch_size_eval,
                                                                               mode='tail',
                                                                               rank=i, ranks=args.num_proc)
                    valid_sampler_tails.append(valid_sampler_tail)
            else:
                for i in range(args.num_proc):
                    valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                      args.neg_sample_size_eval,
                                                                      args.neg_sample_size_eval,
                                                                      args.eval_filter,
                                                                      mode='chunk-head',
                                                                      num_workers=args.num_workers,
                                                                      rank=i, ranks=args.num_proc)
                    valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                      args.neg_sample_size_eval,
                                                                      args.neg_sample_size_eval,
                                                                      args.eval_filter,
                                                                      mode='chunk-tail',
                                                                      num_workers=args.num_workers,
                                                                      rank=i, ranks=args.num_proc)
                    valid_sampler_heads.append(valid_sampler_head)
                    valid_sampler_tails.append(valid_sampler_tail)
        else: # This is used for debug
            if args.dataset == "wikikg90M":
                valid_sampler_tail = eval_dataset.create_sampler_wikikg90M('valid', args.batch_size_eval,
                                                             mode='tail',
                                                             rank=0, ranks=1)
            else:
                valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                 args.neg_sample_size_eval,
                                                                 args.neg_sample_size_eval,
                                                                 args.eval_filter,
                                                                 mode='chunk-head',
                                                                 num_workers=args.num_workers,
                                                                 rank=0, ranks=1)
                valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                 args.neg_sample_size_eval,
                                                                 args.neg_sample_size_eval,
                                                                 args.eval_filter,
                                                                 mode='chunk-tail',
                                                                 num_workers=args.num_workers,
                                                                 rank=0, ranks=1)
    if args.test:
        if args.num_test_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            if args.dataset == "wikikg90M":
                for i in range(args.num_proc):
                    valid_sampler_tail = eval_dataset.create_sampler_wikikg90M('test', args.batch_size_eval,
                                                                               mode='tail',
                                                                              rank=i, ranks=args.num_proc)
                    valid_sampler_tails.append(valid_sampler_tail)
            else:
                for i in range(args.num_test_proc):
                    test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                     args.neg_sample_size_eval,
                                                                     args.neg_sample_size_eval,
                                                                     args.eval_filter,
                                                                     mode='chunk-head',
                                                                     num_workers=args.num_workers,
                                                                     rank=i, ranks=args.num_test_proc)
                    test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                     args.neg_sample_size_eval,
                                                                     args.neg_sample_size_eval,
                                                                     args.eval_filter,
                                                                     mode='chunk-tail',
                                                                     num_workers=args.num_workers,
                                                                     rank=i, ranks=args.num_test_proc)
                    test_sampler_heads.append(test_sampler_head)
                    test_sampler_tails.append(test_sampler_tail)
        else:
            if args.dataset == "wikikg90M":
                test_sampler_tail = eval_dataset.create_sampler_wikikg90M('test', args.batch_size_eval,
                                                            mode='tail',
                                                            rank=0, ranks=1)
            else:
                test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                args.neg_sample_size_eval,
                                                                args.neg_sample_size_eval,
                                                                args.eval_filter,
                                                                mode='chunk-head',
                                                                num_workers=args.num_workers,
                                                                rank=0, ranks=1)
                test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                args.neg_sample_size_eval,
                                                                args.neg_sample_size_eval,
                                                                args.eval_filter,
                                                                mode='chunk-tail',
                                                                num_workers=args.num_workers,
                                                                rank=0, ranks=1)

    # load model
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    emap_file = dataset.emap_fname
    rmap_file = dataset.rmap_fname

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    gc.collect()

    model = load_model(args, n_entities, n_relations)
    if args.num_proc > 1 or args.async_update:
        model.share_memory()

    print('Total initialize time {:.3f} seconds'.format(time.time() - init_time_start))

    # train
    start = time.time()
    if args.num_proc > 1:
        procs = []
        barrier = mp.Barrier(args.num_proc)
        for i in range(args.num_proc):
            if args.dataset == "wikikg90M":
                valid_sampler = [valid_sampler_tails[i]] if args.valid else None
            else:
                valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None
            proc = mp.Process(target=train_mp, args=(args,
                                                     model,
                                                     train_samplers[i],
                                                     valid_sampler,
                                                     i,
                                                     rel_parts,
                                                     cross_rels,
                                                     barrier))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        if args.dataset == "wikikg90M":
            valid_samplers = [valid_sampler_tail] if args.valid else None
        else:
            valid_samplers = [valid_sampler_head, valid_sampler_tail] if args.valid else None
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)

    print('training takes {} seconds'.format(time.time() - start))

    if not args.no_save_emb:
        save_model(args, model, emap_file, rmap_file)

    # test
    if args.test:
        start = time.time()
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
            procs = []
            for i in range(args.num_test_proc):
                if args.dataset == "wikikg90M":
                    proc = mp.Process(target=test_mp, args=(args,
                                                            model,
                                                            [test_sampler_tails[i]],
                                                            i,
                                                            'Test',
                                                            queue))
                else:
                    proc = mp.Process(target=test_mp, args=(args,
                                                            model,
                                                            [test_sampler_heads[i], test_sampler_tails[i]],
                                                            i,
                                                            'Test',
                                                            queue))
                procs.append(proc)
                proc.start()

            if args.dataset == "wikikg90M":
                print('The predict results have saved to {}'.format(args.save_path))
            else:
                total_metrics = {}
                metrics = {}
                logs = []
                for i in range(args.num_test_proc):
                    log = queue.get()
                    logs = logs + log

                for metric in logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
                print("-------------- Test result --------------")
                for k, v in metrics.items():
                    print('Test average {} : {}'.format(k, v))
                print("-----------------------------------------")

            for proc in procs:
                proc.join()
        else:
            if args.dataset == "wikikg90M":
                test(args, model, [test_sampler_tail])
            else:
                test(args, model, [test_sampler_head, test_sampler_tail])
            if args.dataset == "wikikg90M":
                print('The predict results have saved to {}'.format(args.save_path))
        print('testing takes {:.3f} seconds'.format(time.time() - start))