def __init__()

in tfx/components/trainer/component.py [0:0]


  def __init__(
      self,
      examples: Optional[types.BaseChannel] = None,
      transformed_examples: Optional[types.BaseChannel] = None,
      transform_graph: Optional[types.BaseChannel] = None,
      schema: Optional[types.BaseChannel] = None,
      base_model: Optional[types.BaseChannel] = None,
      hyperparameters: Optional[types.BaseChannel] = None,
      module_file: Optional[Union[str, data_types.RuntimeParameter]] = None,
      run_fn: Optional[Union[str, data_types.RuntimeParameter]] = None,
      train_args: Optional[Union[trainer_pb2.TrainArgs,
                                 data_types.RuntimeParameter]] = None,
      eval_args: Optional[Union[trainer_pb2.EvalArgs,
                                data_types.RuntimeParameter]] = None,
      custom_config: Optional[Union[Dict[str, Any],
                                    data_types.RuntimeParameter]] = None,
      custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None):
    """Construct a Trainer component.

    Args:
      examples: A [BaseChannel][tfx.v1.types.BaseChannel] of type [`standard_artifacts.Examples`][tfx.v1.types.standard_artifacts.Examples],
        serving as the source of examples used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated (no compatibility guarantee). Please set
        'examples' instead.
      transform_graph: An optional [BaseChannel][tfx.v1.types.BaseChannel] of type
        [`standard_artifacts.TransformGraph`][tfx.v1.types.standard_artifacts.TransformGraph],
        serving as the input transform graph if present.
      schema:  An optional [BaseChannel][tfx.v1.types.BaseChannel] of type
        [`standard_artifacts.Schema`][tfx.v1.types.standard_artifacts.Schema],
        serving as the schema of training and eval data. Schema is optional when

          1. transform_graph is provided which contains schema.
          2. user module bypasses the usage of schema, e.g., hardcoded.
      base_model: A [BaseChannel][tfx.v1.types.BaseChannel] of type `Model`, containing model that will be
        used for training. This can be used for warmstart, transfer learning or
        model ensembling.
      hyperparameters: A [BaseChannel] of type
        [`standard_artifacts.HyperParameters`][tfx.v1.types.standard_artifacts.HyperParameters],
        serving as the hyperparameters for training module. Tuner's output best
        hyperparameters can be feed into this.
      module_file: A path to python module file containing UDF model definition.
        The `module_file` must implement a function named `run_fn` at its top
        level with function signature:
        ```python
        def run_fn(trainer.fn_args_utils.FnArgs)
        ```
        and the trained model must be saved to `FnArgs.serving_model_dir` when
        this function is executed.

        Exactly one of `module_file` or `run_fn` must be supplied if Trainer
        uses GenericExecutor (default). Use of a [RuntimeParameter][tfx.v1.dsl.experimental.RuntimeParameter] for this
        argument is experimental.
      run_fn:  A python path to UDF model definition function for generic
        trainer. See 'module_file' for details. Exactly one of 'module_file' or
        'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use
        of a [RuntimeParameter][tfx.v1.dsl.experimental.RuntimeParameter] for this argument is experimental.
      train_args: A proto.TrainArgs instance, containing args used for training
        Currently only splits and num_steps are available. Default behavior
        (when splits is empty) is train on `train` split.
      eval_args: A proto.EvalArgs instance, containing args used for evaluation.
        Currently only splits and num_steps are available. Default behavior
        (when splits is empty) is evaluate on `eval` split.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      custom_executor_spec: Optional custom executor spec. Deprecated (no
        compatibility guarantee), please customize component directly.

    Raises:
      ValueError:
        - When both or neither of `module_file` and `run_fn` is supplied.
        - When both or neither of `examples` and `transformed_examples`
            is supplied.
        - When `transformed_examples` is supplied but `transform_graph`
            is not supplied.
    """
    if [bool(module_file), bool(run_fn)].count(True) != 1:
      raise ValueError(
          "Exactly one of 'module_file', or 'run_fn' must be supplied.")

    if bool(examples) == bool(transformed_examples):
      raise ValueError(
          "Exactly one of 'example' or 'transformed_example' must be supplied.")

    if transformed_examples and not transform_graph:
      raise ValueError("If 'transformed_examples' is supplied, "
                       "'transform_graph' must be supplied too.")

    if custom_executor_spec:
      logging.warning(
          "`custom_executor_spec` is deprecated. Please customize component directly."
      )
    if transformed_examples:
      logging.warning(
          "`transformed_examples` is deprecated. Please use `examples` instead."
      )
    examples = examples or transformed_examples
    model = types.Channel(type=standard_artifacts.Model)
    model_run = types.Channel(type=standard_artifacts.ModelRun)
    spec = standard_component_specs.TrainerSpec(
        examples=examples,
        transform_graph=transform_graph,
        schema=schema,
        base_model=base_model,
        hyperparameters=hyperparameters,
        train_args=train_args or trainer_pb2.TrainArgs(),
        eval_args=eval_args or trainer_pb2.EvalArgs(),
        module_file=module_file,
        run_fn=run_fn,
        custom_config=(custom_config
                       if isinstance(custom_config, data_types.RuntimeParameter)
                       else json_utils.dumps(custom_config)),
        model=model,
        model_run=model_run)
    super().__init__(spec=spec, custom_executor_spec=custom_executor_spec)

    if udf_utils.should_package_user_modules():
      # In this case, the `MODULE_PATH_KEY` execution property will be injected
      # as a reference to the given user module file after packaging, at which
      # point the `MODULE_FILE_KEY` execution property will be removed.
      udf_utils.add_user_module_dependency(
          self, standard_component_specs.MODULE_FILE_KEY,
          standard_component_specs.MODULE_PATH_KEY)