def log()

in models/expire_span.py [0:0]


def log(args, model, logger, stat_train):
    x = []
    x_max = 0
    for l in model.module.layers:
        if l.args.expire_span:
            s = l.attn.attn.expire_span.avg_span_log
            l.attn.attn.expire_span.avg_span_log = []
            x += s
            if hasattr(l.attn.attn.expire_span, "max_span_log"):
                x_max = max(x_max, l.attn.attn.expire_span.max_span_log)
                l.attn.attn.expire_span.max_span_log = 0
    if len(x) > 0:
        x = sum(x) / len(x)
        logger.log("adapt_span/avg", x)
        logger.log("adapt_span/max", x_max)