benchmarks/api/bench_dist_neighbor_loader.py (137 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import argparse import os.path as osp import time import torch import torch.distributed as dist import graphlearn_torch as glt if __name__ == "__main__": print('*** DistNeighborLoader Benchmarks ***') parser = argparse.ArgumentParser('DistRandomSampler benchmarks.') parser.add_argument('--dataset', type=str, default='products', help='name of the dataset for benchmark') parser.add_argument('--num_nodes', type=int, default=2, help='number of worker nodes') parser.add_argument('--node_rank', type=int, default=0, help='worker node rank') parser.add_argument('--sample_nprocs', type=int, default=2, help='number of processes for sampling') parser.add_argument('--epochs', type=int, default=1, help='repeat epochs for sampling') parser.add_argument('--batch_size', type=int, default=2048, help='batch size for sampling') parser.add_argument('--shuffle', action="store_true", help='whether to shuffle input seeds at each epoch') parser.add_argument('--with_edge', action="store_true", help='whether to sample with edge ids') parser.add_argument('--collect_features', action='store_true', help='whether to collect features for sampled results') parser.add_argument('--worker_concurrency', type=int, default=4, help='concurrency for each sampling worker') parser.add_argument('--channel_size', type=str, default='4GB', help='memory used for shared-memory channel') parser.add_argument('--master_addr', type=str, default='localhost', help='master ip address for synchronization across all training nodes') parser.add_argument('--master_port', type=str, default='11234', help='port for synchronization across all training nodes') args = parser.parse_args() dataset_name = args.dataset num_nodes = args.num_nodes node_rank = args.node_rank sampling_nprocs = args.sample_nprocs device_count = torch.cuda.device_count() epochs = args.epochs batch_size = args.batch_size shuffle = args.shuffle with_edge = args.with_edge collect_features = args.collect_features worker_concurrency = args.worker_concurrency channel_size = args.channel_size master_addr = str(args.master_addr) sampling_master_port = int(args.master_port) torch_pg_master_port = sampling_master_port + 1 print('- dataset: {}'.format(dataset_name)) print('- total nodes: {}'.format(num_nodes)) print('- node rank: {}'.format(node_rank)) print('- device count: {}'.format(device_count)) print('- sampling nprocs per training proc: {}'.format(sampling_nprocs)) print('- epochs: {}'.format(epochs)) print('- batch size: {}'.format(batch_size)) print('- shuffle: {}'.format(shuffle)) print('- sample with edge id: {}'.format(with_edge)) print('- collect remote features: {}'.format(collect_features)) print('- sampling concurrency per worker: {}'.format(worker_concurrency)) print('- channel size: {}'.format(channel_size)) print('- master addr: {}'.format(master_addr)) print('- sampling master port: {}'.format(sampling_master_port)) print('** Loading dist dataset ...') root = osp.join(osp.dirname(osp.realpath(__file__)), '..', '..', 'data', dataset_name) dataset = glt.distributed.DistDataset() dataset.load( root_dir=osp.join(root, 'ogbn-'+dataset_name+'-partitions'), partition_idx=node_rank, graph_mode='ZERO_COPY', device_group_list=[glt.data.DeviceGroup(0, [0]), glt.data.DeviceGroup(1, [1])], # 2 GPUs device=0 ) print('** Loading input seeds ...') seeds_dir = osp.join(root, 'ogbn-'+dataset_name+'-test-partitions') seeds_data = torch.load(osp.join(seeds_dir, f'partition{node_rank}.pt')) print('** Initializing worker group context ...') glt.distributed.init_worker_group( world_size=num_nodes, rank=node_rank, group_name='dist-neighbor-loader-benchmarks' ) dist_context = glt.distributed.get_context() print('** Initializing process group') dist.init_process_group('gloo', rank=dist_context.rank, world_size=dist_context.world_size, init_method='tcp://{}:{}'.format(master_addr, torch_pg_master_port)) print('** Launching dist neighbor loader ...') dist_loader = glt.distributed.DistNeighborLoader( data=dataset, num_neighbors=[15, 10, 5], input_nodes=seeds_data, batch_size=batch_size, shuffle=shuffle, drop_last=True, with_edge=with_edge, collect_features=collect_features, to_device=torch.device('cuda:0'), worker_options=glt.distributed.MpDistSamplingWorkerOptions( num_workers=sampling_nprocs, worker_devices=[torch.device('cuda', i % device_count) for i in range(sampling_nprocs)], worker_concurrency=worker_concurrency, master_addr=master_addr, master_port=sampling_master_port, channel_size=channel_size, pin_memory=True ) ) print('** Benchmarking ...') f = open('benchmark.txt', 'a+') for epoch in range(epochs): num_sampled_nodes = 0 num_sampled_edges = 0 num_collected_features = 0 start = time.time() for i, batch in enumerate(dist_loader): if i % 100 == 0: f.write('Epoch {}, Batch {}\n'.format(epoch, i)) num_sampled_nodes += batch.node.numel() num_sampled_edges += batch.edge_index.size(1) if batch.x is not None: num_collected_features += batch.x.size(0) torch.cuda.synchronize() total_time = time.time() - start f.write('** Epoch {} **\n'.format(epoch)) f.write('- total time: {}s\n'.format(total_time)) f.write('- total sampled nodes: {}\n'.format(num_sampled_nodes)) f.write('- sampling nodes per sec: {} M\n'.format((num_sampled_nodes / total_time) / 1000000)) f.write('- total sampled edges: {}\n'.format(num_sampled_edges)) f.write('- sampling edges per sec: {} M\n'.format((num_sampled_edges / total_time) / 1000000)) f.write('- total collected features: {}\n'.format(num_collected_features)) f.write('- collecting features per sec: {} M\n'.format((num_collected_features / total_time) / 1000000)) dist.barrier() time.sleep(1) print('** Exit ...')