def format_shapes()

in pyro/poutine/trace_struct.py [0:0]


    def format_shapes(self, title='Trace Shapes:', last_site=None):
        """
        Returns a string showing a table of the shapes of all sites in the
        trace.
        """
        if not self.nodes:
            return title
        rows = [[title]]

        rows.append(['Param Sites:'])
        for name, site in self.nodes.items():
            if site["type"] == "param":
                rows.append([name, None] + [str(size) for size in site["value"].shape])
            if name == last_site:
                break

        rows.append(['Sample Sites:'])
        for name, site in self.nodes.items():
            if site["type"] == "sample":
                # param shape
                batch_shape = getattr(site["fn"], "batch_shape", ())
                event_shape = getattr(site["fn"], "event_shape", ())
                rows.append([name + " dist", None] + [str(size) for size in batch_shape] +
                            ["|", None] + [str(size) for size in event_shape])

                # value shape
                event_dim = len(event_shape)
                shape = getattr(site["value"], "shape", ())
                batch_shape = shape[:len(shape) - event_dim]
                event_shape = shape[len(shape) - event_dim:]
                rows.append(["value", None] + [str(size) for size in batch_shape] +
                            ["|", None] + [str(size) for size in event_shape])

                # log_prob shape
                if "log_prob" in site:
                    batch_shape = getattr(site["log_prob"], "shape", ())
                    rows.append(["log_prob", None] + [str(size) for size in batch_shape] + ["|", None])
            if name == last_site:
                break

        return _format_table(rows)