in pyro/poutine/trace_struct.py [0:0]
def format_shapes(self, title='Trace Shapes:', last_site=None):
"""
Returns a string showing a table of the shapes of all sites in the
trace.
"""
if not self.nodes:
return title
rows = [[title]]
rows.append(['Param Sites:'])
for name, site in self.nodes.items():
if site["type"] == "param":
rows.append([name, None] + [str(size) for size in site["value"].shape])
if name == last_site:
break
rows.append(['Sample Sites:'])
for name, site in self.nodes.items():
if site["type"] == "sample":
# param shape
batch_shape = getattr(site["fn"], "batch_shape", ())
event_shape = getattr(site["fn"], "event_shape", ())
rows.append([name + " dist", None] + [str(size) for size in batch_shape] +
["|", None] + [str(size) for size in event_shape])
# value shape
event_dim = len(event_shape)
shape = getattr(site["value"], "shape", ())
batch_shape = shape[:len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim:]
rows.append(["value", None] + [str(size) for size in batch_shape] +
["|", None] + [str(size) for size in event_shape])
# log_prob shape
if "log_prob" in site:
batch_shape = getattr(site["log_prob"], "shape", ())
rows.append(["log_prob", None] + [str(size) for size in batch_shape] + ["|", None])
if name == last_site:
break
return _format_table(rows)