in tfx/components/example_gen/utils.py [0:0]
def generate_output_split_names(
input_config: Union[example_gen_pb2.Input, Dict[str, Any]],
output_config: Union[example_gen_pb2.Output, Dict[str, Any]]) -> List[str]:
"""Return output split name based on input and output config.
Return output split name if it's specified and input only contains one split,
otherwise output split will be same as input.
Args:
input_config: example_gen_pb2.Input instance. If any field is provided as a
RuntimeParameter, input_config should be constructed as a dict with the
same field names as Input proto message.
output_config: example_gen_pb2.Output instance. If any field is provided as
a RuntimeParameter, output_config should be constructed as a dict with the
same field names as Output proto message.
Returns:
List of split names.
Raises:
RuntimeError: if configs are not valid, including:
- Missing field.
- Duplicated split.
- Output split is specified while input has more than one split.
- Missing train and eval split.
"""
result = []
# Convert proto to dict for easy sanity check. Otherwise we need to branch the
# logic based on parameter types.
if isinstance(output_config, example_gen_pb2.Output):
output_config = json_format.MessageToDict(
output_config,
including_default_value_fields=True,
preserving_proto_field_name=True)
if isinstance(input_config, example_gen_pb2.Input):
input_config = json_format.MessageToDict(
input_config,
including_default_value_fields=True,
preserving_proto_field_name=True)
if 'split_config' in output_config and 'splits' in output_config[
'split_config']:
if 'splits' not in input_config:
raise RuntimeError(
'ExampleGen instance specified output splits but no input split '
'is specified.')
if len(input_config['splits']) != 1:
# If output is specified, then there should only be one input split.
raise RuntimeError(
'ExampleGen instance specified output splits but at the same time '
'input has more than one split.')
for split in output_config['split_config']['splits']:
if not split['name'] or (isinstance(split['hash_buckets'], int) and
split['hash_buckets'] <= 0):
raise RuntimeError('Str-typed output split name and int-typed '
'hash buckets are required.')
result.append(split['name'])
else:
# If output is not specified, it will have the same split as the input.
if 'splits' in input_config:
for split in input_config['splits']:
if not split['name'] or not split['pattern']:
raise RuntimeError('Str-typed input split name and pattern '
'are required.')
result.append(split['name'])
if not result:
raise RuntimeError('ExampleGen splits are missing.')
if len(result) != len(set(result)):
raise RuntimeError('Duplicated split name {}.'.format(result))
return result