in tbsm_data_pytorch.py [0:0]
def collate_wrapper_tbsm(list_of_tuples):
# turns tuple into X, S_o, S_i, take last ts_length items
data = list(zip(*list_of_tuples))
all_cat = torch.tensor(data[0], dtype=torch.long)
all_int = torch.tensor(data[1], dtype=torch.float)
# print("shapes:", all_cat.shape, all_int.shape)
num_den_fea = all_int.shape[1]
num_cat_fea = all_cat.shape[1]
batchSize = all_cat.shape[0]
ts_len = all_cat.shape[2]
all_int = torch.reshape(all_int, (batchSize, num_den_fea * ts_len))
X = []
lS_i = []
lS_o = []
# transform data into the form used in dlrm nn
for j in range(ts_len):
lS_i_h = []
for i in range(num_cat_fea):
lS_i_h.append(all_cat[:, i, j])
lS_o_h = [torch.tensor(range(batchSize)) for _ in range(len(lS_i_h))]
lS_i.append(lS_i_h)
lS_o.append(lS_o_h)
X.append(all_int[:, j].view(-1, 1))
T = torch.tensor(data[2], dtype=torch.float32).view(-1, 1)
return X, lS_o, lS_i, T