def train_tbsm()

in tbsm_pytorch.py [0:0]


def train_tbsm(args, use_gpu):
    # prepare the data
    train_ld, _ = tp.make_tbsm_data_and_loader(args, "train")
    val_ld, _ = tp.make_tbsm_data_and_loader(args, "val")

    # setup initial values
    isMainTraining = False
    writer = SummaryWriter()
    losses = np.empty((0,3), np.float32)
    accuracies = np.empty((0,3), np.float32)

    # selects best seed out of 5. Sometimes Adagrad gets stuck early, this
    # seems to occur randomly depending on initial weight values and
    # is independent of chosen model: N-inner, dot etc.
    # this procedure is used to reduce the probability of this happening.
    def select(args):

        seeds = np.random.randint(2, 10000, size=5)
        if args.debug_mode:
            print(seeds)
        best_index = 0
        max_val_accuracy = 0.0
        testpoint = min(int(0.05 * len(train_ld)), len(train_ld) - 1)
        print("testpoint, total batches: ", testpoint, len(train_ld))

        for i, seed in enumerate(seeds):

            set_seed(seed, use_gpu)
            tbsm, device = get_tbsm(args, use_gpu)

            gA_test = iterate_train_data(args, train_ld, val_ld, tbsm, 0, use_gpu,
                                         device, writer, losses, accuracies,
                                         isMainTraining)

            if args.debug_mode:
                print("select: ", i, seed, gA_test, max_val_accuracy)
            if gA_test > max_val_accuracy:
                best_index = i
                max_val_accuracy = gA_test

        return seeds[best_index]

    # select best seed if needed
    if args.no_select_seed or path.exists(args.save_model):
        seed = args.numpy_rand_seed
    else:
        print("Choosing best seed...")
        seed = select(args)
    set_seed(seed, use_gpu)
    print("selected seed:", seed)

    # create or load TBSM
    tbsm, device = get_tbsm(args, use_gpu)
    if args.debug_mode:
        print("initial parameters (weights and bias):")
        for name, param in tbsm.named_parameters():
            print(name)
            print(param.detach().cpu().numpy())

    # main training loop
    isMainTraining = True
    print("time/loss/accuracy (if enabled):")
    with torch.autograd.profiler.profile(args.enable_profiling, use_gpu) as prof:
        for k in range(args.nepochs):
            iterate_train_data(args, train_ld, val_ld, tbsm, k, use_gpu, device,
            writer, losses, accuracies, isMainTraining)

    # collect metrics and other statistics about the run
    if args.enable_summary:
        with open('summary.npy', 'wb') as acc_loss:
            np.save(acc_loss, losses)
            np.save(acc_loss, accuracies)
        writer.close()

    # debug prints
    if args.debug_mode:
        print("final parameters (weights and bias):")
        for name, param in tbsm.named_parameters():
            print(name)
            print(param.detach().cpu().numpy())

    # profiling
    if args.enable_profiling:
        with open("tbsm_pytorch.prof", "w") as prof_f:
            prof_f.write(
                prof.key_averages(group_by_input_shape=True).table(
                    sort_by="self_cpu_time_total"
                )
            )
            prof.export_chrome_trace("./tbsm_pytorch.json")

    return