in src/sagemaker_tensorflow/pipemode.py [0:0]
def __init__(self, channel, record_format='RecordIO',
state_dir='/opt/ml/pipe_state', pipe_dir='/opt/ml/input/data',
config_dir='/opt/ml/input/config', benchmark=False, benchmark_records_interval=0,
max_corrupted_records_to_skip=0):
"""Create a Dataset for reading from a SageMaker PipeMode channel.
Supports records encoded using either RecordIO, TFRecord, or new line text encoding.
Args:
record_format: The record format to use. One of 'RecordIO', 'TFRecord', or 'TextLine'
channel: The name of the SageMaker channel.
pipe_dir: The directory to read SageMaker Channels from.
state_dir: The directory where pipe index state is persisted.
config_dir: The path for SageMaker input data config.
benchmark: Controls whether to emit timing and throughput metrics after closing an Iterator created from
this Dataset. If True, timing and throughput metrics will be emitted to stdout after an Iterator
created from this Dataset is destroyed.
benchmark_records_interval: Controls whether to emit timing and throughput metrics while records are being
read from this Dataset. Defines the number of records per interval to emit timing and throughput
metrics. If zero, no metrics will be emitted while records are being read from this Dataset.
Metrics are emitted to stdout.
max_corrupted_records_to_skip: the number of corrupted records encountered in sequence that it's ok to
skip. Only applicable for record_format='TFRecord'.
"""
try:
os.makedirs(state_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
self.record_format = record_format
self.channel = channel
self.pipe_dir = pipe_dir
self.state_dir = state_dir
self.benchmark = benchmark
self.benchmark_records_interval = benchmark_records_interval
self.max_corrupted_records_to_skip = max_corrupted_records_to_skip
with open(os.path.join(config_dir, 'inputdataconfig.json')) as f:
self.input_data_config = json.load(f)
self._validate_input_data_config()
if self.max_corrupted_records_to_skip > 0 and record_format != 'TFRecord':
raise PipeModeDatasetException("max_corrupted_records_to_skip can only be set for record_format='TFRecord'")
super(PipeModeDataset, self).__init__()