def collect_data_step()

in tensorflow_decision_forests/keras/core.py [0:0]


  def collect_data_step(self, data, is_training_example):
    """Collect examples e.g. training or validation."""

    if isinstance(data, dict):
      raise ValueError("No label received for training. If you used "
                       "`pd_dataframe_to_tf_dataset`, make sure to "
                       f"specify the `label` argument. data={data}")

    if len(data) == 2:
      train_x, train_y = data
      train_weights = None
    elif len(data) == 3:
      train_x, train_y, train_weights = data
    else:
      raise ValueError(f"Unexpected data shape {data}")

    if self._verbose >= 2:
      tf_logging.info(
          "%s tensor examples:\nFeatures: %s\nLabel: %s\nWeights: %s",
          "Training" if is_training_example else "Validation", train_x, train_y,
          train_weights)

    if isinstance(train_x, dict):
      _check_feature_names(
          train_x.keys(),
          self._advanced_arguments.fail_on_non_keras_compatible_feature_name)

    if self._preprocessing is not None:
      train_x = self._preprocessing(train_x)
      if self._verbose >= 2:
        tf_logging.info("Tensor example after pre-processing:\n%s", train_x)
      if isinstance(train_x, list) and self._features:
        tf_logging.warning(
            "Using \"features\" with a pre-processing stage returning a list "
            "is not recommended. Use a pre-processing stage that returns a "
            "dictionary instead.")

    if isinstance(train_x, dict):
      # Native format
      pass
    elif isinstance(train_x, tf.Tensor):
      train_x = {train_x.name: train_x}
    elif isinstance(train_x, list) or isinstance(train_x, tuple):
      # Note: The name of a tensor (value.name) can change between the training
      # and the inference.
      train_x = {str(idx): value for idx, value in enumerate(train_x)}
    else:
      raise ValueError(
          f"The training input tensor is expected to be a tensor, list of "
          f"tensors or a dictionary of tensors. Got {train_x} instead")

    # Check the labels
    if not isinstance(train_y, tf.Tensor):
      raise ValueError(
          f"The training label tensor is expected to be a tensor. Got {train_y}"
          " instead.")

    if len(train_y.shape) != 1:
      if self._verbose >= 2:
        tf_logging.info(
            "Squeeze label' shape from [batch_size, 1] to [batch_size]")
      train_y = tf.squeeze(train_y, axis=1)

    if len(train_y.shape) != 1:
      raise ValueError(
          "Labels can either be passed in as [batch_size, 1] or [batch_size]. "
          "Invalid shape %s." % train_y.shape)

    # Check the training
    self._weighted_training = train_weights is not None
    if self._weighted_training:
      if not isinstance(train_weights, tf.Tensor):
        raise ValueError(
            "The training weights tensor is expected to be a tensor. "
            f"Got {train_weights} instead.")

      if len(train_weights.shape) != 1:
        if self._verbose >= 2:
          tf_logging.info(
              "Squeeze weight' shape from [batch_size, 1] to [batch_size]")
        train_weights = tf.squeeze(train_weights, axis=1)

      if len(train_weights.shape) != 1:
        raise ValueError(
            "Weights can either be passed in as [batch_size, 1] or [batch_size]. "
            "Invalid shape %s." % train_weights.shape)

    # List the input features and their semantics.
    semantics = tf_core.infer_semantic(
        train_x, {feature.name: feature.semantic for feature in self._features},
        self._exclude_non_specified)

    # The ranking group and treatment are not part of the features unless
    # specified explicitly.
    if (self._ranking_group is not None and
        self._ranking_group not in self._features and
        self._ranking_group in semantics):
      del semantics[self._ranking_group]

    if (self._uplift_treatment is not None and
        self._uplift_treatment not in self._features and
        self._uplift_treatment in semantics):
      del semantics[self._uplift_treatment]

    if is_training_example:
      if self._semantics is not None:
        raise ValueError("The model is already trained")
      self._semantics = semantics
    else:
      self._has_validation_dataset = True
      if self._semantics is None:
        raise ValueError(
            "The validation should be collected after the training")

      if semantics != self._semantics:
        raise ValueError(
            "The validation dataset does not have the same "
            "semantic as the training dataset.\nTraining:\n{}\nValidation:\n{}"
            .format(self._semantics, semantics))

    semantic_inputs = tf_core.combine_tensors_and_semantics(train_x, semantics)

    normalized_semantic_inputs = tf_core.normalize_inputs(semantic_inputs)

    if self._verbose >= 2:
      tf_logging.info("Normalized tensor features:\n %s",
                      normalized_semantic_inputs)

    if is_training_example:
      self._normalized_input_keys = sorted(
          list(normalized_semantic_inputs.keys()))

    # Add the weights
    if self._weighted_training:
      normalized_semantic_inputs[_WEIGHTS] = tf_core.SemanticTensor(
          tensor=tf.cast(train_weights, tf_core.NormalizedNumericalType),
          semantic=tf_core.Semantic.NUMERICAL)

    # Add the semantic of the label.
    if self._task == Task.CLASSIFICATION:
      normalized_semantic_inputs[_LABEL] = tf_core.SemanticTensor(
          tensor=tf.cast(train_y, tf_core.NormalizedCategoricalIntType) +
          tf_core.CATEGORICAL_INTEGER_OFFSET,
          semantic=tf_core.Semantic.CATEGORICAL)

    elif self._task == Task.REGRESSION:
      normalized_semantic_inputs[_LABEL] = tf_core.SemanticTensor(
          tensor=tf.cast(train_y, tf_core.NormalizedNumericalType),
          semantic=tf_core.Semantic.NUMERICAL)

    elif self._task == Task.RANKING:
      normalized_semantic_inputs[_LABEL] = tf_core.SemanticTensor(
          tensor=tf.cast(train_y, tf_core.NormalizedNumericalType),
          semantic=tf_core.Semantic.NUMERICAL)

      assert self._ranking_group is not None
      if self._ranking_group not in train_x:
        raise Exception(
            "The ranking key feature \"{}\" is not available as an input "
            "feature.".format(self._ranking_group))
      normalized_semantic_inputs[_RANK_GROUP] = tf_core.SemanticTensor(
          tensor=tf.cast(train_x[self._ranking_group],
                         tf_core.NormalizedHashType),
          semantic=tf_core.Semantic.HASH)

    elif self._task == Task.CATEGORICAL_UPLIFT:
      normalized_semantic_inputs[_LABEL] = tf_core.SemanticTensor(
          tensor=tf.cast(train_y, tf_core.NormalizedCategoricalIntType) +
          tf_core.CATEGORICAL_INTEGER_OFFSET,
          semantic=tf_core.Semantic.CATEGORICAL)

      assert self._uplift_treatment is not None
      if self._uplift_treatment not in train_x:
        raise Exception(
            "The uplift treatment key feature \"{}\" is not available as an input "
            "feature.".format(self._uplift_treatment))
      normalized_semantic_inputs[_UPLIFT_TREATMENT] = tf_core.SemanticTensor(
          tensor=tf.cast(train_x[self._uplift_treatment],
                         tf_core.NormalizedCategoricalIntType) +
          tf_core.CATEGORICAL_INTEGER_OFFSET,
          semantic=tf_core.Semantic.CATEGORICAL)

    else:
      raise Exception("Non supported task {}".format(self._task))

    if not self._is_trained:
      # Collects the training examples.

      distribution_config = tf_core.get_distribution_configuration(
          self.distribute_strategy)
      if distribution_config is None:
        # No distribution strategy. Collecting examples in memory.
        tf_core.collect_training_examples(
            normalized_semantic_inputs,
            self._training_model_id,
            collect_training_data=is_training_example)

      else:

        if not is_training_example:
          tf_logging.warning(
              "The validation dataset given to `fit` is not used to help "
              "training (e.g. early stopping) in the case of distributed "
              "training. If you want to use a validation dataset use "
              "non-distributed training or use `fit_from_file` instead.")

        # Each worker collects a part of the dataset.
        if not self.capabilities().support_partial_cache_dataset_format:
          raise ValueError(
              f"The model {type(self)} does not support training with a TF "
              "Distribution strategy (i.e. model.capabilities()."
              "support_partial_cache_dataset_format == False). If the dataset "
              "is small, simply remove "
              "the distribution strategy scope (i.e. `with strategy.scope():` "
              "around the model construction). If the dataset is large, use a "
              "distributed version of the model. For Example, use "
              "DistributedGradientBoostedTreesModel instead of "
              "GradientBoostedTreesModel.")

        tf_core.collect_distributed_training_examples(
            inputs=normalized_semantic_inputs,
            model_id=self._training_model_id,
            dataset_path=self._distributed_partial_dataset_cache_path())

    # Not metrics are returned during the collection of training examples.
    return {}