Expr Expr::simplify()

in src/core/symbolic.cpp [318:433]


Expr Expr::simplify() const {
  if (type() != Expr::Type::function) {
    return *this;
  }
  auto sorted_args = args();
  for (auto& arg : sorted_args) {
    arg = arg.simplify();
  }
  switch (op()) {
    case Op::add: {
      auto lhs = sorted_args.at(0);
      auto rhs = sorted_args.at(1);
      if (lhs.type() == Expr::Type::value) {
        if (rhs.type() == Expr::Type::value) {
          return Expr(lhs.value() + rhs.value());
        }
        if (lhs.value() == 0) {
          return rhs;
        }
        if (rhs.op() == Op::add) {
          auto rlhs = rhs.args().at(0);
          auto rrhs = rhs.args().at(1);
          if (rlhs.type() == Expr::Type::value) {
            return (Expr(rlhs.value() + lhs.value()) + rrhs).simplify();
          }
        }
      }
      if (rhs.type() == Expr::Type::value) {
        if (rhs.value() == 0) {
          return lhs;
        }
      }
      return Expr(op(), sorted_args);
    }
    case Op::multiply: {
      auto lhs = sorted_args.at(0);
      auto rhs = sorted_args.at(1);
      if (lhs.type() == Expr::Type::value) {
        if (rhs.type() == Expr::Type::value) {
          return Expr(lhs.value() * rhs.value());
        }
        if (lhs.value() == 1) {
          return rhs;
        }
        if (lhs.value() == 0) {
          return Expr(0);
        }
      }
      if (rhs.type() == Expr::Type::value) {
        if (rhs.value() == 0) {
          return Expr(0);
        }
        if (rhs.value() == 1) {
          return lhs;
        }
      }
      return Expr(op(), sorted_args);
    }
    case Op::divide: {
      auto lhs = sorted_args.at(0);
      auto rhs = sorted_args.at(1);
      if (lhs.type() == Expr::Type::value) {
        if (rhs.type() == Expr::Type::value) {
          if (lhs.value() % rhs.value() == 0) {
            return Expr(lhs.value() / rhs.value());
          }
        }
      }
      if (rhs.type() == Expr::Type::value && rhs.value() == 1) {
        return lhs;
      }
      return Expr(op(), sorted_args);
    }
    case Op::max: {
      auto lhs = sorted_args.at(0);
      auto rhs = sorted_args.at(1);
      if (lhs.type() == Expr::Type::value) {
        if (rhs.type() == Expr::Type::value) {
          return Expr(std::max(lhs.value(), rhs.value()));
        }
        if (lhs.value() == std::numeric_limits<decltype(lhs.value())>::min()) {
          return rhs;
        }
      }
      if (lhs == rhs) {
        return lhs;
      }
      return Expr(op(), sorted_args);
    }
    case Op::negate: {
      const auto& arg = sorted_args.at(0);
      if (arg.type() == Expr::Type::value) {
        return Expr(-arg.value());
      }
      if (arg.type() == Expr::Type::function && arg.op() == Op::negate) {
        return arg.args().at(0).simplify();
      }
      if (arg.type() == Expr::Type::function && arg.op() == Op::add) {
        return (-arg.args().at(0) - arg.args().at(1)).simplify();
      }
      return Expr(op(), sorted_args);
    }
    case Op::size: {
      const auto& arg = sorted_args.at(0);
      if (arg.type() == Expr::Type::value) {
        return Expr(arg.value());
      }
      return Expr(op(), sorted_args);
    }
    default: {
      return Expr(op(), sorted_args);
    }
  };
  ASSERT(0);
  return *this;
}