in crypten/nn/module.py [0:0]
def forward(self, *args):
assert len(args) == len(
self.input_names
), f"Expected {len(self.input_names)} inputs but received {len(args)}."
# keep track of all values that have been computed:
values = {self.input_names[idx]: args[idx] for idx in range(len(args))}
computed = {key: False for key in self._graph.keys()}
inputs_available = {
key: [False for _ in range(len(value_list))]
for key, value_list in self._graph.items()
}
def _mark_as_computed(name):
"""Marks a value as having been computed."""
computed[name] = True
for key, value_list in self._graph.items():
if name in value_list:
inputs_available[key][value_list.index(name)] = True
def _find_computable_node():
"""Find a node for which all inputs are available."""
for key, inputs_available_list in inputs_available.items():
if all(inputs_available_list) and not computed[key]:
return key
return None
def _clear_unused_values():
"""Clear values that are no longer needed (to save memory)."""
remove_keys = []
for remove_key in values.keys():
can_be_removed = True
# we cannot remove a value if it is still needed:
for key, value_list in self._graph.items():
if not computed[key] and remove_key in value_list:
can_be_removed = False
break
if can_be_removed:
remove_keys.append(remove_key)
# remove all values we no longer need:
for remove_key in remove_keys:
del values[remove_key]
# NOTE: We maintain inputs_available[remove_key] as True to
# prevent re-computation of the node.
# perform forward pass:
for input_name in self.input_names:
_mark_as_computed(input_name)
node_to_compute = _find_computable_node()
while node_to_compute is not None:
# compute output of module:
input = [values[name] for name in self._graph[node_to_compute]]
if len(input) == 1:
input = input[0] # unpack iterable if possible
module = self._modules[node_to_compute]
output = module(input)
# we may get one output:
output_names = getattr(module, "_output_names", None)
if output_names is None or len(output_names) == 1:
if output_names is not None:
assert output_names[0] == node_to_compute, "invalid graph"
values[node_to_compute] = output
_mark_as_computed(node_to_compute)
# or multiple outputs:
else:
assert isinstance(
output, tuple
), f"expected outputs {output_names} of {module} to be tuple, not {type(output)}"
assert len(output_names) == len(
output
), f"expected {len(output_names)} outputs from {module}, received {len(output)}"
for node, value in zip(output_names, output):
values[node] = value
_mark_as_computed(node)
# return output if it is available:
if all(computed[output_name] for output_name in self.output_names):
result = [values[output_name] for output_name in self.output_names]
return result[0] if len(result) == 1 else tuple(result)
# find next node to compute:
node_to_compute = _find_computable_node()
# clean up values we no longer need:
_clear_unused_values()
# this should never happen:
raise ValueError("nn.Graph.forward() failed. Is graph unconnected?")