in torch/csrc/jit/passes/autocast.cpp [230:451]
void handleBlock(Block* block, AutocastContext initial_state) {
std::stack<AutocastScope> autocast_stack;
c10::optional<bool> incompatible_amp = c10::nullopt;
// The current autocast enabled/disabled state
auto current_state = [&] {
return autocast_stack.empty() ? initial_state
: autocast_stack.top().context;
};
for (Node* node : block->nodes()) {
switch (node->kind()) {
case prim::CallFunction:
// TODO: limit it only to amp related node;
if (current_state() == initial_state) {
// if the current autocasting state is the same as the global state,
// then autocasting will be done correctly on subsequent method and
// function calls
if (current_state()) {
castTensorInputs(
node, aten::_autocast_to_full_precision, current_state());
}
break;
}
TORCH_INTERNAL_ASSERT(
!incompatible_amp.has_value() || incompatible_amp.value(),
"Calls are not expected with AMP & JIT");
incompatible_amp = true;
break;
case prim::CallMethod:
// TODO: limit it only to amp related node;
if (current_state() == initial_state) {
// if the current autocasting state is the same as the global state,
// then autocasting will be done correctly on subsequent method and
// function calls
if (current_state()) {
castTensorInputs(
node, aten::_autocast_to_full_precision, current_state());
}
break;
}
if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
const auto& name = node->s(attr::name);
const auto& function = class_type->getMethod(name);
if (!function.isGraphFunction()) {
TORCH_INTERNAL_ASSERT(
!incompatible_amp.has_value() || incompatible_amp.value(),
"Calls are not expected with AMP & JIT");
incompatible_amp = true;
}
} else {
TORCH_INTERNAL_ASSERT(
!incompatible_amp.has_value() || incompatible_amp.value(),
"Unexpected prim::CallMethod form with AMP & JIT");
incompatible_amp = true;
}
break;
case prim::Enter:
if (auto autocast_scope =
parseAutocast(node->input(), current_state())) {
if (node->hasUses()) {
// TODO: better error message
AT_ERROR("`with autocast() as ...` is not supported");
}
TORCH_INTERNAL_ASSERT(
!incompatible_amp.has_value() || !incompatible_amp.value(),
"Unsupported case by AMP & JIT");
incompatible_amp = false;
autocast_stack.push(*autocast_scope);
}
break;
case prim::Exit:
if (isAutocastNode(node->input(0))) {
TORCH_INTERNAL_ASSERT(!autocast_stack.empty());
TORCH_INTERNAL_ASSERT(autocast_stack.top().instance == node->input());
TORCH_INTERNAL_ASSERT(
!incompatible_amp.has_value() || !incompatible_amp.value(),
"Unsupported case by AMP & JIT");
incompatible_amp = false;
autocast_stack.pop();
}
break;
// CastPolicy::fp16 (cast all inputs to float16)
case aten::_convolution:
case aten::conv1d:
case aten::conv2d:
case aten::conv3d:
case aten::conv_tbc:
case aten::conv_transpose1d:
case aten::convolution:
case aten::cudnn_convolution:
case aten::cudnn_convolution_transpose:
case aten::prelu:
case aten::addmm:
case aten::addmv:
case aten::addr:
case aten::matmul:
case aten::mm:
case aten::mv:
case aten::linear:
case aten::addbmm:
case aten::baddbmm:
case aten::bmm:
case aten::chain_matmul:
case aten::_thnn_fused_lstm_cell:
case aten::_thnn_fused_gru_cell:
case aten::lstm_cell:
case aten::gru_cell:
case aten::rnn_tanh_cell:
case aten::rnn_relu_cell:
if (!node->schema().is_mutable()) {
castTensorInputs(
node, aten::_autocast_to_reduced_precision, current_state());
}
break;
// CastPolicy::fp32 (cast all inputs to float32)
case aten::native_layer_norm:
case aten::acos:
case aten::asin:
case aten::cosh:
case aten::erfinv:
case aten::exp:
case aten::expm1:
case aten::log:
case aten::log10:
case aten::log2:
case aten::log1p:
case aten::reciprocal:
case aten::rsqrt:
case aten::sinh:
case aten::tan:
case aten::pow:
case aten::softplus:
case aten::gelu:
case aten::layer_norm:
case aten::group_norm:
case aten::frobenius_norm:
case aten::nuclear_norm:
case aten::cosine_similarity:
case aten::cosine_embedding_loss:
case aten::nll_loss:
case aten::nll_loss2d:
case aten::hinge_embedding_loss:
case aten::kl_div:
case aten::l1_loss:
case aten::smooth_l1_loss:
case aten::mse_loss:
case aten::margin_ranking_loss:
case aten::multilabel_margin_loss:
case aten::soft_margin_loss:
case aten::triplet_margin_loss:
case aten::multi_margin_loss:
case aten::binary_cross_entropy_with_logits:
case aten::dist:
case aten::pdist:
case aten::cdist:
case aten::renorm:
if (!node->schema().is_mutable()) {
castTensorInputs(
node, aten::_autocast_to_full_precision, current_state());
}
break;
// CastPolicy::fp32_set_opt_dtype
case aten::prod:
case aten::softmax:
case aten::log_softmax:
case aten::cumprod:
case aten::cumsum:
case aten::sum:
if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
castTensorInputs(
node, aten::_autocast_to_full_precision, current_state());
}
break;
// CastPolicy::promote (promote inputs to the widest type)
case aten::addcdiv:
case aten::addcmul:
case aten::atan2:
case aten::bilinear:
case aten::cat:
case aten::_cat:
case aten::cross:
case aten::dot:
case aten::equal:
case aten::index_put:
case aten::stack:
case aten::tensordot:
// add, sub, mul, div were added to autocast jit, because aten implicit
// type promotion is not visible to JIT and could cause dtype mismatch on
// backward
// see [Note: implicit type promotion in Autocast]
case aten::add:
case aten::sub:
case aten::mul:
case aten::div:
if (!node->schema().is_mutable()) {
castInputsToWidestType(node, current_state());
}
break;
// Banned in autocast, see binary_cross_entropy_banned()
case aten::binary_cross_entropy:
AT_ERROR("Unsafe to autocast");
}
// process sub-blocks, if any
for (Block* sub_block : node->blocks()) {
handleBlock(sub_block, current_state());
}
}
// Sanity check: make sure there's no unbalanced transition
TORCH_INTERNAL_ASSERT(autocast_stack.empty());
}