in tbsm_data_pytorch.py [0:0]
def truncate_and_save(self, out_file, do_shuffle, t, users, items, cats, times, y):
# truncate. If for some users we didn't generate had too short history
# we truncate the unused portion of the pre-allocated matrix.
if t < self.total_out:
users = users[:t, :]
items = items[:t, :]
cats = cats[:t, :]
times = times[:t, :]
y = y[:t]
# shuffle
if do_shuffle:
indices = np.arange(len(y))
indices = np.random.permutation(indices)
users = users[indices]
items = items[indices]
cats = cats[indices]
times = times[indices]
y = y[indices]
N = len(y)
X_cat = np.zeros((3, N, self.ts_length + 1), dtype="i4") # 4 byte int
X_int = np.zeros((1, N, self.ts_length + 1), dtype=np.float)
X_cat[0, :, :] = users
X_cat[1, :, :] = items
X_cat[2, :, :] = cats
X_int[0, :, :] = times
# saving to compressed numpy file
if not path.exists(out_file):
np.savez_compressed(
out_file,
X_cat=X_cat,
X_int=X_int,
y=y,
)
return