def iterate_train_data()

in tbsm_pytorch.py [0:0]


def iterate_train_data(args, train_ld, val_ld, tbsm, k, use_gpu, device, writer, losses, accuracies, isMainTraining):
    # select number of batches
    if isMainTraining:
        nbatches = len(train_ld) if args.num_batches == 0 else args.num_batches
    else:
        nbatches = len(train_ld)

    # specify the optimizer algorithm
    optimizer = torch.optim.Adagrad(tbsm.parameters(), lr=args.learning_rate)
    total_time = 0
    total_loss = 0
    total_accu = 0
    total_iter = 0
    total_samp = 0
    max_gA_test = 0

    for j, (X, lS_o, lS_i, T) in enumerate(train_ld):
        if j >= nbatches:
            break
        t1 = time_wrap(use_gpu, device)
        batchSize = X[0].shape[0]
        # forward pass
        Z = tbsm(*data_wrap(X,
            lS_o,
            lS_i,
            use_gpu,
            device
        ))

        # loss
        E = loss_fn_wrap(Z, T, use_gpu, device)

        # compute loss and accuracy
        L = E.detach().cpu().numpy()  # numpy array
        z = Z.detach().cpu().numpy()  # numpy array
        t = T.detach().cpu().numpy()  # numpy array
        # rounding t
        A = np.sum((np.round(z, 0) == np.round(t, 0)).astype(np.uint8))

        optimizer.zero_grad()

        # backward pass
        E.backward(retain_graph=True)

        # weights update
        optimizer.step()

        t2 = time_wrap(use_gpu, device)
        total_time += t2 - t1
        total_loss += (L * batchSize)
        total_accu += A
        total_iter += 1
        total_samp += batchSize

        print_tl = ((j + 1) % args.print_freq == 0) or (j + 1 == nbatches)
        # print time, loss and accuracy
        if print_tl and isMainTraining:

            gT = 1000.0 * total_time / total_iter if args.print_time else -1
            total_time = 0

            gL = total_loss / total_samp
            total_loss = 0

            gA = total_accu / total_samp
            total_accu = 0

            str_run_type = "inference" if args.inference_only else "training"
            print(
                "Finished {} it {}/{} of epoch {}, ".format(
                    str_run_type, j + 1, nbatches, k
                )
                + "{:.2f} ms/it, loss {:.8f}, accuracy {:3.3f} %".format(
                    gT, gL, gA * 100
                )
            )
            total_iter = 0
            total_samp = 0

        if isMainTraining:
            should_test = (
                (args.test_freq > 0
                and (j + 1) % args.test_freq == 0) or j + 1 == nbatches
            )
        else:
            should_test = (j == min(int(0.05 * len(train_ld)), len(train_ld) - 1))

        #  validation run
        if should_test:

            total_accu_test, total_samp_test, total_loss_val = iterate_val_data(val_ld, tbsm, use_gpu, device)

            gA_test = total_accu_test / total_samp_test
            if not isMainTraining:
                break

            gL_test = total_loss_val / total_samp_test

            print("At epoch {:d} validation accuracy is {:3.3f} %".
                format(k, gA_test * 100))

            if args.enable_summary and isMainTraining:
                writer.add_scalars('train and val loss',
                    {'train_loss': gL,
                        'val_loss': gL_test},
                    k * len(train_ld) + j)
                writer.add_scalars('train and val accuracy',
                    {'train_acc': gA * 100,
                        'val_acc': gA_test * 100},
                    k * len(train_ld) + j)
                losses = np.append(losses, np.array([[j, gL, gL_test]]),
                axis=0)
                accuracies = np.append(accuracies, np.array([[j, gA * 100,
                gA_test * 100]]), axis=0)

            # save model if best so far
            if gA_test > max_gA_test and isMainTraining:
                print("Saving current model...")
                max_gA_test = gA_test
                model_ = tbsm
                torch.save(
                    {
                        "model_state_dict": model_.state_dict(),
                        # "opt_state_dict": optimizer.state_dict(),
                    },
                    args.save_model,
                )
    if not isMainTraining:
        return gA_test