def init_instance()

in automl21/accel/neural_rec.py [0:0]


    def init_instance(self, init_x, context):
        single = init_x.dim() == 1
        if single:
            init_x = init_x.unsqueeze(0)
            if context is not None:
                context = context.unsqueeze(0)

        assert init_x.dim() == 2
        if context is not None:
            assert context.dim() == 2
            assert init_x.size(0) == context.size(0)

        n_batch = init_x.size(0)

        h = torch.zeros(n_batch, self.rec_n_layers * self.rec_n_hidden,
                        dtype=init_x.dtype, device=init_x.device)
        c = torch.zeros(n_batch, self.rec_n_layers * self.rec_n_hidden,
                        dtype=init_x.dtype, device=init_x.device)
        if self.learn_init_hidden or self.learn_init_iterate:
            assert context is not None
            z = self.init_hidden_net(context)
            if self.learn_init_hidden:
                if self.learn_init_iterate:
                    # need to explicitly create a sections array because
                    # rec_n_hidden can be smaller than iterate size
                    hidden_layer_product = self.rec_n_layers * self.rec_n_hidden
                    sections = [hidden_layer_product, hidden_layer_product,
                                z.size(-1) - 2 * hidden_layer_product]
                    h, c, new_init_x = z.split(sections, dim=-1)
                else:
                    h, c = z.split(self.rec_n_layers * self.rec_n_hidden, dim=-1)
            else:
                new_init_x = z

        if self.learn_init_iterate:
            if self.learn_init_iterate_delta:
                init_x = init_x + new_init_x
            else:
                init_x = new_init_x

        if single:
            init_x = init_x.squeeze(0)

        h = self._extract_layered_hidden_state(h)
        c = self._extract_layered_hidden_state(c)

        return init_x, NeuralLSTMHidden(h, c)