bool removableGuard()

in torch/csrc/jit/passes/guard_elimination.cpp [283:454]


  bool removableGuard(Node* n) {
    const static auto no_exceptions = std::unordered_set<size_t>{};
    switch (n->kind()) {
      case aten::add:
      case aten::add_:
      case aten::sub:
      case aten::mul:
      case aten::div:
      case aten::t:
      case aten::sigmoid:
      case aten::sin:
      case aten::cos:
      case aten::tan:
      case aten::sinh:
      case aten::cosh:
      case aten::tanh:
      case aten::asin:
      case aten::acos:
      case aten::atan:
      case aten::atan2:
      case aten::floor:
      case aten::fmod:
      case aten::ceil:
      case aten::trunc:
      case aten::sqrt:
      case aten::rsqrt:
      case aten::remainder:
      case aten::mm:
      case aten::min:
      case aten::max:
      case aten::type_as:
      case aten::ge:
      case aten::gt:
      case aten::lt:
      case aten::le:
      case aten::eq:
      case aten::ne:
      case aten::neg:
      case prim::ConstantChunk:
      case aten::size:
      case aten::abs:
      case aten::sign:
      case aten::pow:
      case aten::relu:
      case aten::threshold:
      case prim::AutogradAdd:
      case prim::AutogradZero:
      case aten::rand_like:
      case aten::erf:
      case aten::erfc:
      case aten::exp:
      case aten::expm1:
      case aten::log:
      case aten::log2:
      case aten::log10:
      case aten::frac:
      case aten::lerp:
      case aten::lgamma:
      case aten::reciprocal:
      case aten::addcmul:
      case aten::where:
      case aten::_cast_Float:
      case aten::_cast_Long:
      case aten::__and__:
      case aten::__or__:
      case aten::__xor__:
      case aten::__lshift__:
      case aten::__rshift__:
      case aten::bitwise_not:
      case aten::bitwise_and:
      case aten::bitwise_or:
      case aten::bitwise_xor:
        return checkInputs(n, no_exceptions, true);
      case aten::softmax:
        return checkInputs(n, std::unordered_set<size_t>{1}, true);
      case aten::multinomial:
        return checkInputs(n, std::unordered_set<size_t>{2, 3}, false);
      case aten::flatten:
      case aten::argmax:
      case aten::squeeze:
      case aten::avg_pool2d:
        return checkInputs(n, no_exceptions, false);
      case aten::conv1d:
      case aten::conv2d:
      case aten::conv3d:
        return checkInputs(n, std::unordered_set<size_t>{2, 6}, false);
      case aten::slice:
        return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
            // check that the dimension argument is constant
            n->input(1)->node()->kind() == prim::Constant &&
            // the start offset is constant
            n->input(2)->node()->kind() == prim::Constant &&
            // the end offset is constant
            n->input(3)->node()->kind() == prim::Constant &&
            // the stride is constant
            n->input(4)->node()->kind() == prim::Constant;
      case aten::max_pool1d:
      case aten::max_pool2d:
      case aten::max_pool3d:
        return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
            // check that the kernel size is constant
            n->input(1)->node()->kind() == prim::Constant &&
            // check that the stride is constant
            n->input(2)->node()->kind() == prim::Constant &&
            // check that the padding is constant
            n->input(3)->node()->kind() == prim::Constant &&
            // check that the dilation is constant
            n->input(4)->node()->kind() == prim::Constant &&
            // check that the ceil_mode is constant
            n->input(5)->node()->kind() == prim::Constant;
      case aten::unsqueeze:
        // check that the dimension argument is constant
        return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
            n->input(1)->node()->kind() == prim::Constant;
      case aten::cat:
        // check that the dimension argument is constant
        return n->input(1)->node()->kind() == prim::Constant &&
            n->input(0)->node()->kind() == prim::ListConstruct &&
            // no extra nodes in between aten::cat and prim::ListConstruct
            n->prev() == n->input(0)->node() &&
            // check the inputs to prim::ListConstruct (not aten::cat)
            checkInputs(n->input(0)->node(), no_exceptions, false);
      case aten::clamp:
        // the second and third args do not affect shapes
        return checkInputs(n, std::unordered_set<size_t>{1, 2}, false);
      // after some optimizations we might end up with two Guards back-to-back
      // which case we can remove the one whose input is also prim::Guard
      case aten::_grad_sum_to_size:
        // skip checking size argument
        if (checkInputs(n, std::unordered_set<size_t>{1}, false)) {
          auto asize = n->input(1)->node();
          if (asize->kind() == prim::Constant) {
            return true;
          } else if (asize->matches("aten::size(Tensor self) -> int[]")) {
            // aten::size is effectively a constant
            if (asize->input()
                    ->type()
                    ->expectRef<TensorType>()
                    .sizes()
                    .concrete_sizes()) {
              return true;
            }
          }
        }
        return false;

      // this is checked by one of the tests in test_jit_fuser.py
      case prim::ListUnpack: {
        // check if the input is a constant chunk
        // used for LSTM fusions
        auto chunk = n->input(0)->node();
        if (chunk->kind() != aten::chunk) {
          return false;
        }
        return checkInputs(chunk, no_exceptions, false);
      }
      // this is checked by one of the tests in test_jit_fuser.py
      case aten::broadcast_tensors: {
        auto list_construct = n->input(0)->node();
        if (list_construct->kind() != prim::ListConstruct) {
          return false;
        }
        return checkInputs(list_construct, no_exceptions, false);
      }
      case prim::Guard:
      case prim::GradOf:
        return true;
      default:
        GRAPH_DEBUG("cannot remove ", n->kind().toQualString());
        return false;
    }
  }