void ball_tree_neighbors::train()

in src/toolkits/nearest_neighbors/ball_tree_neighbors.cpp [71:414]


void ball_tree_neighbors::train(const sframe& X,
                                const std::vector<flexible_type>& ref_labels,
                                const std::vector<dist_component_type>& composite_distance_params,
                                const std::map<std::string, flexible_type>& opts) {

  logprogress_stream << "Starting ball tree nearest neighbors model training." << std::endl;

  timer t;
  double start_time = t.current_time();

  // Validate the inputs.
  init_options(opts);
  validate_distance_components(composite_distance_params, X);

  populate_distance_for_summary_struct(composite_distance_params);

  // Create the ml_data object for the reference data.
  initialize_model_data(X, ref_labels);

  // Initialize the distance components. NOTE: this needs data to be initialized
  // first because the row slicers need the column indices to be sorted.
  initialize_distances();

  ASSERT_FALSE(composite_distances.empty());
  dist_component c = composite_distances[0];


  if (metadata->num_dimensions() > 100) {
    logprogress_stream << "\nWARNING: The computational advantage of the "
      << "ball tree tends to diminish as the number of variables grows. With more "
      << "than 100 variables, the ball tree may not be optimal for this dataset."
      << std::endl;
  }


  // Figure out leaf size if the user didn't set it
  size_t leaf_size = (size_t)options.value("leaf_size");

  if (leaf_size == 0) {
    leaf_size = std::max((size_t)1000, (size_t)ceil((double)num_examples / 2048)); // max tree depth of 12
    options.set_option("leaf_size", leaf_size);
  }

  size_t min_leaves = ceil((double)num_examples / leaf_size);
  tree_depth = ceil(log2(min_leaves)) + 1;
  size_t num_leaves = std::max(size_t(1), size_t(std::pow(2, tree_depth - 1)));
  size_t num_nodes = 2 * num_leaves - 1;

  if (tree_depth > 12) {
    logprogress_stream << "\nWARNING: The ball tree is very large. Consider "
      << "increasing the leaf size to create a smaller tree and improve "
      << "performance." << std::endl;
  }


  // Initialize tree and loop objects
  if (is_dense) {
    pivots.resize(num_nodes);     // pivot observations
  } else {
    pivots_sp.resize(num_nodes);
  }

  node_radii.resize(num_nodes);                      // distance from pivot to furthest node member
  std::vector<double> first_child_radius (num_nodes, 0);  // distance from the pivot to the first child observation
  std::vector<double> median_dist (num_nodes);   // median of distances to the first child point
  double middle_dist;                            // the temporary container for the median distance in a node
  membership.resize(num_examples);               // point membership in nodes
  std::vector<double> pivot_dist(num_examples);  // distance from each point to its pivot (at the lowest tree level)
  std::vector<double> first_child_dist(num_examples);  // distance from each point to the first child (at the lowest tree level)


  // Set the radius and membership to 0 to start
  for (size_t i = 0; i < num_nodes; ++i) {
    node_radii[i] = 0;
  }

  for (size_t i = 0; i < num_examples; ++i) {
    membership[i] = 0;
  }

  size_t num_variables = metadata->num_dimensions();

  // Declare loop variables
  DenseVector p(num_variables);      // dense pivot observation
  DenseVector x(num_variables);      // dense generic observation
  SparseVector x_sp(num_variables);  // sparse pivot observation
  SparseVector p_sp(num_variables);  // sparse query observation

  size_t a;                          // generic row index for a point
  size_t idx_node;                   // index of the current node
  size_t idx_node_start;             // index of the first node on a level
  size_t idx_node_end;               // index of the last node on a level
  size_t num_level_nodes;            // number of nodes on a level

  // Switch for maintaining balance in the nodes. If a point is exactly on the
  // median this toggle indicates which child node to assign it to.
  bool first_child_median_flag = true;


  // Choose the first pivot
  // NOTE: for now this will be the first row of the reference data, but this
  // should probably be chosen randomly.

  if (is_dense) {
    mld_ref.get_iterator().fill_observation(x);
    pivots[0] = x;
  } else {
    mld_ref.get_iterator().fill_observation(x_sp);
    pivots_sp[0] = x_sp;
  }


  table_printer table( {{"Tree level", 0}, {"Elapsed Time", 0} });
  table.print_header();

  // The main loop over levels of the tree
  // NOTE: the second-to-last tree level creates the leaves, so the loop should
  // end at tree_depth - 1.
  for (size_t tree_level = 0; tree_level < (tree_depth - 1); ++tree_level) {

    if (cppipc::must_cancel()) {
      log_and_throw("Toolkit cancelled by user.");
    }

    // Get the node indices for nodes on the current level
    idx_node_start = std::pow(2, tree_level) - 1;
    idx_node_end = std::pow(2, (tree_level + 1)) - 2;
    num_level_nodes = idx_node_end - idx_node_start + 1;


    // First pass over the data
    for (auto it = mld_ref.get_iterator(); !it.done(); ++it) {

      // Get the required data
      a = it.row_index();
      idx_node = membership[a];

      if (is_dense) {
        p = pivots[idx_node];
        it.fill_observation(x);
        pivot_dist[a] = c.distance->distance(x, p);

        // find the largest distance to the pivot and index of the point
        if (pivot_dist[a] >= node_radii[idx_node]) {
          node_radii[idx_node] = pivot_dist[a];
          pivots[2 * idx_node + 1] = x;
        }

      } else {  // data is not dense
        p_sp = pivots_sp[idx_node];
        it.fill_observation(x_sp);
        pivot_dist[a] = c.distance->distance(x_sp, p_sp);

        // find the largest distance to the pivot and index of the point
        if (pivot_dist[a] >= node_radii[idx_node]) {
          node_radii[idx_node] = pivot_dist[a];
          pivots_sp[2 * idx_node + 1] = x_sp;
        }
      }
    }


    // Create vector of vectors to store the first child distances contiguously
    // for each node.
    std::vector<std::vector<double>> node_dists(num_level_nodes);


    // Second pass over the data
    for (auto it = mld_ref.get_iterator(); !it.done(); ++it) {

      // Get the required data
      a = it.row_index();
      idx_node = membership[a];

      if (is_dense) {
        p = pivots[2 * idx_node + 1];
        it.fill_observation(x);

        // Find all of the distances to the first child and pick the second child
        // as the point furthest away.
        first_child_dist[a] = c.distance->distance(x, p);

        if (first_child_dist[a] >= first_child_radius[idx_node]) {
          first_child_radius[idx_node] = first_child_dist[a];
          pivots[2 * idx_node + 2] = x;
        }

      } else { // data is not dense
        p_sp = pivots_sp[2 * idx_node + 1];
        it.fill_observation(x_sp);

        // Find all of the distances to the first child and pick the second child
        // as the point furthest away.
        first_child_dist[a] = c.distance->distance(x_sp, p_sp);

        if (first_child_dist[a] >= first_child_radius[idx_node]) {
          first_child_radius[idx_node] = first_child_dist[a];
          pivots_sp[2 * idx_node + 2] = x_sp;
        }
      }

      // Keep the first child distances compiled by node for median computation
      node_dists[idx_node - idx_node_start].push_back(first_child_dist[a]);
    }


    // Find the median first child distance for each node
    for (size_t j = 0; j < num_level_nodes; ++j) {
      if (node_dists[j].size() > 1) {

        std::nth_element(node_dists[j].begin(),
                         node_dists[j].begin() + node_dists[j].size()/2,
                         node_dists[j].end());
        middle_dist = node_dists[j][node_dists[j].size()/2];

        // if there are an even number of elements get the median of the middle two
        if (node_dists[j].size() % 2 == 0) {
          std::nth_element(node_dists[j].begin(),
                           node_dists[j].begin() + node_dists[j].size()/2 - 1,
                           node_dists[j].end());
          middle_dist = (middle_dist + node_dists[j][node_dists[j].size()/2 - 1]) / 2;
        }

        median_dist[j + idx_node_start] = middle_dist;

      } else {
        // set median distance to -1 so that singletons always go to second child
        median_dist[j + idx_node_start] = -1;
      }
    }


    // Third pass over the data
    // - assign each point to a child
    // - careful about maintaining balance here
    for (size_t b = 0; b < num_examples; ++b) {
      idx_node = membership[b];
      if (first_child_dist[b] < median_dist[idx_node]) {
        membership[b] = 2 * idx_node + 1;

      } else if (first_child_dist[b] > median_dist[idx_node]) {
        membership[b] = 2 * idx_node + 2;

      } else {  // the point is exactly on the median
          if (first_child_median_flag) {
            membership[b] = 2 * idx_node + 1;
            first_child_median_flag = false;

          } else {
            membership[b] = 2 * idx_node + 2;
            first_child_median_flag = true;
          }
      }
    }


    table.print_row(tree_level, progress_time());

  } // end loop over tree levels


  // Find the radii for each of the leaf nodes
  for (auto it = mld_ref.get_iterator(); !it.done(); ++it) {

    // Get the required data
    a = it.row_index();
    idx_node = membership[a];

    if (is_dense) {
      // Find the largest distance to the pivot and index of that point
      p = pivots[idx_node];
      it.fill_observation(x);
      pivot_dist[a] = c.distance->distance(x, p);

    } else { // data is not dense
      // Find the largest distance to the pivot and index of that point
      p_sp = pivots_sp[idx_node];
      it.fill_observation(x_sp);
      pivot_dist[a] = c.distance->distance(x_sp, p_sp);
    }

    if (pivot_dist[a] >= node_radii[idx_node]) {
      node_radii[idx_node] = pivot_dist[a];
    }
  }

  table.print_row(tree_depth - 1, progress_time());



  // Group the reference data by leaf node ID

  // convert the reference labels to an SArray.
  std::shared_ptr<sarray<flexible_type>> sa_ref_labels(new sarray<flexible_type>);
  sa_ref_labels->open_for_write();
  flex_type_enum ref_label_type = reference_labels[0].get_type();
  sa_ref_labels->set_type(ref_label_type);
  turi::copy(ref_labels.begin(), ref_labels.end(), *sa_ref_labels);
  sa_ref_labels->close();

  // convert membership into a shared pointer to an sarray
  std::shared_ptr<sarray<flexible_type>> member_column(new sarray<flexible_type>);
  member_column->open_for_write();
  member_column->set_type(flex_type_enum::INTEGER);
  turi::copy(membership.begin(), membership.end(), *member_column);
  member_column->close();

  // add the membership sarray as a column to the reference data and group
  sframe sf_refs = X.add_column(sa_ref_labels, "__nearest_neighbors_ref_label");
  sf_refs = sf_refs.add_column(member_column, "__nearest_neighbors_membership");
  sf_refs = turi::group(sf_refs, "__nearest_neighbors_membership");

  // extract the grouped membership vector and remove from the dataset.
  auto member_reader = sf_refs.select_column("__nearest_neighbors_membership")->get_reader();
  std::vector<flexible_type> temp(num_examples);
  member_reader->read_rows(0, num_examples, temp);
  std::copy(temp.begin(), temp.end(), membership.begin());

  size_t idx_member_column = sf_refs.column_index("__nearest_neighbors_membership");
  sf_refs = sf_refs.remove_column(idx_member_column);

  // extract the map of grouped row indices from the dataset.
  auto label_reader = sf_refs.select_column("__nearest_neighbors_ref_label")->get_reader();
  std::vector<flexible_type> temp2(num_examples);
  label_reader->read_rows(0, num_examples, temp2);

  // this modifies the model's stored reference labels, *not* the vector passed to this function.
  std::copy(temp2.begin(), temp2.end(), reference_labels.begin());

  size_t idx_label_column = sf_refs.column_index("__nearest_neighbors_ref_label");
  sf_refs = sf_refs.remove_column(idx_label_column);


  // Re-make the ML data with the row-permuted data for storage in the model
  mld_ref = v2::ml_data(metadata);
  mld_ref.fill(sf_refs);


  add_or_update_state({ {"method", "ball_tree"},
                        {"tree_depth", tree_depth},
                        {"leaf_size", leaf_size},
                        {"training_time", t.current_time() - start_time} });
  table.print_footer();
}  // end the create function