def train()

in benchmarks/experimental/offload.py [0:0]


def train(model_config, model, benchmark_config, model_specs, args):
    device = torch.device("cuda")
    torch.cuda.set_device(0)

    lm_dataloader, _, _ = model_config["data"]
    criterion = benchmark_config["criterion"]
    vocab_size = model_specs["vocab_size"]
    optimizer = model_config["optimizer"]

    model.train()
    log_number_of_parameters(model)

    total_loss = 0.0
    word_counter = 0

    optimizer = optimizer(model.parameters())

    total_tokens = 0
    total_tokens_per_log_interval = 0
    bptt = 2
    start_time = time.time()
    epoch_start_time = 0.0

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

    for i, batch in enumerate(lm_dataloader):
        # TODO(anj): Make this a flag for both "lm" and "seq" models.
        if i == 5:
            break

        if i == 1:
            epoch_start_time = time.time()

        source, target = get_batch(batch)
        source, target = source.cuda(), target.cuda()

        if i > 0:
            total_tokens += source.numel()

        with _get_profiler_context(args.use_profiler) as prof:
            optimizer.zero_grad()
            with _get_profiler_record_context("FW pass", args.use_profiler):
                output = model(source)
            with _get_profiler_record_context("Loss", args.use_profiler):
                loss = criterion(output.view(-1, vocab_size), target.view(-1))
            with _get_profiler_record_context("BW pass", args.use_profiler):
                loss.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
            with _get_profiler_record_context("Opt step", args.use_profiler):
                optimizer.step()

        total_loss += loss.item()
        log_interval = 1
        total_tokens_per_log_interval += source.numel()
        if i % log_interval == 0 and i > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print(
                "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                    i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                )
            )
            total_tokens_per_log_interval = 0
            total_loss = 0
            start_time = time.time()
        if args.use_profiler:
            prof.export_chrome_trace("/tmp/offload_prof")

    if epoch_start_time != 0:
        wps = total_tokens / (time.time() - epoch_start_time)
    else:
        raise RuntimeError(
            "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
        )
    return wps, loss.item()