def _get_streams()

in labgraph/graphs/group.py [0:0]


    def _get_streams(self) -> Dict[str, Stream]:
        """
        Returns a dictionary containing a stream for each group of topics that are
        joined to each other. Topic joins are transitive, i.e., if A is joined to B, and
        B is joined to C, then A is joined to C.
        """

        # This is basically a version of the union-find algorithm. We start by putting
        # each topic in its own stream, then for each topic pair that is joined, we merge
        # each topics' stream.

        stream_nums: Dict[str, int] = {}

        for topic_name in self.__topics__.keys():
            if PATH_DELIMITER not in topic_name:
                # Put each topic in this group in its own stream
                stream_nums[topic_name] = len(stream_nums)

        for child_name, child in self.__children__.items():
            # Preserve the streams that child modules have already computed
            for stream in child.__streams__.values():
                new_stream_num = len(stream_nums)
                for topic_path in stream.topic_paths:
                    stream_nums[
                        PATH_DELIMITER.join((child_name, topic_path))
                    ] = new_stream_num

        # Use `connections()` to merge streams
        connections = list(self.connections())
        typeguard.check_type(
            f"{self.__class__.__name__}.{self.connections.__name__}()",
            connections,
            Connections,
        )
        for topic1, topic2 in connections:
            # Find the topic paths for these `Topic` objects
            topic_paths: List[str] = []
            for topic in (topic1, topic2):
                for topic_path, descendant_topic in self.__topics__.items():
                    if topic is descendant_topic:
                        topic_paths.append(topic_path)
                        break

            assert len(topic_paths) == 2

            # Set the stream numbers of all topics in the two old streams to be
            # the same
            old_stream_nums = (stream_nums[topic_paths[0]], stream_nums[topic_paths[1]])
            new_stream_num = min(old_stream_nums)
            for topic_path, stream_num in list(stream_nums.items()):
                if stream_num in old_stream_nums:
                    stream_nums[topic_path] = new_stream_num

        # Group the topics by stream number, validate their message types, and sort the
        # topic names in each stream
        result = []
        for stream_num in set(stream_nums.values()):
            topic_paths = sorted(p for p, n in stream_nums.items() if n == stream_num)

            message_types_by_topic: Dict[str, Type[Message]] = {}
            for topic_path in topic_paths:
                message_types_by_topic[topic_path] = self.__topics__[
                    topic_path
                ].message_type

            message_types = list(message_types_by_topic.values())
            # Use == to unique instead of creating a set() which compares via `is`
            message_types = [
                m for i, m in enumerate(message_types) if message_types.index(m) == i
            ]

            # HACK: Since we don't currently use Cthulhu dynamic types, we don't know
            # what types a CPPNode's topics use. We type them as Message for now, and
            # exclude them from the topics-have-same-type check below.
            for message_type in list(message_types):
                if message_type is Message:
                    message_types.remove(message_type)

            if len(message_types) > 1:
                error_message = (
                    "Topics in stream must have matching message types, found the "
                    "following types for the same stream:\n"
                )
                for topic_path in topic_paths:
                    error_message += (
                        f"- {topic_path}: {message_types_by_topic[topic_path]}\n"
                    )
                raise LabgraphError(error_message)

            if len(message_types) == 0:
                warning_message = (
                    "Found no message types for a stream. This could be because the "
                    "topic is never used in Python. Consume a topic with a message "
                    "type in Python to make this stream work. This stream contains the "
                    "following topics:\n"
                )
                for topic_path in topic_paths:
                    warning_message += f"- {topic_path}\n"

                logger.warning(warning_message)

                result.append(Stream(topic_paths=tuple(topic_paths)))
            else:
                # The stream is valid, so create the stream object
                result.append(
                    Stream(
                        topic_paths=tuple(topic_paths),
                        message_type=list(message_types)[0],
                    )
                )

        # Index the streams by id
        return {stream.id: stream for stream in result}