def parse_transformer_out()

in visualization/plotting.py [0:0]


def parse_transformer_out(world_size, tag, fpath, itr_scale=1):
    f_fpath = fpath.format(tag=tag)
    itr_list, ppl_list, nll_list = [[] for _ in range(world_size)], [[] for _ in range(world_size)], [[] for _ in range(world_size)]
    time_list = [[0. for _ in range(100)] for _ in range(world_size)]
    with open(f_fpath, 'r') as f:
        for line in f:
            s1_line = line + ''
            s2_line = line + ''
            if re.search('train_wall', s1_line):
                line = s1_line.split('|')
                rank = int(line[0].split(' ')[0].replace(':', ''))
                try:
                    ep = int(line[1].split(' ')[-2])
                except Exception:
                    continue
                if ep == 1: continue  # skip first epoch
                time = float(line[-1].split(' ')[-1].replace('\n', ''))
                if time > time_list[rank][ep - 2]:
                    time_list[rank][ep - 2] = time
            elif re.search('valid_nll_loss', s2_line):
                line = s2_line.split('|')
                rank = int(line[0].split(' ')[0].replace(':', ''))
                ep = int(line[1].split(' ')[-2])
                if ep == 1: continue  # skip first epoch
                itr = int(line[-2].split(' ')[-2])
                ppl = float(line[-3].split(' ')[-2])
                nll = float(line[-4].split(' ')[-2])
                itr_list[rank].append(itr)
                ppl_list[rank].append(ppl)
                nll_list[rank].append(nll)

    pdf = pd.DataFrame()
    itr_columns, ppl_columns, nll_columns, time_columns = [], [], [], []
    itr_len = min([len(itr_lr) for itr_lr in itr_list if len(itr_lr) != 0])
    for r in range(world_size):
        if len(itr_list[r]) == 0:
            continue
        rtag = 'itr' + str(r)
        pdf[rtag] = itr_list[r][:itr_len]
        itr_columns.append(rtag)
        rtag = 'ppl' + str(r)
        pdf[rtag] = ppl_list[r][:itr_len]
        ppl_columns.append(rtag)
        rtag = 'nll' + str(r)
        pdf[rtag] = nll_list[r][:itr_len]
        nll_columns.append(rtag)
        rtag = 'time' + str(r)
        pdf[rtag] = time_list[r][:itr_len]
        time_columns.append(rtag)

    pdf['itr'] = pdf[itr_columns].mean(axis=1)
    pdf['ppl'] = pdf[ppl_columns].mean(axis=1)
    pdf['nll'] = pdf[nll_columns].mean(axis=1)
    pdf['time'] = pdf[time_columns].mean(axis=1)
    print(pdf.head())
    return pdf