in train/compute/python/workloads/pytorch/split_table_batched_embeddings_ops.py [0:0]
def get_data(self, config, device):
logger.debug(f"data generator config: {config}")
# batch size * pooling_factor
num_tables = config["args"][0]["value"]
if num_tables > 1:
rows = genericList_to_list(config["args"][1])
pooling_factors = genericList_to_list(config["args"][4])
else:
rows = [config["args"][1]["value"]]
pooling_factors = [config["args"][4]["value"]]
batch_size = config["args"][3]["value"]
weighted = config["args"][5]["value"]
indices_list = []
offsets_list = []
per_sample_weights_list = []
offset_start = 0
distribution = os.getenv("split_embedding_distribution")
if distribution is None:
distribution = 1
logger.debug(f"distribution = {distribution}")
target_device = torch.device(device)
indices_file = None
offsets_file = None
weights_file = None
if ("indices_tensor" in config["args"][4]) and (
"offsets_tensor" in config["args"][4]
):
indices_file = config["args"][4]["indices_tensor"]
offsets_file = config["args"][4]["offsets_tensor"]
if weighted and "weights_tensor" in config["args"][4]:
weights_file = config["args"][4]["weights_tensor"]
else:
indices_file = os.getenv("split_embedding_indices")
offsets_file = os.getenv("split_embedding_offsets")
if weighted:
weights_file = os.getenv("split_embedding_weights")
logger.debug(f"indices_file: {indices_file}, offsets_file: {offsets_file}")
if indices_file is not None and offsets_file is not None:
indices_tensor = torch.load(indices_file, map_location=target_device)
offsets_tensor = torch.load(offsets_file, map_location=target_device)
per_sample_weights_tensor = None
if weights_file:
per_sample_weights_tensor = torch.load(
weights_file, map_location=target_device
)
else:
for i in range(num_tables):
indices, offsets, per_sample_weights = generate_requests(
batch_size,
pooling_factors[i],
rows[i],
offset_start,
float(distribution),
weighted,
)
indices_list.append(indices)
offsets_list.append(offsets)
# update to the offset_start to the last element of current offset
offset_start = offsets[-1].item()
if weighted:
per_sample_weights_list.append(per_sample_weights)
indices_tensor = torch.cat(indices_list)
offsets_tensor = torch.cat(offsets_list)
# check for per sample weights
per_sample_weights_tensor = (
torch.cat(per_sample_weights_list) if weighted else None
)
logger.debug(f"indices: {indices_tensor.shape}")
logger.debug(f"offsets: {offsets_tensor.shape}")
if per_sample_weights_tensor is not None:
logger.debug(
f"per_sample_weights: {per_sample_weights_tensor.shape}, {per_sample_weights_tensor}"
)
return (
[
indices_tensor.to(target_device),
offsets_tensor.to(target_device),
per_sample_weights_tensor.to(target_device) if weighted else None,
],
{},
)