fn canonicalize_plus_inner()

in amzn-smt-eager-arithmetic/src/canonicalize.rs [59:165]


    fn canonicalize_plus_inner(
        args: impl IntoIterator<Item = Term>,
    ) -> (SmallVec<[Term; 2]>, Numeral, Numeral) {
        use ArithOp::*;
        use Op::*;

        fn group_same_vars(terms: impl IntoIterator<Item = Term>) -> SmallVec<[Term; 2]> {
            let mut map: BTreeMap<HashOrdered<_>, (Numeral, Numeral)> = BTreeMap::new();
            let mut res = smallvec![];

            for t in terms {
                match t {
                    Term::Variable(v) => {
                        let (pos, _) = map.entry(HashOrdered(v)).or_default();
                        *pos += Numeral::from(1u8);
                    }
                    Term::OtherOp(op) => match op.as_ref() {
                        Arith(Neg(t)) => {
                            if let Term::Variable(v) = t {
                                let (_, neg) = map.entry(HashOrdered(v.clone())).or_default();
                                *neg += Numeral::from(1u8);
                            } else {
                                res.push(Neg(t.clone()).into());
                            }
                        }
                        Arith(Mul(args)) => match args.as_slice() {
                            [Term::Constant(c), Term::Variable(v)] => match c.as_ref() {
                                Constant::Numeral(a) => {
                                    let (pos, _) = map.entry(HashOrdered(v.clone())).or_default();
                                    *pos += a;
                                }
                                _ => res.push(op.into()),
                            },
                            [Term::OtherOp(c), Term::Variable(v)] => match c.as_ref() {
                                Arith(Neg(Term::Constant(c))) => match c.as_ref() {
                                    Constant::Numeral(a) => {
                                        let (_, neg) =
                                            map.entry(HashOrdered(v.clone())).or_default();
                                        *neg += a;
                                    }
                                    _ => res.push(op.into()),
                                },
                                _ => res.push(op.into()),
                            },
                            _ => res.push(op.into()),
                        },
                        _ => res.push(op.into()),
                    },
                    _ => res.push(t),
                }
            }

            use std::cmp::Ordering::*;
            res.extend(
                map.into_iter()
                    .map(|(HashOrdered(v), (pos, neg))| match pos.cmp(&neg) {
                        Less => {
                            let x = neg - pos;
                            if x.is_one() {
                                Canonicalizer::negate(v.into())
                            } else {
                                let neg_coeff = Canonicalizer::negate(Constant::Numeral(x).into());
                                Mul([neg_coeff, v.into()].into()).into()
                            }
                        }
                        Greater => {
                            let x = pos - neg;
                            if x.is_one() {
                                v.into()
                            } else {
                                Mul([Constant::Numeral(x).into(), v.into()].into()).into()
                            }
                        }
                        Equal => v.into(),
                    }),
            );
            res
        }

        let mut c_pos = Numeral::zero();
        let mut c_neg = Numeral::zero();
        let args = args
            .into_iter()
            // flatten nested plus operators
            .flat_map(|arg| {
                if let Term::OtherOp(op) = &arg {
                    if let Arith(Plus(args)) = op.as_ref() {
                        return Either::Left(args.clone().into_iter());
                    }
                }
                Either::Right(iter::once(arg))
            })
            // pull out constants
            .filter(|arg| {
                if let Some((x, pos)) = int_term(arg) {
                    if pos {
                        c_pos += x;
                    } else {
                        c_neg += x;
                    }
                    false
                } else {
                    true
                }
            });
        (group_same_vars(args), c_pos, c_neg)
    }