in models/expire_span.py [0:0]
def log(args, model, logger, stat_train):
x = []
x_max = 0
for l in model.module.layers:
if l.args.expire_span:
s = l.attn.attn.expire_span.avg_span_log
l.attn.attn.expire_span.avg_span_log = []
x += s
if hasattr(l.attn.attn.expire_span, "max_span_log"):
x_max = max(x_max, l.attn.attn.expire_span.max_span_log)
l.attn.attn.expire_span.max_span_log = 0
if len(x) > 0:
x = sum(x) / len(x)
logger.log("adapt_span/avg", x)
logger.log("adapt_span/max", x_max)