in gbdt-rs/src/decision_tree.rs [1566:1680]
fn get_impurity(
train_data: &[usize],
feature_index: usize,
value: &mut ValueType,
impurity: &mut f64,
cache: &TrainingCache,
impurity_cache: &mut ImpurityCache,
sub_cache: &SubCache,
) -> Vec<(usize, ValueType)> {
*impurity = std::f64::MAX;
*value = VALUE_TYPE_UNKNOWN;
// Sort the samples with the feature value
let sorted_data = cache.sort_with_bool_vec(
feature_index,
false,
&impurity_cache.bool_vec,
train_data.len(),
sub_cache,
);
let mut unknown: usize = 0;
let mut s: f64 = 0.0;
let mut ss: f64 = 0.0;
let mut c: f64 = 0.0;
for pair in sorted_data.iter() {
let (index, feature_value) = *pair;
if feature_value == VALUE_TYPE_UNKNOWN {
let cv: &CacheValue = &cache.cache_value[index];
s += cv.s;
ss += cv.ss;
c += cv.c;
unknown += 1;
} else {
break;
}
}
if unknown == sorted_data.len() {
return sorted_data;
}
let mut fitness0 = if c > 1.0 { ss - s * s / c } else { 0.0 };
if fitness0 < 0.0 {
fitness0 = 0.0;
}
if !impurity_cache.cached {
impurity_cache.sum_s = 0.0;
impurity_cache.sum_ss = 0.0;
impurity_cache.sum_c = 0.0;
for index in train_data.iter() {
let cv: &CacheValue = &cache.cache_value[*index];
impurity_cache.sum_s += cv.s;
impurity_cache.sum_ss += cv.ss;
impurity_cache.sum_c += cv.c;
}
}
s = impurity_cache.sum_s - s;
ss = impurity_cache.sum_ss - ss;
c = impurity_cache.sum_c - c;
let _fitness00: f64 = if c > 1.0 { ss - s * s / c } else { 0.0 };
let mut ls: f64 = 0.0;
let mut lss: f64 = 0.0;
let mut lc: f64 = 0.0;
let mut rs: f64 = s;
let mut rss: f64 = ss;
let mut rc: f64 = c;
for i in unknown..(sorted_data.len() - 1) {
let (index, feature_value) = sorted_data[i];
let (_next_index, next_value) = sorted_data[i + 1];
let cv: &CacheValue = &cache.cache_value[index];
s = cv.s;
ss = cv.ss;
c = cv.c;
ls += s;
lss += ss;
lc += c;
rs -= s;
rss -= ss;
rc -= c;
let f1: ValueType = feature_value;
let f2: ValueType = next_value;
if almost_equal(f1, f2) {
continue;
}
let mut fitness1: f64 = if lc > 1.0 { lss - ls * ls / lc } else { 0.0 };
if fitness1 < 0.0 {
fitness1 = 0.0;
}
let mut fitness2: f64 = if rc > 1.0 { rss - rs * rs / rc } else { 0.0 };
if fitness2 < 0.0 {
fitness2 = 0.0;
}
let fitness: f64 = fitness0 + fitness1 + fitness2;
if *impurity > fitness {
*impurity = fitness;
*value = (f1 + f2) / 2.0;
}
}
sorted_data
}