void handleBlock()

in torch/csrc/jit/passes/autocast.cpp [230:451]


void handleBlock(Block* block, AutocastContext initial_state) {
  std::stack<AutocastScope> autocast_stack;

  c10::optional<bool> incompatible_amp = c10::nullopt;

  // The current autocast enabled/disabled state
  auto current_state = [&] {
    return autocast_stack.empty() ? initial_state
                                  : autocast_stack.top().context;
  };

  for (Node* node : block->nodes()) {
    switch (node->kind()) {
      case prim::CallFunction:
        // TODO: limit it only to amp related node;
        if (current_state() == initial_state) {
          // if the current autocasting state is the same as the global state,
          // then autocasting will be done correctly on subsequent method and
          // function calls
          if (current_state()) {
            castTensorInputs(
                node, aten::_autocast_to_full_precision, current_state());
          }
          break;
        }
        TORCH_INTERNAL_ASSERT(
            !incompatible_amp.has_value() || incompatible_amp.value(),
            "Calls are not expected with AMP & JIT");
        incompatible_amp = true;
        break;

      case prim::CallMethod:
        // TODO: limit it only to amp related node;
        if (current_state() == initial_state) {
          // if the current autocasting state is the same as the global state,
          // then autocasting will be done correctly on subsequent method and
          // function calls
          if (current_state()) {
            castTensorInputs(
                node, aten::_autocast_to_full_precision, current_state());
          }
          break;
        }
        if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
          const auto& name = node->s(attr::name);
          const auto& function = class_type->getMethod(name);
          if (!function.isGraphFunction()) {
            TORCH_INTERNAL_ASSERT(
                !incompatible_amp.has_value() || incompatible_amp.value(),
                "Calls are not expected with AMP & JIT");
            incompatible_amp = true;
          }
        } else {
          TORCH_INTERNAL_ASSERT(
              !incompatible_amp.has_value() || incompatible_amp.value(),
              "Unexpected prim::CallMethod form with AMP & JIT");
          incompatible_amp = true;
        }
        break;

      case prim::Enter:
        if (auto autocast_scope =
                parseAutocast(node->input(), current_state())) {
          if (node->hasUses()) {
            // TODO: better error message
            AT_ERROR("`with autocast() as ...` is not supported");
          }
          TORCH_INTERNAL_ASSERT(
              !incompatible_amp.has_value() || !incompatible_amp.value(),
              "Unsupported case by AMP & JIT");
          incompatible_amp = false;
          autocast_stack.push(*autocast_scope);
        }
        break;

      case prim::Exit:
        if (isAutocastNode(node->input(0))) {
          TORCH_INTERNAL_ASSERT(!autocast_stack.empty());
          TORCH_INTERNAL_ASSERT(autocast_stack.top().instance == node->input());
          TORCH_INTERNAL_ASSERT(
              !incompatible_amp.has_value() || !incompatible_amp.value(),
              "Unsupported case by AMP & JIT");
          incompatible_amp = false;
          autocast_stack.pop();
        }
        break;

      // CastPolicy::fp16 (cast all inputs to float16)
      case aten::_convolution:
      case aten::conv1d:
      case aten::conv2d:
      case aten::conv3d:
      case aten::conv_tbc:
      case aten::conv_transpose1d:
      case aten::convolution:
      case aten::cudnn_convolution:
      case aten::cudnn_convolution_transpose:
      case aten::prelu:
      case aten::addmm:
      case aten::addmv:
      case aten::addr:
      case aten::matmul:
      case aten::mm:
      case aten::mv:
      case aten::linear:
      case aten::addbmm:
      case aten::baddbmm:
      case aten::bmm:
      case aten::chain_matmul:
      case aten::_thnn_fused_lstm_cell:
      case aten::_thnn_fused_gru_cell:
      case aten::lstm_cell:
      case aten::gru_cell:
      case aten::rnn_tanh_cell:
      case aten::rnn_relu_cell:
        if (!node->schema().is_mutable()) {
          castTensorInputs(
              node, aten::_autocast_to_reduced_precision, current_state());
        }
        break;

      // CastPolicy::fp32 (cast all inputs to float32)
      case aten::native_layer_norm:
      case aten::acos:
      case aten::asin:
      case aten::cosh:
      case aten::erfinv:
      case aten::exp:
      case aten::expm1:
      case aten::log:
      case aten::log10:
      case aten::log2:
      case aten::log1p:
      case aten::reciprocal:
      case aten::rsqrt:
      case aten::sinh:
      case aten::tan:
      case aten::pow:
      case aten::softplus:
      case aten::gelu:
      case aten::layer_norm:
      case aten::group_norm:
      case aten::frobenius_norm:
      case aten::nuclear_norm:
      case aten::cosine_similarity:
      case aten::cosine_embedding_loss:
      case aten::nll_loss:
      case aten::nll_loss2d:
      case aten::hinge_embedding_loss:
      case aten::kl_div:
      case aten::l1_loss:
      case aten::smooth_l1_loss:
      case aten::mse_loss:
      case aten::margin_ranking_loss:
      case aten::multilabel_margin_loss:
      case aten::soft_margin_loss:
      case aten::triplet_margin_loss:
      case aten::multi_margin_loss:
      case aten::binary_cross_entropy_with_logits:
      case aten::dist:
      case aten::pdist:
      case aten::cdist:
      case aten::renorm:
        if (!node->schema().is_mutable()) {
          castTensorInputs(
              node, aten::_autocast_to_full_precision, current_state());
        }
        break;

      // CastPolicy::fp32_set_opt_dtype
      case aten::prod:
      case aten::softmax:
      case aten::log_softmax:
      case aten::cumprod:
      case aten::cumsum:
      case aten::sum:
        if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
          castTensorInputs(
              node, aten::_autocast_to_full_precision, current_state());
        }
        break;

      // CastPolicy::promote (promote inputs to the widest type)
      case aten::addcdiv:
      case aten::addcmul:
      case aten::atan2:
      case aten::bilinear:
      case aten::cat:
      case aten::_cat:
      case aten::cross:
      case aten::dot:
      case aten::equal:
      case aten::index_put:
      case aten::stack:
      case aten::tensordot:
      // add, sub, mul, div were added to autocast jit, because aten implicit
      // type promotion is not visible to JIT and could cause dtype mismatch on
      // backward
      // see [Note: implicit type promotion in Autocast]
      case aten::add:
      case aten::sub:
      case aten::mul:
      case aten::div:
        if (!node->schema().is_mutable()) {
          castInputsToWidestType(node, current_state());
        }
        break;

      // Banned in autocast, see binary_cross_entropy_banned()
      case aten::binary_cross_entropy:
        AT_ERROR("Unsafe to autocast");
    }

    // process sub-blocks, if any
    for (Block* sub_block : node->blocks()) {
      handleBlock(sub_block, current_state());
    }
  }

  // Sanity check: make sure there's no unbalanced transition
  TORCH_INTERNAL_ASSERT(autocast_stack.empty());
}