in retrieval_train.py [0:0]
def train_model(opt_):
env = TrainEnvironment(opt_)
dictionary = env.dict
if opt_.load_checkpoint:
net, dictionary = load_model(opt_.load_checkpoint, opt_)
env = TrainEnvironment(opt_, dictionary)
env.dict = dictionary
else:
net = create_model(opt_, dictionary["words"])
if opt_.embeddings and opt_.embeddings != "None":
load_embeddings(opt_, dictionary["words"], net)
paramnum = 0
trainable = 0
for name, parameter in net.named_parameters():
if parameter.requires_grad:
trainable += parameter.numel()
paramnum += parameter.numel()
print("TRAINABLE", paramnum, trainable)
if opt_.cuda:
net = torch.nn.DataParallel(net)
net = net.cuda()
if opt_.optimizer == "adamax":
lr = opt_.learning_rate or 0.002
named_params_to_optimize = filter(
lambda p: p[1].requires_grad, net.named_parameters()
)
params_to_optimize = (p[1] for p in named_params_to_optimize)
optimizer = optim.Adamax(params_to_optimize, lr=lr)
if opt_.epoch_start != 0:
saved_params = torch.load(
opt_.load_checkpoint, map_location=lambda storage, loc: storage
)
optimizer.load_state_dict(saved_params["optim_dict"])
else:
lr = opt_.learning_rate or 0.01
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, net.parameters()), lr=lr
)
start_time = time.time()
best_loss = float("+inf")
test_data_shuffled = env.build_valid_dataloader(True)
test_data_not_shuffled = env.build_valid_dataloader(False)
with torch.no_grad():
validate(
0,
net,
test_data_shuffled,
nb_candidates=opt_.hits_at_nb_cands,
shuffled_str="shuffled",
)
train_data = None
for epoch in range(opt_.epoch_start, opt_.num_epochs):
if train_data is None or opt_.dataset_name == "reddit":
train_data = env.build_train_dataloader(epoch)
train(epoch, start_time, net, optimizer, opt_, train_data)
with torch.no_grad():
# We compute the loss both for shuffled and not shuffled case.
# however, the loss that determines if the model is better is the
# same as the one used for training.
loss_shuffled = validate(
epoch,
net,
test_data_shuffled,
nb_candidates=opt_.hits_at_nb_cands,
shuffled_str="shuffled",
)
loss_not_shuffled = validate(
epoch,
net,
test_data_not_shuffled,
nb_candidates=opt_.hits_at_nb_cands,
shuffled_str="not-shuffled",
)
if opt_.no_shuffle:
loss = loss_not_shuffled
else:
loss = loss_shuffled
if loss < best_loss:
best_loss = loss
best_loss_epoch = epoch
logging.info(f"New best loss, saving model to {opt_.model_file}")
save_model(opt_.model_file, net, dictionary, optimizer)
# Stop if it's been too many epochs since the loss has decreased
if opt_.stop_crit_num_epochs != -1:
if epoch - best_loss_epoch >= opt_.stop_crit_num_epochs:
break
return net, dictionary