in torchrec/datasets/random.py [0:0]
def _generate_batch(self) -> Batch:
if self.hash_sizes is None:
# pyre-ignore[28]
values = torch.randint(
high=self.hash_size,
size=(self._num_ids_in_batch,),
generator=self.generator,
)
else:
values = (
torch.rand(
self._num_ids_in_batch,
generator=self.generator,
)
* none_throws(self.max_values)
).type(torch.LongTensor)
sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=self.keys,
values=values,
offsets=torch.tensor(
list(
range(
0,
self._num_ids_in_batch + 1,
self.ids_per_feature,
)
),
dtype=torch.int32,
),
)
dense_features = torch.randn(
self.batch_size,
self.num_dense,
generator=self.generator,
)
# pyre-ignore[28]
labels = torch.randint(
low=0,
high=2,
size=(self.batch_size,),
generator=self.generator,
)
batch = Batch(
dense_features=dense_features,
sparse_features=sparse_features,
labels=labels,
)
return batch