in python/dglke/train_pytorch.py [0:0]
def dist_train_test(args, model, train_sampler, entity_pb, relation_pb, l2g, rank=0, rel_parts=None, cross_rels=None, barrier=None):
if args.num_proc > 1:
th.set_num_threads(args.num_thread)
client = connect_to_kvstore(args, entity_pb, relation_pb, l2g)
client.barrier()
train_time_start = time.time()
train(args, model, train_sampler, None, rank, rel_parts, cross_rels, barrier, client)
total_train_time = time.time() - train_time_start
client.barrier()
# Release the memory of local model
model = None
if (client.get_machine_id() == 0) and (client.get_id() % args.num_client == 0): # pull full model from kvstore
# Pull model from kvstore
args.num_test_proc = args.num_client
dataset_full = dataset = get_dataset(args.data_path,
args.dataset,
args.format,
args.delimiter,
args.data_files)
args.train = False
args.valid = False
args.test = True
args.strict_rel_part = False
args.soft_rel_part = False
args.async_update = False
args.eval_filter = not args.no_eval_filter
if args.neg_deg_sample_eval:
assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."
print('Full data n_entities: ' + str(dataset_full.n_entities))
print("Full data n_relations: " + str(dataset_full.n_relations))
eval_dataset = EvalDataset(dataset_full, args)
if args.neg_sample_size_eval < 0:
args.neg_sample_size_eval = dataset_full.n_entities
args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval, args.neg_sample_size_eval)
model_test = load_model(args, dataset_full.n_entities, dataset_full.n_relations)
print("Pull relation_emb ...")
relation_id = F.arange(0, model_test.n_relations)
relation_data = client.pull(name='relation_emb', id_tensor=relation_id)
model_test.relation_emb.emb[relation_id] = relation_data
print("Pull entity_emb ... ")
# split model into 100 small parts
start = 0
percent = 0
entity_id = F.arange(0, model_test.n_entities)
count = int(model_test.n_entities / 100)
end = start + count
while True:
print("Pull model from kvstore: %d / 100 ..." % percent)
if end >= model_test.n_entities:
end = -1
tmp_id = entity_id[start:end]
entity_data = client.pull(name='entity_emb', id_tensor=tmp_id)
model_test.entity_emb.emb[tmp_id] = entity_data
if end == -1:
break
start = end
end += count
percent += 1
if not args.no_save_emb:
print("save model to %s ..." % args.save_path)
save_model(args, model_test)
print('Total train time {:.3f} seconds'.format(total_train_time))
if args.test:
model_test.share_memory()
start = time.time()
test_sampler_tails = []
test_sampler_heads = []
for i in range(args.num_test_proc):
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-head',
num_workers=args.num_workers,
rank=i, ranks=args.num_test_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
args.neg_sample_size_eval,
args.neg_sample_size_eval,
args.eval_filter,
mode='chunk-tail',
num_workers=args.num_workers,
rank=i, ranks=args.num_test_proc)
test_sampler_heads.append(test_sampler_head)
test_sampler_tails.append(test_sampler_tail)
eval_dataset = None
dataset_full = None
print("Run test, test processes: %d" % args.num_test_proc)
queue = mp.Queue(args.num_test_proc)
procs = []
for i in range(args.num_test_proc):
proc = mp.Process(target=test_mp, args=(args,
model_test,
[test_sampler_heads[i], test_sampler_tails[i]],
i,
'Test',
queue))
procs.append(proc)
proc.start()
total_metrics = {}
metrics = {}
logs = []
for i in range(args.num_test_proc):
log = queue.get()
logs = logs + log
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
print("-------------- Test result --------------")
for k, v in metrics.items():
print('Test average {} : {}'.format(k, v))
print("-----------------------------------------")
for proc in procs:
proc.join()
print('testing takes {:.3f} seconds'.format(time.time() - start))
client.shut_down() # shut down kvserver