in python/dglke/kvclient.py [0:0]
def start_client(args):
"""Start kvclient for training
"""
init_time_start = time.time()
time.sleep(WAIT_TIME) # wait for launch script
# We cannot support gpu distributed training yet
args.gpu = [-1]
args.mix_cpu_gpu = False
args.async_update = False
# We don't use relation partition in distributed training yet
args.rel_part = False
args.strict_rel_part = False
args.soft_rel_part = False
# We don't support validation in distributed training
args.valid = False
total_machine = get_machine_count(args.ip_config)
server_namebook = dgl.contrib.read_ip_config(filename=args.ip_config)
machine_id = get_local_machine_id(server_namebook)
dataset, entity_partition_book, local2global = get_partition_dataset(
args.data_path,
args.dataset,
machine_id)
n_entities = dataset.n_entities
n_relations = dataset.n_relations
print('Partition %d n_entities: %d' % (machine_id, n_entities))
print("Partition %d n_relations: %d" % (machine_id, n_relations))
entity_partition_book = F.tensor(entity_partition_book)
relation_partition_book = get_long_tail_partition(dataset.n_relations, total_machine)
relation_partition_book = F.tensor(relation_partition_book)
local2global = F.tensor(local2global)
relation_partition_book.share_memory_()
entity_partition_book.share_memory_()
local2global.share_memory_()
g = ConstructGraph(dataset, args)
train_data = TrainDataset(g, dataset, args, ranks=args.num_client)
if args.neg_sample_size_eval < 0:
args.neg_sample_size_eval = dataset.n_entities
args.batch_size = get_compatible_batch_size(args.batch_size, args.neg_sample_size)
args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval, args.neg_sample_size_eval)
args.num_workers = 8 # fix num_workers to 8
train_samplers = []
for i in range(args.num_client):
train_sampler_head = train_data.create_sampler(args.batch_size,
args.neg_sample_size,
args.neg_sample_size,
mode='head',
num_workers=args.num_workers,
shuffle=True,
exclude_positive=False,
rank=i)
train_sampler_tail = train_data.create_sampler(args.batch_size,
args.neg_sample_size,
args.neg_sample_size,
mode='tail',
num_workers=args.num_workers,
shuffle=True,
exclude_positive=False,
rank=i)
train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
args.neg_sample_size, args.neg_sample_size,
True, n_entities))
dataset = None
model = load_model(args, n_entities, n_relations)
model.share_memory()
print('Total initialize time {:.3f} seconds'.format(time.time() - init_time_start))
rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
cross_rels = train_data.cross_rels if args.soft_rel_part else None
procs = []
for i in range(args.num_client):
proc = mp.Process(target=dist_train_test, args=(args,
model,
train_samplers[i],
entity_partition_book,
relation_partition_book,
local2global,
i,
rel_parts,
cross_rels))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()