void PropagateOnNode()

in torch/csrc/jit/codegen/cuda/type_inference.cpp [81:514]


  void PropagateOnNode(Node* node) {
    switch (node->kind()) {
      // Constant:
      case prim::Constant: {
        if (node->output()->type()->isSubtypeOf(TensorType::get())) {
          node->output()->inferTypeFrom(node->t(attr::value));
        }
        break;
      }
      // unary operations
      case aten::threshold:
      case aten::clamp:
      case aten::abs:
      case aten::neg:
      case aten::ceil:
      case aten::floor:
      case aten::round:
      case aten::trunc:
      case aten::frac:
      case aten::relu:
      case aten::silu:
      case aten::gelu:
      case aten::softplus:
      case aten::bitwise_not:
      // TODO: rand_like should support cast.
      case aten::rand_like: {
        node->output()->setType(unary_type(node));
        break;
      }
      // unary float operations
      case aten::log:
      case aten::log10:
      case aten::log1p:
      case aten::log2:
      case aten::lgamma:
      case aten::exp:
      case aten::expm1:
      case aten::erf:
      case aten::erfc:
      case aten::cos:
      case aten::acos:
      case aten::cosh:
      case aten::sin:
      case aten::asin:
      case aten::sinh:
      case aten::tan:
      case aten::atan:
      case aten::atanh:
      case aten::sqrt:
      case aten::rsqrt:
      case aten::reciprocal:
      case aten::sigmoid:
      case aten::tanh: {
        node->output()->setType(unary_float_type(node));
        break;
      }
      // binary float
      case aten::atan2: {
        node->output()->setType(binary_float_type(node));
        break;
      }
      // binary operations that forward meta info and broadcast shape:
      case aten::gelu_backward:
      case aten::mul:
      case aten::div:
      case aten::min:
      case aten::max:
      // TODO: first operand for pow can be Tensor / Scalar
      case aten::pow:
      case aten::remainder:
      case aten::threshold_backward:
      case aten::fmod:
      case aten::lerp:
      // add/sub could be ternary op and the third argument does not contribute
      // to neither type promotion nor shape.
      // TODO: Include alpha check for add/sub
      case aten::add:
      case aten::sub: {
        node->output()->setType(binary_type(node));
        break;
      }
      // Type can be int or bool for "and" and "or", if both are bool should be
      // bool, if both int should be int, otherwise would have errored
      case aten::__and__:
      case aten::__or__: {
        const auto promoted_type = binary_broadcast_type(
            getInputTensorType(node, 0, true),
            getInputTensorType(node, 1, true),
            node->input(0)->type()->cast<TensorType>()->scalarType() ==
                    at::ScalarType::Bool
                ? at::ScalarType::Bool
                : at::ScalarType::Int);
        break;
      }
      // Real int ops
      case aten::__xor__:
      case aten::__lshift__:
      case aten::__rshift__: {
        const auto promoted_type = binary_broadcast_type(
            getInputTensorType(node, 0, true),
            getInputTensorType(node, 1, true),
            at::ScalarType::Int);
        node->output()->setType(promoted_type);
        break;
      }
      // binary comparison
      case aten::lt:
      case aten::le:
      case aten::gt:
      case aten::ge:
      case aten::ne:
      case aten::eq: {
        const auto promoted_type = binary_broadcast_type(
            getInputTensorType(node, 0, false),
            getInputTensorType(node, 1, true),
            at::ScalarType::Bool);
        node->output()->setType(promoted_type);
        break;
      }
      case aten::where: {
        const auto promoted_type = binary_broadcast_type(
            getInputTensorType(node, 1, true),
            getInputTensorType(node, 2, true));
        node->output()->setType(promoted_type);
        break;
      }
      case aten::addcmul: {
        auto promoted_type = binary_broadcast_type(
            getInputTensorType(node, 1, true),
            getInputTensorType(node, 2, true));
        promoted_type = binary_broadcast_type(
            promoted_type, getInputTensorType(node, 0, true));
        node->output()->setType(promoted_type);
        break;
      }
      case aten::native_dropout_backward:
      case aten::dropout: {
        node->output()->setType(getInputTensorType(node, 0));
        break;
      }
      case aten::native_dropout: {
        auto out_type = getInputTensorType(node, 0);
        node->output(0)->setType(out_type);

        auto mask_type = TensorType::create(
            at::ScalarType::Bool, *out_type->device(), c10::nullopt, false);

        node->output(1)->setType(mask_type);
        break;
      }
      case aten::instance_norm:
      case aten::batch_norm: {
        node->output()->setType(getInputTensorType(node, 0));
        break;
      }
      case aten::_batch_norm_impl_index_backward: {
        // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
        auto out_mask_list = constant_as<c10::List<bool>>(node->input(10));
        TORCH_INTERNAL_ASSERT(
            out_mask_list.has_value(),
            "Missing output mask for batch_norm_backward");
        std::vector<int> output_mask;
        for (const auto value : out_mask_list->vec()) {
          output_mask.emplace_back(static_cast<int>(value));
        }

        auto grad_input_type = getInputTensorType(node, 1);
        if (output_mask[0]) {
          node->output(0)->setType(grad_input_type);
        }

        if (output_mask[1]) {
          if (auto weight_type = getInputTensorType(node, 3, true)) {
            auto acc_weight_type =
                weight_type->withScalarType(toAccumulateType(weight_type));
            node->output(1)->setType(acc_weight_type);
          }
        }

        // TODO: Use shape information from weight tensor
        // OR get dtype information for bias tensor
        if (output_mask[2]) {
          auto bias_type = TensorType::create(
              toAccumulateType(grad_input_type),
              *grad_input_type->device(),
              c10::nullopt,
              c10::nullopt);
          node->output(2)->setType(bias_type);
        }
        break;
      }
      case aten::_batch_norm_impl_index: {
        auto out_type = getInputTensorType(node, 0);
        node->output(0)->setType(out_type);

        auto mean_invstd_type = TensorType::create(
            toAccumulateType(out_type),
            *out_type->device(),
            c10::nullopt,
            c10::nullopt);
        node->output(1)->setType(mean_invstd_type);
        node->output(2)->setType(mean_invstd_type);

        // TODO: not that it matters, but mark the right type here;
        auto reserve_type = TensorType::create(
            *out_type->scalarType(),
            *out_type->device(),
            c10::nullopt,
            c10::nullopt);
        node->output(3)->setType(reserve_type);
        node->output(4)->setType(IntType::get());
        break;
      }
      case aten::native_batch_norm: {
        auto out_type = getInputTensorType(node, 0);
        node->output(0)->setType(out_type);

        auto mean_invstd_type = TensorType::create(
            toAccumulateType(out_type),
            *out_type->device(),
            c10::nullopt,
            c10::nullopt);
        node->output(1)->setType(mean_invstd_type);
        node->output(2)->setType(mean_invstd_type);
        break;
      }
      case aten::layer_norm: {
        node->output(0)->setType(getInputTensorType(node, 0));
        break;
      }
      case aten::native_layer_norm: {
        auto out_type = getInputTensorType(node, 0);
        node->output(0)->setType(out_type);

        auto mean_invstd_type = TensorType::create(
            toAccumulateType(out_type),
            *out_type->device(),
            c10::nullopt,
            c10::nullopt);
        node->output(1)->setType(mean_invstd_type);
        node->output(2)->setType(mean_invstd_type);
        break;
      }
      case aten::native_layer_norm_backward: {
        // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
        auto out_mask_list = constant_as<c10::List<bool>>(node->input(7));
        TORCH_INTERNAL_ASSERT(
            out_mask_list.has_value(), "output mask for layer_norm_backward");
        std::vector<int> output_mask;
        for (const auto value : out_mask_list->vec()) {
          output_mask.emplace_back(static_cast<int>(value));
        }

        if (output_mask[0]) {
          node->output(0)->setType(getInputTensorType(node, 0));
        }

        if (output_mask[1]) {
          // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
          if (auto weight_type = getInputTensorType(node, 5, true)) {
            node->output(1)->setType(weight_type);
          }
        }

        if (output_mask[2]) {
          // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
          if (auto bias_type = getInputTensorType(node, 6, true)) {
            node->output(2)->setType(bias_type);
          }
        }
        break;
      }
      case aten::softmax: {
        auto out_type = getInputTensorType(node, 0);

        // accept dtype input to `aten::softmax` node
        if (!node->input(2)->type()->isSubtypeOf(
                static_cast<c10::TypePtr>(NoneType::get()))) {
          if (auto opt_ivalue = toIValue(node->input(2))) {
            out_type = out_type->withScalarType(opt_ivalue->toScalarType());
          }
        }
        node->output()->setType(out_type);
        break;
      }
      case aten::_softmax: {
        auto out_type = getInputTensorType(node, 0);

        const auto half_to_float = constant_as<bool>(node->input(2));
        TORCH_CHECK(
            half_to_float.has_value(),
            "half_to_float bool doesn't have a value.");
        if (half_to_float.value()) {
          out_type = out_type->withScalarType(at::ScalarType::Float);
        }

        node->output()->setType(out_type);
        break;
      }
      case aten::_softmax_backward_data: {
        auto out_type = getInputTensorType(node, 0);
        if (auto opt_ivalue = toIValue(node->input(3))) {
          out_type = out_type->withScalarType(opt_ivalue->toScalarType());
        }
        node->output()->setType(out_type);
        break;
      }
      case aten::amax:
      case aten::mean:
      case aten::sum: {
        auto out_type = getInputTensorType(node, 0);

        // accept dtype input to `aten::sum` && `aten::mean`  node
        if (node->kind() == aten::mean || node->kind() == aten::sum) {
          if (!node->input(3)->type()->isSubtypeOf(
                  static_cast<c10::TypePtr>(NoneType::get()))) {
            if (auto opt_ivalue = toIValue(node->input(3))) {
              out_type = out_type->withScalarType(opt_ivalue->toScalarType());
            }
          }
        }
        const auto dims = constant_as<c10::List<int64_t>>(node->input(1));
        const auto keepdim = constant_as<bool>(node->input(2));
        TORCH_CHECK(
            dims.has_value() && keepdim.has_value(),
            "Shape inference cannot handle options.");
        node->output()->setType(
            unary_reduce_type(out_type, dims->vec(), keepdim.value()));
        break;
      }
      case aten::sum_to_size:
      case aten::_grad_sum_to_size: {
        auto out_type = node->input(0)->type()->cast<TensorType>();
        node->output()->setType(out_type->withDim(c10::nullopt));
        break;
      }
      /*
      // TODO: Enable view in parser by detecting non-alias view operation
      case aten::view:
      case aten::reshape: {
        auto out_type = node->input(0)->type()->cast<TensorType>();
        auto size_optional = constant_as<c10::List<int64_t>>(node->input(1));
        TORCH_INTERNAL_ASSERT(
            size_optional.has_value(), "The size parameter is required.");
        auto new_size = size_optional->vec();
        node->output()->setType(out_type->withSizes(new_size));
        break;
      }
      */
      case aten::type_as: {
        const auto type0 = getInputTensorType(node, 0);
        const auto type1 = getInputTensorType(node, 1);
        node->output()->setType(type0->withScalarType(type1->scalarType()));
        break;
      }
      case aten::to: {
        const auto type0 = getInputTensorType(node, 0);
        const auto out_dtype = toIValue(node->input(1));
        TORCH_CHECK(out_dtype, "No output type specified");
        node->output()->setType(
            type0->withScalarType(out_dtype->toScalarType()));
        break;
      }
      case prim::add_optional: {
        const auto type0 = getInputTensorType(node, 0);
        // const auto type1 = getInputTensorType(node, 1, true);
        // note: add_optional is supposed to replace an inplace add on input0,
        // so we just directly forward dtype
        TORCH_CHECK(type0 != nullptr);
        node->output()->setType(type0);
        break;
      }
      case aten::_autocast_to_reduced_precision: {
        const auto in_type = node->input(0)->type()->cast<TensorType>();
        TORCH_CHECK(
            hasTypeAndDevice(in_type),
            "Type and device propagation has failed, or was not provided enough information.");
        const auto in_scalar_type = in_type->scalarType();
        const auto in_device = in_type->device();
        const auto cuda_enabled = constant_as<bool>(node->input(1));
        const auto cpu_enabled = constant_as<bool>(node->input(2));
        const auto cuda_dtype = constant_as<c10::ScalarType>(node->input(3));
        const auto cpu_dtype = constant_as<c10::ScalarType>(node->input(4));
        TORCH_CHECK(
            cuda_enabled.has_value() && cpu_enabled.has_value() &&
                cuda_dtype.has_value() && cpu_dtype.has_value(),
            "_autocast_to_reduced_precision requires all scalar inputs to be constant.");
        if (in_type->scalarType() == at::ScalarType::Float) {
          if (in_device->is_cuda() && cuda_enabled.value()) {
            node->output()->setType(
                in_type->withScalarType(cuda_dtype.value()));
            break;
          } else if (in_device->is_cpu() && cpu_enabled.value()) {
            node->output()->setType(in_type->withScalarType(cpu_dtype.value()));
            break;
          }
        }
        node->output()->setType(in_type);
        break;
      }
      case aten::_autocast_to_full_precision: {
        const auto in_type = node->input(0)->type()->cast<TensorType>();
        TORCH_CHECK(
            hasTypeAndDevice(in_type),
            "Type and device propagation has failed, or was not provided enough information.");
        const auto in_scalar_type = in_type->scalarType();
        const auto in_device = in_type->device();
        const auto cuda_enabled = constant_as<bool>(node->input(1));
        const auto cpu_enabled = constant_as<bool>(node->input(2));
        TORCH_CHECK(
            cuda_enabled.has_value() && cpu_enabled.has_value(),
            "_autocast_to_full_precision requires enable flag to be constant.");

        if ((in_scalar_type == at::ScalarType::Half ||
             in_scalar_type == at::ScalarType::BFloat16) &&
            ((in_device->is_cuda() && cuda_enabled.value()) ||
             (in_device->is_cpu() && cpu_enabled.value()))) {
          node->output()->setType(
              in_type->withScalarType(at::ScalarType::Float));
        } else {
          node->output()->setType(in_type);
        }
        break;
      }
      default:
        TORCH_CHECK(
            false,
            "type inference failed, unrecognized operation encountered:",
            node->kind().toDisplayString());
        // TODO: generate a proper error log, as this probably means something
        //       went unexpected.
        break;
    }
  }