def get_data()

in train/compute/python/workloads/pytorch/split_table_batched_embeddings_ops.py [0:0]


    def get_data(self, config, device):
        logger.debug(f"data generator config: {config}")
        # batch size * pooling_factor
        num_tables = config["args"][0]["value"]
        if num_tables > 1:
            rows = genericList_to_list(config["args"][1])
            pooling_factors = genericList_to_list(config["args"][4])
        else:
            rows = [config["args"][1]["value"]]
            pooling_factors = [config["args"][4]["value"]]
        batch_size = config["args"][3]["value"]
        weighted = config["args"][5]["value"]

        indices_list = []
        offsets_list = []
        per_sample_weights_list = []
        offset_start = 0
        distribution = os.getenv("split_embedding_distribution")
        if distribution is None:
            distribution = 1
        logger.debug(f"distribution = {distribution}")

        target_device = torch.device(device)

        indices_file = None
        offsets_file = None
        weights_file = None
        if ("indices_tensor" in config["args"][4]) and (
            "offsets_tensor" in config["args"][4]
        ):
            indices_file = config["args"][4]["indices_tensor"]
            offsets_file = config["args"][4]["offsets_tensor"]
            if weighted and "weights_tensor" in config["args"][4]:
                weights_file = config["args"][4]["weights_tensor"]
        else:
            indices_file = os.getenv("split_embedding_indices")
            offsets_file = os.getenv("split_embedding_offsets")
            if weighted:
                weights_file = os.getenv("split_embedding_weights")

        logger.debug(f"indices_file: {indices_file}, offsets_file: {offsets_file}")
        if indices_file is not None and offsets_file is not None:
            indices_tensor = torch.load(indices_file, map_location=target_device)
            offsets_tensor = torch.load(offsets_file, map_location=target_device)
            per_sample_weights_tensor = None
            if weights_file:
                per_sample_weights_tensor = torch.load(
                    weights_file, map_location=target_device
                )
        else:
            for i in range(num_tables):
                indices, offsets, per_sample_weights = generate_requests(
                    batch_size,
                    pooling_factors[i],
                    rows[i],
                    offset_start,
                    float(distribution),
                    weighted,
                )
                indices_list.append(indices)
                offsets_list.append(offsets)
                # update to the offset_start to the last element of current offset
                offset_start = offsets[-1].item()
                if weighted:
                    per_sample_weights_list.append(per_sample_weights)

            indices_tensor = torch.cat(indices_list)
            offsets_tensor = torch.cat(offsets_list)

            # check for per sample weights
            per_sample_weights_tensor = (
                torch.cat(per_sample_weights_list) if weighted else None
            )

        logger.debug(f"indices: {indices_tensor.shape}")
        logger.debug(f"offsets: {offsets_tensor.shape}")
        if per_sample_weights_tensor is not None:
            logger.debug(
                f"per_sample_weights: {per_sample_weights_tensor.shape}, {per_sample_weights_tensor}"
            )

        return (
            [
                indices_tensor.to(target_device),
                offsets_tensor.to(target_device),
                per_sample_weights_tensor.to(target_device) if weighted else None,
            ],
            {},
        )