def forward()

in archs/models.py [0:0]


    def forward(self, x, sampled_g=None, t=None, return_feat=False):
        """TODO: Docstring for forward.

        :x: Input data
        :g: Gating tensor (#Task x )#num_layer x #num_mods x #num_mods
        :t: task ID
        :returns: TODO

        """

        if t is None:
            t_not_set = True
            t = torch.tensor([0] * x.shape[0], dtype=x.dtype).long()
        else:
            t_not_set = False
            t = t.squeeze()

        if self._start_modules is not None:
            prev_out = [mod(x) for mod in self._start_modules]
        else:
            prev_out = [x]

        if sampled_g is None:
            # NON-Gated Module network
            prev_out = sum(prev_out) / float(len(prev_out))
            #prev_out = torch.mean(prev_out, 0)
            for li in range(len(self._module_list)):
                prev_out = sum([
                    mod(prev_out) for mod in self._module_list[li]
                ]) / float(len(self._module_list[li]))
            features = prev_out
            if self._end_modules is not None:
                if t_not_set or self.single_head:
                    prev_out = self._end_modules[0](prev_out)
                else:
                    prev_out = torch.cat([
                        self._end_modules[tid](prev_out[bi:bi + 1])
                        for bi, tid in enumerate(t)
                    ], 0)
            if return_feat:
                return prev_out, features
            return prev_out
        else:
            # Forward prop with sampled Gs
            for li in range(len(self._module_list)):
                curr_out = []
                for j in range(len(self._module_list[li])):
                    gind = j if not self._chain else 0
                    # Dim: #Batch x C
                    module_in_wt = sampled_g[li + 1][gind]
                    # Module input weights rearranged to match inputs
                    module_in_wt = module_in_wt.transpose(0, 1)
                    add_dims = prev_out[0].dim() + 1 - module_in_wt.dim()
                    module_in_wt = module_in_wt.view(*module_in_wt.shape,
                                                     *([1] * add_dims))
                    module_in_wt = module_in_wt.expand(
                        len(prev_out), *prev_out[0].shape)
                    module_in = sum([
                        module_in_wt[i] * prev_out[i]
                        for i in range(len(prev_out))
                    ])
                    mod = self._module_list[li][j]
                    curr_out.append(mod(module_in))
                prev_out = curr_out

            # Output modules (with sampled Gs)
            if self._end_modules is not None:
                li = self.num_layers - 1
                if t_not_set or self.single_head:
                    # Dim: #Batch x C
                    module_in_wt = sampled_g[li + 1][0]
                    # Module input weights rearranged to match inputs
                    module_in_wt = module_in_wt.transpose(0, 1)
                    add_dims = prev_out[0].dim() + 1 - module_in_wt.dim()
                    module_in_wt = module_in_wt.view(*module_in_wt.shape,
                                                     *([1] * add_dims))
                    module_in_wt = module_in_wt.expand(
                        len(prev_out), *prev_out[0].shape)
                    module_in = sum([
                        module_in_wt[i] * prev_out[i]
                        for i in range(len(prev_out))
                    ])
                    features = module_in
                    prev_out = self._end_modules[0](module_in)
                else:
                    curr_out = []
                    for bi, tid in enumerate(t):
                        # Dim: #Batch x C
                        gind = tid if not self._chain else 0
                        module_in_wt = sampled_g[li + 1][gind]
                        # Module input weights rearranged to match inputs
                        module_in_wt = module_in_wt.transpose(0, 1)
                        add_dims = prev_out[0].dim() + 1 - module_in_wt.dim()
                        module_in_wt = module_in_wt.view(
                            *module_in_wt.shape, *([1] * add_dims))
                        module_in_wt = module_in_wt.expand(
                            len(prev_out), *prev_out[0].shape)
                        module_in = sum([
                            module_in_wt[i] * prev_out[i]
                            for i in range(len(prev_out))
                        ])
                        features = module_in
                        mod = self._end_modules[tid]
                        curr_out.append(mod(module_in[bi:bi + 1]))
                    prev_out = curr_out
                    prev_out = torch.cat(prev_out, 0)
            if return_feat:
                return prev_out, features
            return prev_out