in python/mxnet/module/executor_group.py [0:0]
def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
for_training, inputs_need_grad, shared_group=None, logger=logging,
fixed_param_names=None, grad_req='write', state_names=None):
self.param_names = param_names
self.arg_names = symbol.list_arguments()
self.aux_names = symbol.list_auxiliary_states()
self.symbol = symbol
self.contexts = contexts
self.workload = workload
self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
self.logger = logger
#In the future we should have a better way to profile memory per device (haibin)
self._total_exec_bytes = 0
self.fixed_param_names = fixed_param_names
if self.fixed_param_names is None:
self.fixed_param_names = []
self.state_names = state_names
if self.state_names is None:
self.state_names = []
if not for_training:
grad_req = 'null'
data_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes]
if label_shapes is not None:
label_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes]
data_names = [x.name for x in data_shapes]
if isinstance(grad_req, str):
self.grad_req = {}
for k in self.arg_names:
if k in self.param_names:
self.grad_req[k] = 'null' if k in self.fixed_param_names else grad_req
elif k in data_names:
self.grad_req[k] = grad_req if self.inputs_need_grad else 'null'
else:
self.grad_req[k] = 'null'
elif isinstance(grad_req, (list, tuple)):
assert len(grad_req) == len(self.arg_names)
self.grad_req = dict(zip(self.arg_names, grad_req))
elif isinstance(grad_req, dict):
self.grad_req = {}
for k in self.arg_names:
if k in self.param_names:
self.grad_req[k] = 'null' if k in self.fixed_param_names else 'write'
elif k in data_names:
self.grad_req[k] = 'write' if self.inputs_need_grad else 'null'
else:
self.grad_req[k] = 'null'
self.grad_req.update(grad_req)
else:
raise ValueError("grad_req must be one of str, list, tuple, or dict.")
if shared_group is not None:
self.shared_data_arrays = shared_group.shared_data_arrays
else:
self.shared_data_arrays = [{} for _ in contexts]
# initialize some instance variables
self.batch_size = None
self.slices = None
self.execs = []
self._default_execs = None
self.data_arrays = None
self.label_arrays = None
self.param_arrays = None
self.state_arrays = None
self.grad_arrays = None
self.aux_arrays = None
self.input_grad_arrays = None
self.data_shapes = None
self.label_shapes = None
self.data_names = None
self.label_names = None
self.data_layouts = None
self.label_layouts = None
self.output_names = self.symbol.list_outputs()
self.output_layouts = [DataDesc.get_batch_axis(self.symbol[name].attr('__layout__'))
for name in self.output_names]
self.num_outputs = len(self.symbol.list_outputs())
self.bind_exec(data_shapes, label_shapes, shared_group)