def __gen_forward_code()

in tinynn/graph/tracer.py [0:0]


    def __gen_forward_code(self, inplace=False) -> str:
        """Generates the code for the forward function for a `nn.Module`"""
        lines = [f"    def forward(self, {','.join([i.unique_name for i in self.input_nodes])}):"]

        mod_name_dict = {}
        for node in self.forward_nodes:
            output = ", ".join([node.unique_name])
            param = ", ".join([node.prev_node_unique_name(i, inplace) for i in range(len(node.prev_nodes))])

            if type(node.module) is TraceFunction:
                full_name = node.full_name()
                if not full_name.startswith('torch.') and not full_name.startswith('self.') and '.' in full_name:
                    ns = '.'.join(full_name.split('.')[:-1])
                    self.used_namespaces.add(ns)
                first_arg = None
                if node.is_class():
                    first_arg = node.prev_node_unique_name(0, inplace)
                if node.type().startswith('__i') and node.type().endswith('__'):
                    inner_op = node.module.func_type[3:-2]
                    if inner_op in SPECIAL_OPERATORS:
                        node.module.func_type = f'__{inner_op}__'
                        parts = node.module.full_name.split('.')[:-1] + [node.module.func_type]
                        node.module.full_name = '.'.join(parts)
                        if first_arg is not None:
                            alias = first_arg
                        else:
                            alias = node.module.get_tensor_name(0, inplace)
                        node.module.add_alias(alias)
                aliases = node.module.get_aliases()
                prefix = ''
                if aliases is not None:
                    prefix = ''.join([f'{x} = ' for x in aliases])
                line = f"        {prefix}{output} = {node.module.extra_expr(first=first_arg, original=inplace)}"
            else:
                if inplace:
                    mod_name = node.original_name
                else:
                    mod_name_dict.setdefault(node.module, node.unique_name)
                    mod_name = mod_name_dict[node.module]
                if len(node.prev_tensors) == 0 and len(node.next_tensors) == 0:
                    continue
                if node.type() is nn.LSTM and len(node.prev_nodes) == 3 and len(node.prev_tensors) == 3:
                    first_arg = node.prev_node_unique_name(0)
                    param = ", ".join([node.prev_node_unique_name(i) for i in range(1, len(node.prev_nodes))])
                    line = f"        {output} = self.{mod_name}({first_arg}, ({param}))"
                else:
                    line = f"        {output} = self.{mod_name}({param})"

            lines.append(line)

            for pn in {pn.unique_name: pn for pn in node.prev_nodes}.values():
                if node.forward_order == max([n.forward_order for n in pn.next_nodes]):
                    if pn.type() not in (ConstantNode, torch.nn.quantized.FloatFunctional):
                        lines.append(f"        {pn.unique_name} = None")

        def _gen_output_node(node):
            if node.rev_index:
                return f'[{", ".join([node.prev_node_unique_name(i) for i in range(len(node.prev_nodes))])}]'
            else:
                return node.prev_node_unique_name(0)

        lines.append(f"        return {', '.join([_gen_output_node(i) for i in self.output_nodes])}")

        block = "\n".join(lines)

        return block