in tinynn/graph/tracer.py [0:0]
def __gen_forward_code(self, inplace=False) -> str:
"""Generates the code for the forward function for a `nn.Module`"""
lines = [f" def forward(self, {','.join([i.unique_name for i in self.input_nodes])}):"]
mod_name_dict = {}
for node in self.forward_nodes:
output = ", ".join([node.unique_name])
param = ", ".join([node.prev_node_unique_name(i, inplace) for i in range(len(node.prev_nodes))])
if type(node.module) is TraceFunction:
full_name = node.full_name()
if not full_name.startswith('torch.') and not full_name.startswith('self.') and '.' in full_name:
ns = '.'.join(full_name.split('.')[:-1])
self.used_namespaces.add(ns)
first_arg = None
if node.is_class():
first_arg = node.prev_node_unique_name(0, inplace)
if node.type().startswith('__i') and node.type().endswith('__'):
inner_op = node.module.func_type[3:-2]
if inner_op in SPECIAL_OPERATORS:
node.module.func_type = f'__{inner_op}__'
parts = node.module.full_name.split('.')[:-1] + [node.module.func_type]
node.module.full_name = '.'.join(parts)
if first_arg is not None:
alias = first_arg
else:
alias = node.module.get_tensor_name(0, inplace)
node.module.add_alias(alias)
aliases = node.module.get_aliases()
prefix = ''
if aliases is not None:
prefix = ''.join([f'{x} = ' for x in aliases])
line = f" {prefix}{output} = {node.module.extra_expr(first=first_arg, original=inplace)}"
else:
if inplace:
mod_name = node.original_name
else:
mod_name_dict.setdefault(node.module, node.unique_name)
mod_name = mod_name_dict[node.module]
if len(node.prev_tensors) == 0 and len(node.next_tensors) == 0:
continue
if node.type() is nn.LSTM and len(node.prev_nodes) == 3 and len(node.prev_tensors) == 3:
first_arg = node.prev_node_unique_name(0)
param = ", ".join([node.prev_node_unique_name(i) for i in range(1, len(node.prev_nodes))])
line = f" {output} = self.{mod_name}({first_arg}, ({param}))"
else:
line = f" {output} = self.{mod_name}({param})"
lines.append(line)
for pn in {pn.unique_name: pn for pn in node.prev_nodes}.values():
if node.forward_order == max([n.forward_order for n in pn.next_nodes]):
if pn.type() not in (ConstantNode, torch.nn.quantized.FloatFunctional):
lines.append(f" {pn.unique_name} = None")
def _gen_output_node(node):
if node.rev_index:
return f'[{", ".join([node.prev_node_unique_name(i) for i in range(len(node.prev_nodes))])}]'
else:
return node.prev_node_unique_name(0)
lines.append(f" return {', '.join([_gen_output_node(i) for i in self.output_nodes])}")
block = "\n".join(lines)
return block