in tbsm_data_pytorch.py [0:0]
def make_tbsm_data_and_loader(args, mode):
if mode == "train":
raw = args.raw_train_file
proc = args.pro_train_file
numpts = args.num_train_pts
batchsize = args.mini_batch_size
doshuffle = True
elif mode == "val":
raw = args.raw_train_file
proc = args.pro_val_file
numpts = args.num_val_pts
batchsize = 25000
doshuffle = True
else:
raw = args.raw_test_file
proc = args.pro_test_file
numpts = 1
batchsize = 25000
doshuffle = False
data = TBSMDataset(
args.datatype,
mode,
args.ts_length,
args.points_per_user,
args.numpy_rand_seed,
raw,
proc,
args.arch_embedding_size,
numpts,
)
loader = torch.utils.data.DataLoader(
data,
batch_size=batchsize,
num_workers=0,
collate_fn=collate_wrapper_tbsm,
shuffle=doshuffle,
)
return loader, len(data)