def collate_wrapper_tbsm()

in tbsm_data_pytorch.py [0:0]


def collate_wrapper_tbsm(list_of_tuples):
    # turns tuple into X, S_o, S_i, take last ts_length items

    data = list(zip(*list_of_tuples))
    all_cat = torch.tensor(data[0], dtype=torch.long)
    all_int = torch.tensor(data[1], dtype=torch.float)

    # print("shapes:", all_cat.shape, all_int.shape)

    num_den_fea = all_int.shape[1]
    num_cat_fea = all_cat.shape[1]
    batchSize = all_cat.shape[0]
    ts_len = all_cat.shape[2]
    all_int = torch.reshape(all_int, (batchSize, num_den_fea * ts_len))

    X = []
    lS_i = []
    lS_o = []

    # transform data into the form used in dlrm nn
    for j in range(ts_len):

        lS_i_h = []
        for i in range(num_cat_fea):
            lS_i_h.append(all_cat[:, i, j])

        lS_o_h = [torch.tensor(range(batchSize)) for _ in range(len(lS_i_h))]

        lS_i.append(lS_i_h)
        lS_o.append(lS_o_h)
        X.append(all_int[:, j].view(-1, 1))

    T = torch.tensor(data[2], dtype=torch.float32).view(-1, 1)

    return X, lS_o, lS_i, T