def make_tbsm_data_and_loader()

in tbsm_data_pytorch.py [0:0]


def make_tbsm_data_and_loader(args, mode):

    if mode == "train":
        raw = args.raw_train_file
        proc = args.pro_train_file
        numpts = args.num_train_pts
        batchsize = args.mini_batch_size
        doshuffle = True
    elif mode == "val":
        raw = args.raw_train_file
        proc = args.pro_val_file
        numpts = args.num_val_pts
        batchsize = 25000
        doshuffle = True
    else:
        raw = args.raw_test_file
        proc = args.pro_test_file
        numpts = 1
        batchsize = 25000
        doshuffle = False

    data = TBSMDataset(
        args.datatype,
        mode,
        args.ts_length,
        args.points_per_user,
        args.numpy_rand_seed,
        raw,
        proc,
        args.arch_embedding_size,
        numpts,
    )

    loader = torch.utils.data.DataLoader(
        data,
        batch_size=batchsize,
        num_workers=0,
        collate_fn=collate_wrapper_tbsm,
        shuffle=doshuffle,
    )

    return loader, len(data)