def train_qa_retriever_joint_epoch()

in longform-qa/lfqa_utils.py [0:0]


def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, scheduler, args, e=0):
    model.train()
    model_collate_fn = functools.partial(
        make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
    )
    # make iterator
    train_samplers = [RandomSampler(dataset) for dataset in dataset_list]
    data_loaders = [
        DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
        for dataset, train_sampler in zip(dataset_list, train_samplers)
    ]
    iterators = [iter(dloader) for dloader in data_loaders]
    joint_iter = zip(*iterators)
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    for step, (batches,) in enumerate(zip(joint_iter)):
        for batch in batches:
            q_ids, q_mask, a_ids, a_mask = batch
            loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
            # optimizer
            loss.backward()
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            # some printing within the epoch
            loc_loss += loss.item()
            loc_steps += 1
        if step % args.print_freq == 0:
            print(
                "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                    e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time,
                )
            )
            loc_loss = 0
            loc_steps = 0