Expr differentiate()

in src/core/symbolic.cpp [913:968]


Expr differentiate(Expr e, Symbol sym) {
  if (!e.contains(sym)) {
    return Expr(0);
  }

  if (e == Expr(sym)) {
    return Expr(1);
  }

  if (e.type() == Expr::Type::function) {
    if (e.args().size() == 2) {
      const auto& a = e.args().at(0);
      const auto& b = e.args().at(1);
      if (e.op() == Op::add) {
        if (a.contains(sym) && !b.contains(sym)) {
          return differentiate(a, sym);
        } else if (b.contains(sym) && !a.contains(sym)) {
          return differentiate(b, sym);
        } else {
          ASSERT(a.contains(sym) && b.contains(sym));
          return differentiate(a, sym) + differentiate(b, sym);
        }
      } else if (e.op() == Op::multiply) {
        if (a.contains(sym) && !b.contains(sym)) {
          return differentiate(a, sym) * b;
        } else if (b.contains(sym) && !a.contains(sym)) {
          return differentiate(b, sym) * a;
        } else {
          ASSERT(a.contains(sym) && b.contains(sym));
          return differentiate(a, sym) * b + differentiate(b, sym) * a;
        }
      } else if (e.op() == Op::divide) {
        if (a.contains(sym) && !b.contains(sym)) {
          return differentiate(a, sym) / b;
        } else if (b.contains(sym) && !a.contains(sym)) {
          return a * differentiate(b, sym) / (b * b);
        } else {
          ASSERT(a.contains(sym) && b.contains(sym));
          return (differentiate(a, sym) * b - a * differentiate(b, sym)) /
                 (b * b);
        }
      }
    } else if (e.args().size() == 1) {
      const auto& arg = e.args().at(0);
      if (e.op() == Op::negate) {
        return -differentiate(arg, sym);
      } else if (e.op() == Op::reciprocal) {
        return differentiate(arg, sym) / (arg * arg);
      }
    }
  }

  ASSERT(0) << "Cannot differentiate " << e.dump() << " with respect to "
            << sym.name();
  return Expr(0);
}