def _train_model()

in clearbox/training.py [0:0]


def _train_model(task: _TrainModelTask) -> _TrainModelResult:
  """Worker function for training the model, basic parallelization unit.

  Args:
    task: Named tuple of the parameters describing the model to train. See the
      docstrings of `_TrainModelTask` class for more info about each field.

  Returns:
    `_TrainModelResult` named tuple encapsulating metrics and feature importance
    scores.
  """
  train_df, valid_df = (
      task.df[~task.valid_mask],
      task.df[task.valid_mask],
  )
  train_x, train_y = (
      train_df[task.features].values,
      train_df[task.target_col].values,
  )
  valid_x, _ = (
      valid_df[task.features].values,
      valid_df[task.target_col].values,
  )
  train_query_arr = train_df[task.query_col].values
  valid_query_arr = valid_df[task.query_col].values
  model = task.model_builder.new()
  model.fit(train_x, train_y, train_query_arr)

  train_score = model.predict(train_x)
  train_rank_y = train_df[task.rank_col].values
  valid_rank_y = valid_df[task.rank_col].values
  valid_score = model.predict(valid_x)
  metric_row = {}
  for metric in task.metrics:
    metric_row[f'train_{metric.name}'] = metric.compute(
        train_query_arr, train_score, train_rank_y
    )
    metric_row[f'valid_{metric.name}'] = metric.compute(
        valid_query_arr, valid_score, valid_rank_y
    )
  assert model.weights is not None
  return _TrainModelResult(
      weights=model.weights,
      metric_row=metric_row,
  )