def train()

in align/train.py [0:0]


def train(configs, logging_steps=50000):
    setup_configs(configs)
    os.system(f'mkdir -p {configs.save_path}')
    torch.save(configs, configs.save_path + '/configs.pt')
    info = collect_bitext_stats(
        configs.bitext_path, configs.align_path, configs.save_path, 
        configs.src_lang, configs.trg_lang, configs.reversed
    )
    if configs.use_criss:
        criss = CRISSWrapper(device=configs.device)
    else:
        criss = None
    dataset = BitextAlignmentDataset(configs.bitext_path, configs.align_path)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=dataset.collate_fn)
    aligner = WordAligner(5 + (2 if configs.use_criss else 0), configs.hiddens, 3, 5).to(configs.device)
    optimizer = torch.optim.Adam(aligner.parameters(), lr=.0005)
    for epoch in range(configs.epochs):
        model_cnt = 0
        total_loss = total_cnt = 0
        bar = tqdm(dataloader)
        for idx, batch in enumerate(bar):
            (ss, ts), edges = batch[0]
            if criss is not None:
                semb = criss.embed(ss, langcode=configs.src_lang)
                temb = criss.embed(ts, langcode=configs.trg_lang)
                cos_matrix = cos(semb.unsqueeze(1), temb.unsqueeze(0)).unsqueeze(-1).unsqueeze(-1)
                ip_matrix = (semb.unsqueeze(1) * temb.unsqueeze(0)).sum(-1).unsqueeze(-1).unsqueeze(-1)
                feat_matrix = torch.cat((cos_matrix, ip_matrix), dim=-1)
            # adding contexualized embeddings here
            training_sets = collections.defaultdict(list)
            criss_features = collections.defaultdict(list)
            for i, sw in enumerate(ss):
                for j, tw in enumerate(ts):
                    label = edges[i, j]
                    training_sets[label].append((sw, tw))
                    if criss is not None:
                        criss_features[label].append(feat_matrix[i, j])
            max_len = max(len(training_sets[k]) for k in training_sets)
            training_set = list()
            criss_feats = list()
            targets = list()
            for key in training_sets:
                training_set += training_sets[key] * (max_len // len(training_sets[key]))
                criss_feats += criss_features[key] * (max_len // len(training_sets[key]))
                targets += [key] * len(training_sets[key]) * (max_len // len(training_sets[key]))
            targets = torch.tensor(targets).long().to(configs.device)
            scores = extract_scores(training_set, criss_feats, aligner, info, configs)
            optimizer.zero_grad()
            loss = nn.CrossEntropyLoss()(scores, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * len(batch)
            total_cnt += len(batch)
            bar.set_description(f'loss={total_loss / total_cnt:.5f}')
            if (idx + 1) % logging_steps == 0:
                print(f'Epoch {epoch}, step {idx+1}, loss = {total_loss / total_cnt:.5f}', flush=True)
    torch.save(aligner.state_dict(), configs.save_path + f'/model.pt')