in src/weakly_sup.py [0:0]
def train_test(configs, logging_steps=50000):
setup_configs(configs)
os.system(f'mkdir -p {configs.save_path}')
torch.save(configs, configs.save_path + '/configs.pt')
# prepare feature extractor
info = collect_bitext_stats(
configs.bitext_path, configs.align_path, configs.save_path, configs.src_lang, configs.trg_lang, configs.reversed)
# dataset
train_lexicon = load_lexicon(configs.tuning_set)
sim_train_lexicon = {(to_simplified(x[0]), to_simplified(x[1])) for x in train_lexicon}
all_train_lexicon = train_lexicon.union(sim_train_lexicon)
test_lexicon = load_lexicon(configs.test_set)
pos_training_set, neg_training_set, test_set = extract_dataset(
train_lexicon, test_lexicon, info[2], configs
)
training_set_modifier = max(1, len(neg_training_set) // len(pos_training_set))
training_set = pos_training_set * training_set_modifier + neg_training_set
print(f'Positive training set is repeated {training_set_modifier} times due to data imbalance.')
# model and optimizers
criss = CRISSWrapper(device=configs.device)
lexicon_inducer = LexiconInducer(7, configs.hiddens, 1, 5).to(configs.device)
optimizer = torch.optim.Adam(lexicon_inducer.parameters(), lr=.0005)
# train model
for epoch in range(configs.epochs):
model_path = configs.save_path + f'/{epoch}.model.pt'
if os.path.exists(model_path):
lexicon_inducer.load_state_dict(torch.load(model_path))
continue
random.shuffle(training_set)
bar = tqdm(range(0, len(training_set), configs.batch_size))
total_loss = total_cnt = 0
for i, sid in enumerate(bar):
batch = training_set[sid:sid+configs.batch_size]
probs = extract_probs(batch, criss, lexicon_inducer, info, configs)
targets = torch.tensor(
[1 if tuple(x) in all_train_lexicon else 0 for x in batch]).float().to(configs.device)
optimizer.zero_grad()
loss = nn.BCELoss()(probs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item() * len(batch)
total_cnt += len(batch)
bar.set_description(f'loss={total_loss / total_cnt:.5f}')
if (i + 1) % logging_steps == 0:
print(f'Epoch {epoch}, step {i+1}, loss = {total_loss / total_cnt:.5f}', flush=True)
torch.save(lexicon_inducer.state_dict(), configs.save_path + f'/{epoch}.{i+1}.model.pt')
print(f'Epoch {epoch}, loss = {total_loss / total_cnt:.5f}', flush=True)
torch.save(lexicon_inducer.state_dict(), configs.save_path + f'/model.pt')
best_threshold, best_n_cand = get_optimal_parameters(
pos_training_set, neg_training_set, train_lexicon, criss,
lexicon_inducer, info, configs,
)
induced_test_lexicon, test_eval = get_test_lexicon(
test_set, test_lexicon, criss, lexicon_inducer, info, configs, best_threshold, best_n_cand
)
with open(configs.save_path + '/induced.weaklysup.dict', 'w') as fout:
for item in induced_test_lexicon:
fout.write('\t'.join([str(x) for x in item]) + '\n')
fout.close()
return induced_test_lexicon, test_eval