in src/toolkits/nearest_neighbors/ball_tree_neighbors.cpp [71:414]
void ball_tree_neighbors::train(const sframe& X,
const std::vector<flexible_type>& ref_labels,
const std::vector<dist_component_type>& composite_distance_params,
const std::map<std::string, flexible_type>& opts) {
logprogress_stream << "Starting ball tree nearest neighbors model training." << std::endl;
timer t;
double start_time = t.current_time();
// Validate the inputs.
init_options(opts);
validate_distance_components(composite_distance_params, X);
populate_distance_for_summary_struct(composite_distance_params);
// Create the ml_data object for the reference data.
initialize_model_data(X, ref_labels);
// Initialize the distance components. NOTE: this needs data to be initialized
// first because the row slicers need the column indices to be sorted.
initialize_distances();
ASSERT_FALSE(composite_distances.empty());
dist_component c = composite_distances[0];
if (metadata->num_dimensions() > 100) {
logprogress_stream << "\nWARNING: The computational advantage of the "
<< "ball tree tends to diminish as the number of variables grows. With more "
<< "than 100 variables, the ball tree may not be optimal for this dataset."
<< std::endl;
}
// Figure out leaf size if the user didn't set it
size_t leaf_size = (size_t)options.value("leaf_size");
if (leaf_size == 0) {
leaf_size = std::max((size_t)1000, (size_t)ceil((double)num_examples / 2048)); // max tree depth of 12
options.set_option("leaf_size", leaf_size);
}
size_t min_leaves = ceil((double)num_examples / leaf_size);
tree_depth = ceil(log2(min_leaves)) + 1;
size_t num_leaves = std::max(size_t(1), size_t(std::pow(2, tree_depth - 1)));
size_t num_nodes = 2 * num_leaves - 1;
if (tree_depth > 12) {
logprogress_stream << "\nWARNING: The ball tree is very large. Consider "
<< "increasing the leaf size to create a smaller tree and improve "
<< "performance." << std::endl;
}
// Initialize tree and loop objects
if (is_dense) {
pivots.resize(num_nodes); // pivot observations
} else {
pivots_sp.resize(num_nodes);
}
node_radii.resize(num_nodes); // distance from pivot to furthest node member
std::vector<double> first_child_radius (num_nodes, 0); // distance from the pivot to the first child observation
std::vector<double> median_dist (num_nodes); // median of distances to the first child point
double middle_dist; // the temporary container for the median distance in a node
membership.resize(num_examples); // point membership in nodes
std::vector<double> pivot_dist(num_examples); // distance from each point to its pivot (at the lowest tree level)
std::vector<double> first_child_dist(num_examples); // distance from each point to the first child (at the lowest tree level)
// Set the radius and membership to 0 to start
for (size_t i = 0; i < num_nodes; ++i) {
node_radii[i] = 0;
}
for (size_t i = 0; i < num_examples; ++i) {
membership[i] = 0;
}
size_t num_variables = metadata->num_dimensions();
// Declare loop variables
DenseVector p(num_variables); // dense pivot observation
DenseVector x(num_variables); // dense generic observation
SparseVector x_sp(num_variables); // sparse pivot observation
SparseVector p_sp(num_variables); // sparse query observation
size_t a; // generic row index for a point
size_t idx_node; // index of the current node
size_t idx_node_start; // index of the first node on a level
size_t idx_node_end; // index of the last node on a level
size_t num_level_nodes; // number of nodes on a level
// Switch for maintaining balance in the nodes. If a point is exactly on the
// median this toggle indicates which child node to assign it to.
bool first_child_median_flag = true;
// Choose the first pivot
// NOTE: for now this will be the first row of the reference data, but this
// should probably be chosen randomly.
if (is_dense) {
mld_ref.get_iterator().fill_observation(x);
pivots[0] = x;
} else {
mld_ref.get_iterator().fill_observation(x_sp);
pivots_sp[0] = x_sp;
}
table_printer table( {{"Tree level", 0}, {"Elapsed Time", 0} });
table.print_header();
// The main loop over levels of the tree
// NOTE: the second-to-last tree level creates the leaves, so the loop should
// end at tree_depth - 1.
for (size_t tree_level = 0; tree_level < (tree_depth - 1); ++tree_level) {
if (cppipc::must_cancel()) {
log_and_throw("Toolkit cancelled by user.");
}
// Get the node indices for nodes on the current level
idx_node_start = std::pow(2, tree_level) - 1;
idx_node_end = std::pow(2, (tree_level + 1)) - 2;
num_level_nodes = idx_node_end - idx_node_start + 1;
// First pass over the data
for (auto it = mld_ref.get_iterator(); !it.done(); ++it) {
// Get the required data
a = it.row_index();
idx_node = membership[a];
if (is_dense) {
p = pivots[idx_node];
it.fill_observation(x);
pivot_dist[a] = c.distance->distance(x, p);
// find the largest distance to the pivot and index of the point
if (pivot_dist[a] >= node_radii[idx_node]) {
node_radii[idx_node] = pivot_dist[a];
pivots[2 * idx_node + 1] = x;
}
} else { // data is not dense
p_sp = pivots_sp[idx_node];
it.fill_observation(x_sp);
pivot_dist[a] = c.distance->distance(x_sp, p_sp);
// find the largest distance to the pivot and index of the point
if (pivot_dist[a] >= node_radii[idx_node]) {
node_radii[idx_node] = pivot_dist[a];
pivots_sp[2 * idx_node + 1] = x_sp;
}
}
}
// Create vector of vectors to store the first child distances contiguously
// for each node.
std::vector<std::vector<double>> node_dists(num_level_nodes);
// Second pass over the data
for (auto it = mld_ref.get_iterator(); !it.done(); ++it) {
// Get the required data
a = it.row_index();
idx_node = membership[a];
if (is_dense) {
p = pivots[2 * idx_node + 1];
it.fill_observation(x);
// Find all of the distances to the first child and pick the second child
// as the point furthest away.
first_child_dist[a] = c.distance->distance(x, p);
if (first_child_dist[a] >= first_child_radius[idx_node]) {
first_child_radius[idx_node] = first_child_dist[a];
pivots[2 * idx_node + 2] = x;
}
} else { // data is not dense
p_sp = pivots_sp[2 * idx_node + 1];
it.fill_observation(x_sp);
// Find all of the distances to the first child and pick the second child
// as the point furthest away.
first_child_dist[a] = c.distance->distance(x_sp, p_sp);
if (first_child_dist[a] >= first_child_radius[idx_node]) {
first_child_radius[idx_node] = first_child_dist[a];
pivots_sp[2 * idx_node + 2] = x_sp;
}
}
// Keep the first child distances compiled by node for median computation
node_dists[idx_node - idx_node_start].push_back(first_child_dist[a]);
}
// Find the median first child distance for each node
for (size_t j = 0; j < num_level_nodes; ++j) {
if (node_dists[j].size() > 1) {
std::nth_element(node_dists[j].begin(),
node_dists[j].begin() + node_dists[j].size()/2,
node_dists[j].end());
middle_dist = node_dists[j][node_dists[j].size()/2];
// if there are an even number of elements get the median of the middle two
if (node_dists[j].size() % 2 == 0) {
std::nth_element(node_dists[j].begin(),
node_dists[j].begin() + node_dists[j].size()/2 - 1,
node_dists[j].end());
middle_dist = (middle_dist + node_dists[j][node_dists[j].size()/2 - 1]) / 2;
}
median_dist[j + idx_node_start] = middle_dist;
} else {
// set median distance to -1 so that singletons always go to second child
median_dist[j + idx_node_start] = -1;
}
}
// Third pass over the data
// - assign each point to a child
// - careful about maintaining balance here
for (size_t b = 0; b < num_examples; ++b) {
idx_node = membership[b];
if (first_child_dist[b] < median_dist[idx_node]) {
membership[b] = 2 * idx_node + 1;
} else if (first_child_dist[b] > median_dist[idx_node]) {
membership[b] = 2 * idx_node + 2;
} else { // the point is exactly on the median
if (first_child_median_flag) {
membership[b] = 2 * idx_node + 1;
first_child_median_flag = false;
} else {
membership[b] = 2 * idx_node + 2;
first_child_median_flag = true;
}
}
}
table.print_row(tree_level, progress_time());
} // end loop over tree levels
// Find the radii for each of the leaf nodes
for (auto it = mld_ref.get_iterator(); !it.done(); ++it) {
// Get the required data
a = it.row_index();
idx_node = membership[a];
if (is_dense) {
// Find the largest distance to the pivot and index of that point
p = pivots[idx_node];
it.fill_observation(x);
pivot_dist[a] = c.distance->distance(x, p);
} else { // data is not dense
// Find the largest distance to the pivot and index of that point
p_sp = pivots_sp[idx_node];
it.fill_observation(x_sp);
pivot_dist[a] = c.distance->distance(x_sp, p_sp);
}
if (pivot_dist[a] >= node_radii[idx_node]) {
node_radii[idx_node] = pivot_dist[a];
}
}
table.print_row(tree_depth - 1, progress_time());
// Group the reference data by leaf node ID
// convert the reference labels to an SArray.
std::shared_ptr<sarray<flexible_type>> sa_ref_labels(new sarray<flexible_type>);
sa_ref_labels->open_for_write();
flex_type_enum ref_label_type = reference_labels[0].get_type();
sa_ref_labels->set_type(ref_label_type);
turi::copy(ref_labels.begin(), ref_labels.end(), *sa_ref_labels);
sa_ref_labels->close();
// convert membership into a shared pointer to an sarray
std::shared_ptr<sarray<flexible_type>> member_column(new sarray<flexible_type>);
member_column->open_for_write();
member_column->set_type(flex_type_enum::INTEGER);
turi::copy(membership.begin(), membership.end(), *member_column);
member_column->close();
// add the membership sarray as a column to the reference data and group
sframe sf_refs = X.add_column(sa_ref_labels, "__nearest_neighbors_ref_label");
sf_refs = sf_refs.add_column(member_column, "__nearest_neighbors_membership");
sf_refs = turi::group(sf_refs, "__nearest_neighbors_membership");
// extract the grouped membership vector and remove from the dataset.
auto member_reader = sf_refs.select_column("__nearest_neighbors_membership")->get_reader();
std::vector<flexible_type> temp(num_examples);
member_reader->read_rows(0, num_examples, temp);
std::copy(temp.begin(), temp.end(), membership.begin());
size_t idx_member_column = sf_refs.column_index("__nearest_neighbors_membership");
sf_refs = sf_refs.remove_column(idx_member_column);
// extract the map of grouped row indices from the dataset.
auto label_reader = sf_refs.select_column("__nearest_neighbors_ref_label")->get_reader();
std::vector<flexible_type> temp2(num_examples);
label_reader->read_rows(0, num_examples, temp2);
// this modifies the model's stored reference labels, *not* the vector passed to this function.
std::copy(temp2.begin(), temp2.end(), reference_labels.begin());
size_t idx_label_column = sf_refs.column_index("__nearest_neighbors_ref_label");
sf_refs = sf_refs.remove_column(idx_label_column);
// Re-make the ML data with the row-permuted data for storage in the model
mld_ref = v2::ml_data(metadata);
mld_ref.fill(sf_refs);
add_or_update_state({ {"method", "ball_tree"},
{"tree_depth", tree_depth},
{"leaf_size", leaf_size},
{"training_time", t.current_time() - start_time} });
table.print_footer();
} // end the create function