in torch/csrc/jit/passes/guard_elimination.cpp [283:454]
bool removableGuard(Node* n) {
const static auto no_exceptions = std::unordered_set<size_t>{};
switch (n->kind()) {
case aten::add:
case aten::add_:
case aten::sub:
case aten::mul:
case aten::div:
case aten::t:
case aten::sigmoid:
case aten::sin:
case aten::cos:
case aten::tan:
case aten::sinh:
case aten::cosh:
case aten::tanh:
case aten::asin:
case aten::acos:
case aten::atan:
case aten::atan2:
case aten::floor:
case aten::fmod:
case aten::ceil:
case aten::trunc:
case aten::sqrt:
case aten::rsqrt:
case aten::remainder:
case aten::mm:
case aten::min:
case aten::max:
case aten::type_as:
case aten::ge:
case aten::gt:
case aten::lt:
case aten::le:
case aten::eq:
case aten::ne:
case aten::neg:
case prim::ConstantChunk:
case aten::size:
case aten::abs:
case aten::sign:
case aten::pow:
case aten::relu:
case aten::threshold:
case prim::AutogradAdd:
case prim::AutogradZero:
case aten::rand_like:
case aten::erf:
case aten::erfc:
case aten::exp:
case aten::expm1:
case aten::log:
case aten::log2:
case aten::log10:
case aten::frac:
case aten::lerp:
case aten::lgamma:
case aten::reciprocal:
case aten::addcmul:
case aten::where:
case aten::_cast_Float:
case aten::_cast_Long:
case aten::__and__:
case aten::__or__:
case aten::__xor__:
case aten::__lshift__:
case aten::__rshift__:
case aten::bitwise_not:
case aten::bitwise_and:
case aten::bitwise_or:
case aten::bitwise_xor:
return checkInputs(n, no_exceptions, true);
case aten::softmax:
return checkInputs(n, std::unordered_set<size_t>{1}, true);
case aten::multinomial:
return checkInputs(n, std::unordered_set<size_t>{2, 3}, false);
case aten::flatten:
case aten::argmax:
case aten::squeeze:
case aten::avg_pool2d:
return checkInputs(n, no_exceptions, false);
case aten::conv1d:
case aten::conv2d:
case aten::conv3d:
return checkInputs(n, std::unordered_set<size_t>{2, 6}, false);
case aten::slice:
return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
// check that the dimension argument is constant
n->input(1)->node()->kind() == prim::Constant &&
// the start offset is constant
n->input(2)->node()->kind() == prim::Constant &&
// the end offset is constant
n->input(3)->node()->kind() == prim::Constant &&
// the stride is constant
n->input(4)->node()->kind() == prim::Constant;
case aten::max_pool1d:
case aten::max_pool2d:
case aten::max_pool3d:
return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
// check that the kernel size is constant
n->input(1)->node()->kind() == prim::Constant &&
// check that the stride is constant
n->input(2)->node()->kind() == prim::Constant &&
// check that the padding is constant
n->input(3)->node()->kind() == prim::Constant &&
// check that the dilation is constant
n->input(4)->node()->kind() == prim::Constant &&
// check that the ceil_mode is constant
n->input(5)->node()->kind() == prim::Constant;
case aten::unsqueeze:
// check that the dimension argument is constant
return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
n->input(1)->node()->kind() == prim::Constant;
case aten::cat:
// check that the dimension argument is constant
return n->input(1)->node()->kind() == prim::Constant &&
n->input(0)->node()->kind() == prim::ListConstruct &&
// no extra nodes in between aten::cat and prim::ListConstruct
n->prev() == n->input(0)->node() &&
// check the inputs to prim::ListConstruct (not aten::cat)
checkInputs(n->input(0)->node(), no_exceptions, false);
case aten::clamp:
// the second and third args do not affect shapes
return checkInputs(n, std::unordered_set<size_t>{1, 2}, false);
// after some optimizations we might end up with two Guards back-to-back
// which case we can remove the one whose input is also prim::Guard
case aten::_grad_sum_to_size:
// skip checking size argument
if (checkInputs(n, std::unordered_set<size_t>{1}, false)) {
auto asize = n->input(1)->node();
if (asize->kind() == prim::Constant) {
return true;
} else if (asize->matches("aten::size(Tensor self) -> int[]")) {
// aten::size is effectively a constant
if (asize->input()
->type()
->expectRef<TensorType>()
.sizes()
.concrete_sizes()) {
return true;
}
}
}
return false;
// this is checked by one of the tests in test_jit_fuser.py
case prim::ListUnpack: {
// check if the input is a constant chunk
// used for LSTM fusions
auto chunk = n->input(0)->node();
if (chunk->kind() != aten::chunk) {
return false;
}
return checkInputs(chunk, no_exceptions, false);
}
// this is checked by one of the tests in test_jit_fuser.py
case aten::broadcast_tensors: {
auto list_construct = n->input(0)->node();
if (list_construct->kind() != prim::ListConstruct) {
return false;
}
return checkInputs(list_construct, no_exceptions, false);
}
case prim::Guard:
case prim::GradOf:
return true;
default:
GRAPH_DEBUG("cannot remove ", n->kind().toQualString());
return false;
}
}