in python/dglke/train_pytorch.py [0:0]
def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, cross_rels=None, barrier=None, client=None):
logs = []
for arg in vars(args):
logging.info('{:20}:{}'.format(arg, getattr(args, arg)))
if len(args.gpu) > 0:
gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
else:
gpu_id = -1
if args.async_update:
model.create_async_update()
if args.strict_rel_part or args.soft_rel_part:
model.prepare_relation(th.device('cuda:' + str(gpu_id)))
if args.soft_rel_part:
model.prepare_cross_rels(cross_rels)
train_start = start = time.time()
sample_time = 0
update_time = 0
forward_time = 0
backward_time = 0
for step in range(0, args.max_step):
start1 = time.time()
pos_g, neg_g = next(train_sampler)
sample_time += time.time() - start1
if client is not None:
model.pull_model(client, pos_g, neg_g)
start1 = time.time()
loss, log = model.forward(pos_g, neg_g, gpu_id)
forward_time += time.time() - start1
start1 = time.time()
loss.backward()
backward_time += time.time() - start1
start1 = time.time()
if client is not None:
model.push_gradient(client)
else:
model.update(gpu_id)
update_time += time.time() - start1
logs.append(log)
# force synchronize embedding across processes every X steps
if args.force_sync_interval > 0 and \
(step + 1) % args.force_sync_interval == 0:
barrier.wait()
if (step + 1) % args.log_interval == 0:
if (client is not None) and (client.get_machine_id() != 0):
pass
else:
for k in logs[0].keys():
v = sum(l[k] for l in logs) / len(logs)
print('[proc {}][Train]({}/{}) average {}: {}'.format(rank, (step + 1), args.max_step, k, v))
logs = []
print('[proc {}][Train] {} steps take {:.3f} seconds'.format(rank, args.log_interval,
time.time() - start))
print('[proc {}]sample: {:.3f}, forward: {:.3f}, backward: {:.3f}, update: {:.3f}'.format(
rank, sample_time, forward_time, backward_time, update_time))
sample_time = 0
update_time = 0
forward_time = 0
backward_time = 0
start = time.time()
if args.valid and (step + 1) % args.eval_interval == 0 and step > 1 and valid_samplers is not None:
valid_start = time.time()
if args.strict_rel_part or args.soft_rel_part:
model.writeback_relation(rank, rel_parts)
# forced sync for validation
if barrier is not None:
barrier.wait()
test(args, model, valid_samplers, rank, mode='Valid')
print('[proc {}]validation take {:.3f} seconds:'.format(rank, time.time() - valid_start))
if args.soft_rel_part:
model.prepare_cross_rels(cross_rels)
if barrier is not None:
barrier.wait()
print('proc {} takes {:.3f} seconds'.format(rank, time.time() - train_start))
if args.async_update:
model.finish_async_update()
if args.strict_rel_part or args.soft_rel_part:
model.writeback_relation(rank, rel_parts)