in lib/datasets/data_input_helper.py [0:0]
def _create_execution_context(execution_context, init_pool, worker_ids, expected_data_size,
num_processes, batch_size):
logger.info('CREATING EXECUTION CONTEXT')
if execution_context is None:
pools = {}
shared_data_lists = {}
else:
pools = execution_context.pools
shared_data_lists = execution_context.shared_data_lists
logger.info('POOLS: {}'.format(pools))
logger.info('SHARED DATA LISTS: {}'.format(len(shared_data_lists)))
if cfg.TRAIN.CROP_SIZE == cfg.TEST.CROP_SIZE:
scales = [cfg.TRAIN.CROP_SIZE]
else:
scales = [cfg.TRAIN.CROP_SIZE, cfg.TEST.CROP_SIZE]
for worker_id in worker_ids:
# for each worker_id, create a shared pool
shared_data_list = [[] for i in range(len(scales))]
shared_data_lists[worker_id] = shared_data_list
logger.info('worker_id: {} list: {}'.format(
worker_id, len(shared_data_lists)))
logger.info('worker_id: {} list keys: {}'.format(
worker_id, shared_data_lists.keys()))
# for each worker_id, we fetch a batch size of 32 and this is being
# done by various parallel processes
for i in range(len(scales)):
if scales[i] == cfg.TRAIN.CROP_SIZE:
bz = cfg.TRAIN.BATCH_SIZE
else:
bz = cfg.TEST.BATCH_SIZE
for _ in range(bz):
shared_arr = RawArray(
ctypes.c_float,
scales[i] ** 2 * 3 * cfg.TRAIN.VIDEO_LENGTH)
one_data_list = [shared_arr]
if cfg.DATASET == 'ava':
shared_arr_box = RawArray(
ctypes.c_float,
cfg.LFB.NUM_LFB_FEAT * 4)
shared_arr_original_boxes = RawArray(
ctypes.c_float,
cfg.LFB.NUM_LFB_FEAT * 4)
# height, width
shared_arr_metadata = RawArray(
ctypes.c_float,
2)
one_data_list += [
shared_arr_box,
shared_arr_original_boxes,
shared_arr_metadata]
shared_data_list[i].append(one_data_list)
pools[worker_id] = Pool(
processes=num_processes,
initializer=init_pool,
initargs=(shared_data_list,)
)
context = collections.namedtuple(
'ExecutionContext',
['pools', 'shared_data_lists']
)
context.pools = pools
context.shared_data_lists = shared_data_lists
logger.info('CREATED POOL: {}'.format(pools))
logger.info('CREATED LISTS: {}'.format(len(shared_data_lists)))
logger.info('POOL keys: {}'.format(pools.keys()))
logger.info('LIST keys: {}'.format(shared_data_lists.keys()))
return context