in tools/bandwidth/measure.py [0:0]
def run(network, optimizer, gpus, kv_store, image_shape, disp_batches,
num_batches, test_results, **kwargs):
# create kvstore and optimizer
devs = [mx.gpu(int(i)) for i in gpus.split(',')]
kv = mx.kv.create(kv_store)
if optimizer is None or optimizer == 'None':
opt = None
else:
opt = mx.optimizer.Optimizer.create_optimizer(optimizer)
kv.set_optimizer(opt)
updater = mx.optimizer.get_updater(mx.optimizer.Optimizer.create_optimizer(optimizer))
# create network
symbol = import_module(network).get_symbol(image_shape=image_shape, **kwargs)
# a fake batch size 32, which does not affect the results
data_shape = (32,) + tuple([int(s) for s in image_shape.split(',')])
shapes = get_shapes(symbol, data_shape)
size = float(sum([reduce(lambda x,y : x*y, s, 1) for s in shapes])) * 4 / 1e6
logging.info('num of arrays = %d, total size = %f MB' % (len(shapes), size))
for i, s in enumerate(shapes):
kv.init(i, mx.nd.zeros(s))
grads_val = [[mx.random.uniform(-1,1,shape=s) for d in devs] for s in shapes]
grads = [[g.as_in_context(d) for g, d in zip(gs, devs)] for gs in grads_val]
weights = [[mx.nd.zeros(s, d) for d in devs] for s in shapes]
cpu_grads = [mx.nd.array(sum([g.asnumpy() for g in gs]))*kv.num_workers for gs in grads_val]
cpu_weights = [mx.nd.zeros(s) for s in shapes]
toc = 0
Results = namedtuple('Results', ['iter', 'time', 'bandwidth', 'error'])
res = []
for b in range(0, num_batches+1):
tic = time.time()
for i,g in enumerate(grads):
kv.push(i, g, i)
for i,w in enumerate(weights):
kv.pull(i, w, i)
for ws in weights:
for w in ws:
w.wait_to_read()
toc += time.time() - tic
if test_results:
if opt == None:
err = error(weights, cpu_grads)
else:
for i, wg in enumerate(zip(cpu_weights, cpu_grads)):
updater(i, wg[1], wg[0])
err = error(weights, cpu_weights)
else:
err = -1
if b % disp_batches == 0:
toc /= disp_batches
if b != 0:
# 0 is used for warmup, ignored
r = Results(iter=b, time=toc, error=err,
bandwidth=size*2*(len(devs)-1)/len(devs)/toc/1e3)
logging.info('iter %d, %f sec, %f GB/sec per gpu, error %f' % (
r.iter, r.time, r.bandwidth, r.error))
res.append(r)
toc = 0
return res