def switch_bucket()

in example/reinforcement-learning/dqn/base.py [0:0]


    def switch_bucket(self, bucket_kwargs=None, data_shapes=None):
        if bucket_kwargs is not None:
            self.curr_bucket_key = get_bucket_key(bucket_kwargs=bucket_kwargs)
        # 1. Check if bucket key exists
        if self.curr_bucket_key in self._buckets:
            if data_shapes is not None:
                if tuple(data_shapes.items()) not in self._buckets[self.curr_bucket_key]['exe']:
                    #TODO Optimize the reshaping functionality!
                    self._buckets[self.curr_bucket_key]['exe'][tuple(data_shapes.items())] = \
                        self.exe.reshape(partial_shaping=True, allow_up_sizing=True, **data_shapes)
                    self._buckets[self.curr_bucket_key]['data_shapes'] = data_shapes
                else:
                    self._buckets[self.curr_bucket_key]['data_shapes'] = data_shapes
            return
        # 2. If the bucket key does not exist, create new symbol + executor
        assert data_shapes is not None, "Must set data_shapes for new bucket!"
        if isinstance(self.sym_gen, mx.symbol.Symbol):
            sym = self.sym_gen
        else:
            sym = self.sym_gen(**dict(self.curr_bucket_key))
        arg_names = sym.list_arguments()
        aux_names = sym.list_auxiliary_states()
        param_names = [n for n in arg_names
                       if n in self.learn_init_keys or (n not in data_shapes.keys())]
        for k, v in data_shapes.items():
            assert isinstance(v, tuple), "Data_shapes must be tuple! Find k=%s, v=%s, " \
                                         "data_shapes=%s" % (k, str(v), str(data_shapes))
        arg_shapes, _, aux_shapes = sym.infer_shape(**data_shapes)
        arg_name_shape = OrderedDict([(k, s) for k, s in zip(arg_names, arg_shapes)])
        if self.params is None:
            self.params = OrderedDict([(n, nd.empty(arg_name_shape[n], ctx=self.ctx))
                                       for n in param_names])
            self.params_grad = OrderedDict([(n, nd.empty(arg_name_shape[n], ctx=self.ctx))
                                            for n in param_names])
            if len(self.params) > 0:
                assert self.initializer is not None, \
                    'We must set the initializer if we donnot initialize' \
                    'manually the free parameters of the network!!'
            for k, v in self.params.items():
                self.initializer(k, v)
        else:
            assert set(arg_name_shape.items()) == \
                   set(data_shapes.items() + [(k, v.shape) for k, v in self.params.items()])
        if self.aux_states is None:
            self.aux_states = OrderedDict([(k, nd.empty(s, ctx=self.ctx))
                                           for k, s in zip(aux_names, aux_shapes)])
        data_inputs = {k: mx.nd.empty(data_shapes[k], ctx=self.ctx)
                       for k in set(data_shapes.keys()) - set(self.learn_init_keys)}
        if len(self._buckets) > 0:
            shared_exe = list(list(self._buckets.values())[0]['exe'].values())[0]
        else:
            shared_exe = None
        self._buckets[self.curr_bucket_key] = {
            'exe': {tuple(data_shapes.items()):
                    sym.bind(ctx=self.ctx,
                             args=dict(self.params, **data_inputs),
                             args_grad=dict(self.params_grad.items()),
                             aux_states=self.aux_states,
                             shared_exec=shared_exe)
                    },
            'data_shapes': data_shapes,
            'sym': sym
        }