in visualization/plotting.py [0:0]
def parse_transformer_out(world_size, tag, fpath, itr_scale=1):
f_fpath = fpath.format(tag=tag)
itr_list, ppl_list, nll_list = [[] for _ in range(world_size)], [[] for _ in range(world_size)], [[] for _ in range(world_size)]
time_list = [[0. for _ in range(100)] for _ in range(world_size)]
with open(f_fpath, 'r') as f:
for line in f:
s1_line = line + ''
s2_line = line + ''
if re.search('train_wall', s1_line):
line = s1_line.split('|')
rank = int(line[0].split(' ')[0].replace(':', ''))
try:
ep = int(line[1].split(' ')[-2])
except Exception:
continue
if ep == 1: continue # skip first epoch
time = float(line[-1].split(' ')[-1].replace('\n', ''))
if time > time_list[rank][ep - 2]:
time_list[rank][ep - 2] = time
elif re.search('valid_nll_loss', s2_line):
line = s2_line.split('|')
rank = int(line[0].split(' ')[0].replace(':', ''))
ep = int(line[1].split(' ')[-2])
if ep == 1: continue # skip first epoch
itr = int(line[-2].split(' ')[-2])
ppl = float(line[-3].split(' ')[-2])
nll = float(line[-4].split(' ')[-2])
itr_list[rank].append(itr)
ppl_list[rank].append(ppl)
nll_list[rank].append(nll)
pdf = pd.DataFrame()
itr_columns, ppl_columns, nll_columns, time_columns = [], [], [], []
itr_len = min([len(itr_lr) for itr_lr in itr_list if len(itr_lr) != 0])
for r in range(world_size):
if len(itr_list[r]) == 0:
continue
rtag = 'itr' + str(r)
pdf[rtag] = itr_list[r][:itr_len]
itr_columns.append(rtag)
rtag = 'ppl' + str(r)
pdf[rtag] = ppl_list[r][:itr_len]
ppl_columns.append(rtag)
rtag = 'nll' + str(r)
pdf[rtag] = nll_list[r][:itr_len]
nll_columns.append(rtag)
rtag = 'time' + str(r)
pdf[rtag] = time_list[r][:itr_len]
time_columns.append(rtag)
pdf['itr'] = pdf[itr_columns].mean(axis=1)
pdf['ppl'] = pdf[ppl_columns].mean(axis=1)
pdf['nll'] = pdf[nll_columns].mean(axis=1)
pdf['time'] = pdf[time_columns].mean(axis=1)
print(pdf.head())
return pdf