in python/dglke/train_mxnet.py [0:0]
def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, barrier=None):
assert args.num_proc <= 1, "MXNet KGE does not support multi-process now"
assert args.rel_part == False, "No need for relation partition in single process for MXNet KGE"
logs = []
for arg in vars(args):
logging.info('{:20}:{}'.format(arg, getattr(args, arg)))
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
if args.strict_rel_part:
model.prepare_relation(mx.gpu(gpu_id))
if mxprofiler:
from mxnet import profiler
profiler.set_config(profile_all=True,
aggregate_stats=True,
continuous_dump=True,
filename='profile_output.json')
start = time.time()
for step in range(0, args.max_step):
pos_g, neg_g = next(train_sampler)
args.step = step
if(step == 1 and mxprofiler):
profiler.set_state('run')
with mx.autograd.record():
loss, log = model.forward(pos_g, neg_g, gpu_id)
loss.backward()
logs.append(log)
model.update(gpu_id)
if step % args.log_interval == 0:
for k in logs[0].keys():
v = sum(l[k] for l in logs) / len(logs)
print('[Train]({}/{}) average {}: {}'.format(step, args.max_step, k, v))
logs = []
print(time.time() - start)
start = time.time()
if args.valid and step % args.eval_interval == 0 and step > 1 and valid_samplers is not None:
start = time.time()
test(args, model, valid_samplers, mode='Valid')
print('test:', time.time() - start)
if args.strict_rel_part:
model.writeback_relation(rank, rel_parts)
if mxprofiler:
nd.waitall()
profiler.set_state('stop')
profiler.dump()
print(profiler.dumps())
# clear cache
logs = []