in align/train.py [0:0]
def train(configs, logging_steps=50000):
setup_configs(configs)
os.system(f'mkdir -p {configs.save_path}')
torch.save(configs, configs.save_path + '/configs.pt')
info = collect_bitext_stats(
configs.bitext_path, configs.align_path, configs.save_path,
configs.src_lang, configs.trg_lang, configs.reversed
)
if configs.use_criss:
criss = CRISSWrapper(device=configs.device)
else:
criss = None
dataset = BitextAlignmentDataset(configs.bitext_path, configs.align_path)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=dataset.collate_fn)
aligner = WordAligner(5 + (2 if configs.use_criss else 0), configs.hiddens, 3, 5).to(configs.device)
optimizer = torch.optim.Adam(aligner.parameters(), lr=.0005)
for epoch in range(configs.epochs):
model_cnt = 0
total_loss = total_cnt = 0
bar = tqdm(dataloader)
for idx, batch in enumerate(bar):
(ss, ts), edges = batch[0]
if criss is not None:
semb = criss.embed(ss, langcode=configs.src_lang)
temb = criss.embed(ts, langcode=configs.trg_lang)
cos_matrix = cos(semb.unsqueeze(1), temb.unsqueeze(0)).unsqueeze(-1).unsqueeze(-1)
ip_matrix = (semb.unsqueeze(1) * temb.unsqueeze(0)).sum(-1).unsqueeze(-1).unsqueeze(-1)
feat_matrix = torch.cat((cos_matrix, ip_matrix), dim=-1)
# adding contexualized embeddings here
training_sets = collections.defaultdict(list)
criss_features = collections.defaultdict(list)
for i, sw in enumerate(ss):
for j, tw in enumerate(ts):
label = edges[i, j]
training_sets[label].append((sw, tw))
if criss is not None:
criss_features[label].append(feat_matrix[i, j])
max_len = max(len(training_sets[k]) for k in training_sets)
training_set = list()
criss_feats = list()
targets = list()
for key in training_sets:
training_set += training_sets[key] * (max_len // len(training_sets[key]))
criss_feats += criss_features[key] * (max_len // len(training_sets[key]))
targets += [key] * len(training_sets[key]) * (max_len // len(training_sets[key]))
targets = torch.tensor(targets).long().to(configs.device)
scores = extract_scores(training_set, criss_feats, aligner, info, configs)
optimizer.zero_grad()
loss = nn.CrossEntropyLoss()(scores, targets)
loss.backward()
optimizer.step()
total_loss += loss.item() * len(batch)
total_cnt += len(batch)
bar.set_description(f'loss={total_loss / total_cnt:.5f}')
if (idx + 1) % logging_steps == 0:
print(f'Epoch {epoch}, step {idx+1}, loss = {total_loss / total_cnt:.5f}', flush=True)
torch.save(aligner.state_dict(), configs.save_path + f'/model.pt')