in tfx_addons/sampling/example/sampler_utils.py [0:0]
def _input_fn(file_pattern: List[Text],
data_accessor: DataAccessor,
tf_transform_output: tft.TFTransformOutput,
batch_size: int = 200) -> tf.data.Dataset:
"""Generates features and label for tuning/training.
Args:
file_pattern: List of paths or patterns of input tfrecord files.
data_accessor: DataAccessor for converting input to RecordBatch.
tf_transform_output: A TFTransformOutput.
batch_size: representing the number of consecutive elements of returned
dataset to combine in a single batch
Returns:
A dataset that contains (features, indices) tuple where features is a
dictionary of Tensors, and indices is a single Tensor of label indices.
"""
return data_accessor.tf_dataset_factory(
file_pattern,
dataset_options.TensorFlowDatasetOptions(
batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)),
tf_transform_output.transformed_metadata.schema)