def forward()

in crypten/nn/module.py [0:0]


    def forward(self, *args):
        assert len(args) == len(
            self.input_names
        ), f"Expected {len(self.input_names)} inputs but received {len(args)}."

        # keep track of all values that have been computed:
        values = {self.input_names[idx]: args[idx] for idx in range(len(args))}
        computed = {key: False for key in self._graph.keys()}
        inputs_available = {
            key: [False for _ in range(len(value_list))]
            for key, value_list in self._graph.items()
        }

        def _mark_as_computed(name):
            """Marks a value as having been computed."""
            computed[name] = True
            for key, value_list in self._graph.items():
                if name in value_list:
                    inputs_available[key][value_list.index(name)] = True

        def _find_computable_node():
            """Find a node for which all inputs are available."""
            for key, inputs_available_list in inputs_available.items():
                if all(inputs_available_list) and not computed[key]:
                    return key
            return None

        def _clear_unused_values():
            """Clear values that are no longer needed (to save memory)."""
            remove_keys = []
            for remove_key in values.keys():
                can_be_removed = True

                # we cannot remove a value if it is still needed:
                for key, value_list in self._graph.items():
                    if not computed[key] and remove_key in value_list:
                        can_be_removed = False
                        break
                if can_be_removed:
                    remove_keys.append(remove_key)

            # remove all values we no longer need:
            for remove_key in remove_keys:
                del values[remove_key]
            # NOTE: We maintain inputs_available[remove_key] as True to
            # prevent re-computation of the node.

        # perform forward pass:
        for input_name in self.input_names:
            _mark_as_computed(input_name)
        node_to_compute = _find_computable_node()
        while node_to_compute is not None:

            # compute output of module:
            input = [values[name] for name in self._graph[node_to_compute]]
            if len(input) == 1:
                input = input[0]  # unpack iterable if possible
            module = self._modules[node_to_compute]
            output = module(input)

            # we may get one output:
            output_names = getattr(module, "_output_names", None)
            if output_names is None or len(output_names) == 1:
                if output_names is not None:
                    assert output_names[0] == node_to_compute, "invalid graph"
                values[node_to_compute] = output
                _mark_as_computed(node_to_compute)

            # or multiple outputs:
            else:
                assert isinstance(
                    output, tuple
                ), f"expected outputs {output_names} of {module} to be tuple, not {type(output)}"
                assert len(output_names) == len(
                    output
                ), f"expected {len(output_names)} outputs from {module}, received {len(output)}"
                for node, value in zip(output_names, output):
                    values[node] = value
                    _mark_as_computed(node)

            # return output if it is available:
            if all(computed[output_name] for output_name in self.output_names):
                result = [values[output_name] for output_name in self.output_names]
                return result[0] if len(result) == 1 else tuple(result)

            # find next node to compute:
            node_to_compute = _find_computable_node()

            # clean up values we no longer need:
            _clear_unused_values()

        # this should never happen:
        raise ValueError("nn.Graph.forward() failed. Is graph unconnected?")