in src/core/symbolic.cpp [913:968]
Expr differentiate(Expr e, Symbol sym) {
if (!e.contains(sym)) {
return Expr(0);
}
if (e == Expr(sym)) {
return Expr(1);
}
if (e.type() == Expr::Type::function) {
if (e.args().size() == 2) {
const auto& a = e.args().at(0);
const auto& b = e.args().at(1);
if (e.op() == Op::add) {
if (a.contains(sym) && !b.contains(sym)) {
return differentiate(a, sym);
} else if (b.contains(sym) && !a.contains(sym)) {
return differentiate(b, sym);
} else {
ASSERT(a.contains(sym) && b.contains(sym));
return differentiate(a, sym) + differentiate(b, sym);
}
} else if (e.op() == Op::multiply) {
if (a.contains(sym) && !b.contains(sym)) {
return differentiate(a, sym) * b;
} else if (b.contains(sym) && !a.contains(sym)) {
return differentiate(b, sym) * a;
} else {
ASSERT(a.contains(sym) && b.contains(sym));
return differentiate(a, sym) * b + differentiate(b, sym) * a;
}
} else if (e.op() == Op::divide) {
if (a.contains(sym) && !b.contains(sym)) {
return differentiate(a, sym) / b;
} else if (b.contains(sym) && !a.contains(sym)) {
return a * differentiate(b, sym) / (b * b);
} else {
ASSERT(a.contains(sym) && b.contains(sym));
return (differentiate(a, sym) * b - a * differentiate(b, sym)) /
(b * b);
}
}
} else if (e.args().size() == 1) {
const auto& arg = e.args().at(0);
if (e.op() == Op::negate) {
return -differentiate(arg, sym);
} else if (e.op() == Op::reciprocal) {
return differentiate(arg, sym) / (arg * arg);
}
}
}
ASSERT(0) << "Cannot differentiate " << e.dump() << " with respect to "
<< sym.name();
return Expr(0);
}