def train()

in python/dglke/train_mxnet.py [0:0]


def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, barrier=None):
    assert args.num_proc <= 1, "MXNet KGE does not support multi-process now"
    assert args.rel_part == False, "No need for relation partition in single process for MXNet KGE"
    logs = []

    for arg in vars(args):
        logging.info('{:20}:{}'.format(arg, getattr(args, arg)))

    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

    if args.strict_rel_part:
        model.prepare_relation(mx.gpu(gpu_id))

    if mxprofiler:
        from mxnet import profiler
        profiler.set_config(profile_all=True,
                            aggregate_stats=True,
                            continuous_dump=True,
                            filename='profile_output.json')
    start = time.time()
    for step in range(0, args.max_step):
        pos_g, neg_g = next(train_sampler)
        args.step = step
        if(step == 1 and mxprofiler):
            profiler.set_state('run')
        with mx.autograd.record():
            loss, log = model.forward(pos_g, neg_g, gpu_id)
        loss.backward()
        logs.append(log)
        model.update(gpu_id)

        if step % args.log_interval == 0:
            for k in logs[0].keys():
                v = sum(l[k] for l in logs) / len(logs)
                print('[Train]({}/{}) average {}: {}'.format(step, args.max_step, k, v))
            logs = []
            print(time.time() - start)
            start = time.time()

        if args.valid and step % args.eval_interval == 0 and step > 1 and valid_samplers is not None:
            start = time.time()
            test(args, model, valid_samplers, mode='Valid')
            print('test:', time.time() - start)
    if args.strict_rel_part:
        model.writeback_relation(rank, rel_parts)
    if mxprofiler:
        nd.waitall()
        profiler.set_state('stop')
        profiler.dump()
        print(profiler.dumps())
    # clear cache
    logs = []