fn get_impurity()

in gbdt-rs/src/decision_tree.rs [1566:1680]


    fn get_impurity(
        train_data: &[usize],
        feature_index: usize,
        value: &mut ValueType,
        impurity: &mut f64,
        cache: &TrainingCache,
        impurity_cache: &mut ImpurityCache,
        sub_cache: &SubCache,
    ) -> Vec<(usize, ValueType)> {
        *impurity = std::f64::MAX;
        *value = VALUE_TYPE_UNKNOWN;
        // Sort the samples with the feature value
        let sorted_data = cache.sort_with_bool_vec(
            feature_index,
            false,
            &impurity_cache.bool_vec,
            train_data.len(),
            sub_cache,
        );

        let mut unknown: usize = 0;
        let mut s: f64 = 0.0;
        let mut ss: f64 = 0.0;
        let mut c: f64 = 0.0;

        for pair in sorted_data.iter() {
            let (index, feature_value) = *pair;
            if feature_value == VALUE_TYPE_UNKNOWN {
                let cv: &CacheValue = &cache.cache_value[index];
                s += cv.s;
                ss += cv.ss;
                c += cv.c;
                unknown += 1;
            } else {
                break;
            }
        }

        if unknown == sorted_data.len() {
            return sorted_data;
        }

        let mut fitness0 = if c > 1.0 { ss - s * s / c } else { 0.0 };

        if fitness0 < 0.0 {
            fitness0 = 0.0;
        }

        if !impurity_cache.cached {
            impurity_cache.sum_s = 0.0;
            impurity_cache.sum_ss = 0.0;
            impurity_cache.sum_c = 0.0;
            for index in train_data.iter() {
                let cv: &CacheValue = &cache.cache_value[*index];
                impurity_cache.sum_s += cv.s;
                impurity_cache.sum_ss += cv.ss;
                impurity_cache.sum_c += cv.c;
            }
        }
        s = impurity_cache.sum_s - s;
        ss = impurity_cache.sum_ss - ss;
        c = impurity_cache.sum_c - c;

        let _fitness00: f64 = if c > 1.0 { ss - s * s / c } else { 0.0 };

        let mut ls: f64 = 0.0;
        let mut lss: f64 = 0.0;
        let mut lc: f64 = 0.0;
        let mut rs: f64 = s;
        let mut rss: f64 = ss;
        let mut rc: f64 = c;

        for i in unknown..(sorted_data.len() - 1) {
            let (index, feature_value) = sorted_data[i];
            let (_next_index, next_value) = sorted_data[i + 1];
            let cv: &CacheValue = &cache.cache_value[index];
            s = cv.s;
            ss = cv.ss;
            c = cv.c;

            ls += s;
            lss += ss;
            lc += c;

            rs -= s;
            rss -= ss;
            rc -= c;

            let f1: ValueType = feature_value;
            let f2: ValueType = next_value;

            if almost_equal(f1, f2) {
                continue;
            }

            let mut fitness1: f64 = if lc > 1.0 { lss - ls * ls / lc } else { 0.0 };
            if fitness1 < 0.0 {
                fitness1 = 0.0;
            }

            let mut fitness2: f64 = if rc > 1.0 { rss - rs * rs / rc } else { 0.0 };
            if fitness2 < 0.0 {
                fitness2 = 0.0;
            }

            let fitness: f64 = fitness0 + fitness1 + fitness2;

            if *impurity > fitness {
                *impurity = fitness;
                *value = (f1 + f2) / 2.0;
            }
        }

        sorted_data
    }