def format_model()

in ppo_ewma/torch_util.py [0:0]


def format_model(mod, rms=False):
    """
    Return a str: a formatted table listing parameters and their sizes
    """
    import pandas

    rows = []
    ntotal = sum(p.numel() for p in mod.parameters())
    for name, param in sorted(mod.named_parameters()):
        shape = tuple(param.shape)
        numel = param.numel()
        std = "%0.5f" % float(param.std())
        crnt = [name, shape, numel, round(numel / ntotal * 100, 1), std, _rms(param)]
        rows.append(crnt)

    for name, module in mod.named_modules():
        numel = sum(p.numel() for p in module.parameters())
        if numel == 0:
            continue
        crnt = [name or "~total", "", numel, round(numel / ntotal * 100, 1), "", ""]
        rows.append(crnt)
    columns = ["path", "shape", "numel", "pct", "std", "rms"]
    if not rms:
        rows = [row[:-1] for row in rows]
        columns = columns[:-1]
    rows.sort(key=lambda x: x[0])
    df = pandas.DataFrame(rows, columns=columns)
    maxlen = df["path"].str.len().max()
    return df.to_string(
        index=False, formatters={"path": "{{:<{}s}}".format(maxlen).format}
    )