in tfx/components/trainer/component.py [0:0]
def __init__(
self,
examples: Optional[types.BaseChannel] = None,
transformed_examples: Optional[types.BaseChannel] = None,
transform_graph: Optional[types.BaseChannel] = None,
schema: Optional[types.BaseChannel] = None,
base_model: Optional[types.BaseChannel] = None,
hyperparameters: Optional[types.BaseChannel] = None,
module_file: Optional[Union[str, data_types.RuntimeParameter]] = None,
run_fn: Optional[Union[str, data_types.RuntimeParameter]] = None,
train_args: Optional[Union[trainer_pb2.TrainArgs,
data_types.RuntimeParameter]] = None,
eval_args: Optional[Union[trainer_pb2.EvalArgs,
data_types.RuntimeParameter]] = None,
custom_config: Optional[Union[Dict[str, Any],
data_types.RuntimeParameter]] = None,
custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None):
"""Construct a Trainer component.
Args:
examples: A [BaseChannel][tfx.v1.types.BaseChannel] of type [`standard_artifacts.Examples`][tfx.v1.types.standard_artifacts.Examples],
serving as the source of examples used in training (required). May be raw or
transformed.
transformed_examples: Deprecated (no compatibility guarantee). Please set
'examples' instead.
transform_graph: An optional [BaseChannel][tfx.v1.types.BaseChannel] of type
[`standard_artifacts.TransformGraph`][tfx.v1.types.standard_artifacts.TransformGraph],
serving as the input transform graph if present.
schema: An optional [BaseChannel][tfx.v1.types.BaseChannel] of type
[`standard_artifacts.Schema`][tfx.v1.types.standard_artifacts.Schema],
serving as the schema of training and eval data. Schema is optional when
1. transform_graph is provided which contains schema.
2. user module bypasses the usage of schema, e.g., hardcoded.
base_model: A [BaseChannel][tfx.v1.types.BaseChannel] of type `Model`, containing model that will be
used for training. This can be used for warmstart, transfer learning or
model ensembling.
hyperparameters: A [BaseChannel] of type
[`standard_artifacts.HyperParameters`][tfx.v1.types.standard_artifacts.HyperParameters],
serving as the hyperparameters for training module. Tuner's output best
hyperparameters can be feed into this.
module_file: A path to python module file containing UDF model definition.
The `module_file` must implement a function named `run_fn` at its top
level with function signature:
```python
def run_fn(trainer.fn_args_utils.FnArgs)
```
and the trained model must be saved to `FnArgs.serving_model_dir` when
this function is executed.
Exactly one of `module_file` or `run_fn` must be supplied if Trainer
uses GenericExecutor (default). Use of a [RuntimeParameter][tfx.v1.dsl.experimental.RuntimeParameter] for this
argument is experimental.
run_fn: A python path to UDF model definition function for generic
trainer. See 'module_file' for details. Exactly one of 'module_file' or
'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use
of a [RuntimeParameter][tfx.v1.dsl.experimental.RuntimeParameter] for this argument is experimental.
train_args: A proto.TrainArgs instance, containing args used for training
Currently only splits and num_steps are available. Default behavior
(when splits is empty) is train on `train` split.
eval_args: A proto.EvalArgs instance, containing args used for evaluation.
Currently only splits and num_steps are available. Default behavior
(when splits is empty) is evaluate on `eval` split.
custom_config: A dict which contains addtional training job parameters
that will be passed into user module.
custom_executor_spec: Optional custom executor spec. Deprecated (no
compatibility guarantee), please customize component directly.
Raises:
ValueError:
- When both or neither of `module_file` and `run_fn` is supplied.
- When both or neither of `examples` and `transformed_examples`
is supplied.
- When `transformed_examples` is supplied but `transform_graph`
is not supplied.
"""
if [bool(module_file), bool(run_fn)].count(True) != 1:
raise ValueError(
"Exactly one of 'module_file', or 'run_fn' must be supplied.")
if bool(examples) == bool(transformed_examples):
raise ValueError(
"Exactly one of 'example' or 'transformed_example' must be supplied.")
if transformed_examples and not transform_graph:
raise ValueError("If 'transformed_examples' is supplied, "
"'transform_graph' must be supplied too.")
if custom_executor_spec:
logging.warning(
"`custom_executor_spec` is deprecated. Please customize component directly."
)
if transformed_examples:
logging.warning(
"`transformed_examples` is deprecated. Please use `examples` instead."
)
examples = examples or transformed_examples
model = types.Channel(type=standard_artifacts.Model)
model_run = types.Channel(type=standard_artifacts.ModelRun)
spec = standard_component_specs.TrainerSpec(
examples=examples,
transform_graph=transform_graph,
schema=schema,
base_model=base_model,
hyperparameters=hyperparameters,
train_args=train_args or trainer_pb2.TrainArgs(),
eval_args=eval_args or trainer_pb2.EvalArgs(),
module_file=module_file,
run_fn=run_fn,
custom_config=(custom_config
if isinstance(custom_config, data_types.RuntimeParameter)
else json_utils.dumps(custom_config)),
model=model,
model_run=model_run)
super().__init__(spec=spec, custom_executor_spec=custom_executor_spec)
if udf_utils.should_package_user_modules():
# In this case, the `MODULE_PATH_KEY` execution property will be injected
# as a reference to the given user module file after packaging, at which
# point the `MODULE_FILE_KEY` execution property will be removed.
udf_utils.add_user_module_dependency(
self, standard_component_specs.MODULE_FILE_KEY,
standard_component_specs.MODULE_PATH_KEY)