in example/reinforcement-learning/dqn/base.py [0:0]
def switch_bucket(self, bucket_kwargs=None, data_shapes=None):
if bucket_kwargs is not None:
self.curr_bucket_key = get_bucket_key(bucket_kwargs=bucket_kwargs)
# 1. Check if bucket key exists
if self.curr_bucket_key in self._buckets:
if data_shapes is not None:
if tuple(data_shapes.items()) not in self._buckets[self.curr_bucket_key]['exe']:
#TODO Optimize the reshaping functionality!
self._buckets[self.curr_bucket_key]['exe'][tuple(data_shapes.items())] = \
self.exe.reshape(partial_shaping=True, allow_up_sizing=True, **data_shapes)
self._buckets[self.curr_bucket_key]['data_shapes'] = data_shapes
else:
self._buckets[self.curr_bucket_key]['data_shapes'] = data_shapes
return
# 2. If the bucket key does not exist, create new symbol + executor
assert data_shapes is not None, "Must set data_shapes for new bucket!"
if isinstance(self.sym_gen, mx.symbol.Symbol):
sym = self.sym_gen
else:
sym = self.sym_gen(**dict(self.curr_bucket_key))
arg_names = sym.list_arguments()
aux_names = sym.list_auxiliary_states()
param_names = [n for n in arg_names
if n in self.learn_init_keys or (n not in data_shapes.keys())]
for k, v in data_shapes.items():
assert isinstance(v, tuple), "Data_shapes must be tuple! Find k=%s, v=%s, " \
"data_shapes=%s" % (k, str(v), str(data_shapes))
arg_shapes, _, aux_shapes = sym.infer_shape(**data_shapes)
arg_name_shape = OrderedDict([(k, s) for k, s in zip(arg_names, arg_shapes)])
if self.params is None:
self.params = OrderedDict([(n, nd.empty(arg_name_shape[n], ctx=self.ctx))
for n in param_names])
self.params_grad = OrderedDict([(n, nd.empty(arg_name_shape[n], ctx=self.ctx))
for n in param_names])
if len(self.params) > 0:
assert self.initializer is not None, \
'We must set the initializer if we donnot initialize' \
'manually the free parameters of the network!!'
for k, v in self.params.items():
self.initializer(k, v)
else:
assert set(arg_name_shape.items()) == \
set(data_shapes.items() + [(k, v.shape) for k, v in self.params.items()])
if self.aux_states is None:
self.aux_states = OrderedDict([(k, nd.empty(s, ctx=self.ctx))
for k, s in zip(aux_names, aux_shapes)])
data_inputs = {k: mx.nd.empty(data_shapes[k], ctx=self.ctx)
for k in set(data_shapes.keys()) - set(self.learn_init_keys)}
if len(self._buckets) > 0:
shared_exe = list(list(self._buckets.values())[0]['exe'].values())[0]
else:
shared_exe = None
self._buckets[self.curr_bucket_key] = {
'exe': {tuple(data_shapes.items()):
sym.bind(ctx=self.ctx,
args=dict(self.params, **data_inputs),
args_grad=dict(self.params_grad.items()),
aux_states=self.aux_states,
shared_exec=shared_exe)
},
'data_shapes': data_shapes,
'sym': sym
}