def train_test()

in src/weakly_sup.py [0:0]


def train_test(configs, logging_steps=50000):
    setup_configs(configs)
    os.system(f'mkdir -p {configs.save_path}')
    torch.save(configs, configs.save_path + '/configs.pt')
    # prepare feature extractor
    info = collect_bitext_stats(
        configs.bitext_path, configs.align_path, configs.save_path, configs.src_lang, configs.trg_lang, configs.reversed)  
    # dataset
    train_lexicon = load_lexicon(configs.tuning_set)
    sim_train_lexicon = {(to_simplified(x[0]), to_simplified(x[1])) for x in train_lexicon}
    all_train_lexicon = train_lexicon.union(sim_train_lexicon)
    test_lexicon = load_lexicon(configs.test_set)
    pos_training_set, neg_training_set, test_set = extract_dataset(
        train_lexicon, test_lexicon, info[2], configs
    )
    training_set_modifier = max(1, len(neg_training_set) // len(pos_training_set))
    training_set = pos_training_set * training_set_modifier + neg_training_set
    print(f'Positive training set is repeated {training_set_modifier} times due to data imbalance.')
    # model and optimizers
    criss = CRISSWrapper(device=configs.device)
    lexicon_inducer = LexiconInducer(7, configs.hiddens, 1, 5).to(configs.device)
    optimizer = torch.optim.Adam(lexicon_inducer.parameters(), lr=.0005)
    # train model
    for epoch in range(configs.epochs):
        model_path = configs.save_path + f'/{epoch}.model.pt'
        if os.path.exists(model_path):
            lexicon_inducer.load_state_dict(torch.load(model_path))
            continue
        random.shuffle(training_set)
        bar = tqdm(range(0, len(training_set), configs.batch_size))
        total_loss = total_cnt = 0
        for i, sid in enumerate(bar):
            batch = training_set[sid:sid+configs.batch_size]
            probs = extract_probs(batch, criss, lexicon_inducer, info, configs)
            targets = torch.tensor(
                [1 if tuple(x) in all_train_lexicon else 0 for x in batch]).float().to(configs.device)
            optimizer.zero_grad()
            loss = nn.BCELoss()(probs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * len(batch)
            total_cnt += len(batch)
            bar.set_description(f'loss={total_loss / total_cnt:.5f}')
            if (i + 1) % logging_steps == 0:
                print(f'Epoch {epoch}, step {i+1}, loss = {total_loss / total_cnt:.5f}', flush=True)
                torch.save(lexicon_inducer.state_dict(), configs.save_path + f'/{epoch}.{i+1}.model.pt')
        print(f'Epoch {epoch}, loss = {total_loss / total_cnt:.5f}', flush=True)
    torch.save(lexicon_inducer.state_dict(), configs.save_path + f'/model.pt')
    best_threshold, best_n_cand = get_optimal_parameters(
        pos_training_set, neg_training_set, train_lexicon, criss,
        lexicon_inducer, info, configs,
    )
    induced_test_lexicon, test_eval = get_test_lexicon(
        test_set, test_lexicon, criss, lexicon_inducer, info, configs, best_threshold, best_n_cand
    )
    with open(configs.save_path + '/induced.weaklysup.dict', 'w') as fout:
        for item in induced_test_lexicon:
            fout.write('\t'.join([str(x) for x in item]) + '\n')
        fout.close()
    return induced_test_lexicon, test_eval