def pd_dataframe_to_tf_dataset()

in tensorflow_decision_forests/keras/core.py [0:0]


def pd_dataframe_to_tf_dataset(
    dataframe,
    label: Optional[str] = None,
    task: Optional[TaskType] = Task.CLASSIFICATION,
    max_num_classes: Optional[int] = 100,
    in_place: Optional[bool] = False,
    fix_feature_names: Optional[bool] = True,
    weight: Optional[str] = None,
    batch_size: Optional[int] = 1000) -> tf.data.Dataset:
  """Converts a Panda Dataframe into a TF Dataset compatible with Keras.

  Details:
    - Ensures columns have uniform types.
    - If "label" is provided, separate it as a second channel in the tf.Dataset
      (as expected by Keras).
    - If "weight" is provided, separate it as a third channel in the tf.Dataset
      (as expected by Keras).
    - If "task" is provided, ensure the correct dtype of the label. If the task
      a classification and the label a string, integerize the labels. In this
      case, the label values are extracted from the dataset and ordered
      lexicographically. Warning: This logic won't work as expected if the
      training and testing dataset contains different label values. In such
      case, it is preferable to convert the label to integers beforehand while
      making sure the same encoding is used for all the datasets. If "
    - Returns "tf.data.from_tensor_slices"

  Args:
    dataframe: Pandas dataframe containing a training or evaluation dataset.
    label: Name of the label column.
    task: Target task of the dataset.
    max_num_classes: Maximum number of classes for a classification task. A high
      number of unique value / classes might indicate that the problem is a
      regression or a ranking instead of a classification. Set to None to
      disable checking the number of classes.
    in_place: If false (default), the input `dataframe` will not be modified by
      `pd_dataframe_to_tf_dataset`. However, a copy of the dataset memory will
      be made. If true, the dataframe will be modified in place.
    fix_feature_names: Some feature names are not supported by the SavedModel
      signature. If `fix_feature_names=True` (default) the feature will be
      renamed and made compatible. If `fix_feature_names=False`, the feature
      name will not be changed, but exporting the model might fail (i.e.
      `model.save(...)`).
    weight: Optional name of a column in `dataframe` to use to weight the
      training.
    batch_size: Number of examples in each batch. The size of the batches has no
      impact on the TF-DF training algorithms. However, a small batch size can
      lead to a large overhead when loading the dataset. Defaults to 1000, but
      if `batch_size` is set to `None`, no batching is applied. Note: TF-DF
        expects for the dataset to be batched.

  Returns:
    A TensorFlow Dataset.
  """

  if not in_place:
    dataframe = dataframe.copy(deep=True)

  if label is not None:

    if label not in dataframe.columns:
      raise ValueError(
          f"The label \"{label}\" is not a column of the dataframe.")

    if task == Task.CLASSIFICATION:

      classification_classes = list(dataframe[label].unique())
      if len(classification_classes) > max_num_classes:
        raise ValueError(
            f"The number of unique classes ({len(classification_classes)}) "
            f"exceeds max_num_classes ({max_num_classes}). A high number of "
            "unique value / classes might indicate that the problem is a "
            "regression or a ranking instead of a classification. If this "
            "problem is effectively a classification problem, increase "
            "`max_num_classes`.")

      if dataframe[label].dtypes in [str, object]:
        classification_classes.sort()
        dataframe[label] = dataframe[label].map(classification_classes.index)

      elif dataframe[label].dtypes in [int, float]:
        if (dataframe[label] < 0).any():
          raise ValueError(
              "Negative integer classification label found. Make sure "
              "you label values are positive or stored as string.")

  if weight is not None:
    if weight not in dataframe.columns:
      raise ValueError(
          f"The weight \"{weight}\" is not a column of the dataframe.")

  if fix_feature_names:
    # Rename the features so they are compatible with SaveModel serving
    # signatures.
    rename_mapping = {}
    new_names = set()
    change_any_feature_name = False
    for column in dataframe:
      new_name = column
      for forbidden_character in _FORBIDDEN_FEATURE_CHARACTERS:
        if forbidden_character in new_name:
          change_any_feature_name = True
          new_name = new_name.replace(forbidden_character, "_")
      # Add a tailing "_" until there are not feature name collisions.
      while new_name in new_names:
        new_name += "_"
        change_any_feature_name = True

      rename_mapping[column] = new_name
      new_names.add(new_name)

    dataframe = dataframe.rename(columns=rename_mapping)
    if change_any_feature_name:
      tf_logging.warning(
          "Some of the feature names have been changed automatically to be "
          "compatible with SavedModels because fix_feature_names=True.")

  # Make sure that missing values for string columns are not represented as
  # float(NaN).
  for col in dataframe.columns:
    if dataframe[col].dtype in [str, object]:
      dataframe[col] = dataframe[col].fillna("")

  if label is not None:
    features_dataframe = dataframe.drop(label, 1)

    if weight is not None:
      features_dataframe = features_dataframe.drop(weight, 1)
      output = (dict(features_dataframe), dataframe[label].values,
                dataframe[weight].values)
    else:
      output = (dict(features_dataframe), dataframe[label].values)

    tf_dataset = tf.data.Dataset.from_tensor_slices(output)

  else:
    if weight is not None:
      raise ValueError(
          "\"weight\" is only supported if the \"label\" is also provided")
    tf_dataset = tf.data.Dataset.from_tensor_slices(dict(dataframe))

  # The batch size does not impact the training of TF-DF.
  if batch_size is not None:
    tf_dataset = tf_dataset.batch(batch_size)

  # Seems to provide a small (measured as ~4% on a 32k rows dataset) speed-up.
  tf_dataset = tf_dataset.prefetch(tf.data.AUTOTUNE)

  setattr(tf_dataset, "_tfdf_task", task)
  return tf_dataset