static void registerJitOperator()

in torch/csrc/jit/codegen/cuda/parser.cpp [679:2395]


  static void registerJitOperator() {
    // Register parse-function for each JIT operator;
    // This is a one-time look up, our hash registry indexes on the pointer in
    // OperatorRegistry.

    std::array<const char*, kNumBinaryOpsWithAlpha> BinaryOpWithAlpha = {
        "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
        "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
        "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
        "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor"};
    for (auto signature : BinaryOpWithAlpha) {
      auto ptr_op = getOperatorForLiteral(signature);
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            using BinaryOpWithAlphaType = Val* (*)(Val*, Val*, Val*);
            static std::unordered_map<
                Symbol,
                std::pair<BinaryOpType, BinaryOpWithAlphaType>>
                op_mapping(
                    {{aten::add,
                      std::make_pair(
                          BinaryOpType::Add,
                          static_cast<BinaryOpWithAlphaType>(&add_alpha))},
                     {aten::sub,
                      std::make_pair(
                          BinaryOpType::Sub,
                          static_cast<BinaryOpWithAlphaType>(&sub_alpha))}});
            // TODO: handle scaling factor when it's not constant 1;
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt,
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()]);
            auto lhs = list_val.front();
            list_val.pop_front();
            auto rhs = list_val.front();
            list_val.pop_front();
            Val* alpha = value_map[node->inputs()[2]->unique()];

            auto out = alpha->isOneInt()
                ? binaryOp(
                      op_mapping[node->kind()].first,
                      lhs,
                      rhs,
                      TypePromotion::default_op_config)
                : op_mapping[node->kind()].second(lhs, rhs, alpha);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    std::array<const char*, kNumBinaryFloatOps> BinaryFloatOp = {
        "aten::div(Tensor self, Tensor other) -> Tensor",
        "aten::div(Tensor self, Scalar other) -> Tensor",
        "aten::atan2(Tensor self, Tensor other) -> Tensor"};
    for (auto signature : BinaryFloatOp) {
      auto ptr_op = getOperatorForLiteral(signature);
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            static std::unordered_map<Symbol, BinaryOpType> op_mapping(
                {{aten::div, BinaryOpType::Div},
                 {aten::atan2, BinaryOpType::Atan2}});

            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt,
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()]);
            auto lhs = list_val.front();
            list_val.pop_front();
            auto rhs = list_val.front();
            list_val.pop_front();

            auto out = binaryOp(
                op_mapping[node->kind()],
                lhs,
                rhs,
                TypePromotion::float_op_config);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    std::array<const char*, kNumBinaryCastOps> BinaryCastOp = {
        "aten::mul(Tensor self, Tensor other) -> Tensor",
        "aten::mul(Tensor self, Scalar other) -> Tensor",
        "aten::max(Tensor self, Tensor other) -> Tensor",
        "aten::min(Tensor self, Tensor other) -> Tensor",
        "aten::pow(Tensor self, Tensor exponent) -> Tensor",
        "aten::pow(Tensor self, Scalar exponent) -> Tensor",
        "aten::pow(Scalar self, Tensor exponent) -> Tensor",
        "aten::remainder(Tensor self, Tensor other) -> Tensor",
        "aten::fmod(Tensor self, Tensor other) -> Tensor",
        "aten::__and__(Tensor self, Tensor other) -> Tensor",
        "aten::__or__(Tensor self, Tensor other) -> Tensor",
        "aten::__xor__(Tensor self, Tensor other) -> Tensor",
        "aten::__lshift__(Tensor self, Tensor other) -> Tensor",
        "aten::__rshift__(Tensor self, Tensor other) -> Tensor"};
    for (auto signature : BinaryCastOp) {
      auto ptr_op = getOperatorForLiteral(signature);
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            static std::unordered_map<Symbol, BinaryOpType> op_mapping(
                {{aten::mul, BinaryOpType::Mul},
                 {aten::min, BinaryOpType::Min},
                 {aten::max, BinaryOpType::Max},
                 {aten::pow, BinaryOpType::Pow},
                 {aten::remainder, BinaryOpType::Remainder},
                 {aten::fmod, BinaryOpType::Fmod},
                 {aten::__and__, BinaryOpType::And},
                 {aten::__or__, BinaryOpType::Or},
                 {aten::__xor__, BinaryOpType::Xor},
                 {aten::__lshift__, BinaryOpType::Lshift},
                 {aten::__rshift__, BinaryOpType::Rshift}});

            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt,
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()]);
            auto lhs = list_val.front();
            list_val.pop_front();
            auto rhs = list_val.front();
            list_val.pop_front();

            auto out = binaryOp(
                op_mapping[node->kind()],
                lhs,
                rhs,
                TypePromotion::default_op_config);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    std::array<const char*, kNumBinaryComparisonOps> BinaryOp = {
        "aten::eq(Tensor self, Tensor other) -> Tensor",
        "aten::eq(Tensor self, Scalar other) -> Tensor",
        "aten::ne(Tensor self, Tensor other) -> Tensor",
        "aten::ne(Tensor self, Scalar other) -> Tensor",
        "aten::ge(Tensor self, Tensor other) -> Tensor",
        "aten::ge(Tensor self, Scalar other) -> Tensor",
        "aten::gt(Tensor self, Tensor other) -> Tensor",
        "aten::gt(Tensor self, Scalar other) -> Tensor",
        "aten::le(Tensor self, Tensor other) -> Tensor",
        "aten::le(Tensor self, Scalar other) -> Tensor",
        "aten::lt(Tensor self, Tensor other) -> Tensor",
        "aten::lt(Tensor self, Scalar other) -> Tensor"};
    for (auto signature : BinaryOp) {
      auto ptr_op = getOperatorForLiteral(signature);
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            static std::unordered_map<Symbol, BinaryOpType> op_mapping(
                {{aten::lt, BinaryOpType::LT},
                 {aten::le, BinaryOpType::LE},
                 {aten::gt, BinaryOpType::GT},
                 {aten::ge, BinaryOpType::GE},
                 {aten::ne, BinaryOpType::NE},
                 {aten::eq, BinaryOpType::Eq}});

            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt,
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()]);
            auto lhs = list_val.front();
            list_val.pop_front();
            auto rhs = list_val.front();
            list_val.pop_front();

            auto out = binaryOp(
                op_mapping[node->kind()],
                lhs,
                rhs,
                TypePromotion::comparison_op_config);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    std::array<const char*, kNumUnaryOps> UnaryOp = {
        "aten::abs(Tensor self) -> Tensor",
        "aten::bitwise_not(Tensor self) -> Tensor",
        "aten::ceil(Tensor self) -> Tensor",
        "aten::floor(Tensor self) -> Tensor",
        "aten::frac(Tensor self) -> Tensor",
        "aten::neg(Tensor self) -> Tensor",
        "aten::relu(Tensor self) -> Tensor",
        "aten::round(Tensor self) -> Tensor",
        "aten::silu(Tensor self) -> Tensor",
        "aten::trunc(Tensor self) -> Tensor",
    };
    for (auto signature : UnaryOp) {
      auto ptr_op = getOperatorForLiteral(signature);
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            static std::unordered_map<Symbol, UnaryOpType> op_mapping({
                {aten::abs, UnaryOpType::Abs},
                {aten::bitwise_not, UnaryOpType::Not},
                {aten::ceil, UnaryOpType::Ceil},
                {aten::floor, UnaryOpType::Floor},
                {aten::frac, UnaryOpType::Frac},
                {aten::neg, UnaryOpType::Neg},
                {aten::relu, UnaryOpType::Relu},
                {aten::round, UnaryOpType::Round},
                {aten::silu, UnaryOpType::Silu},
                {aten::trunc, UnaryOpType::Trunc},
            });
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt, value_map[node->inputs()[0]->unique()]);
            auto operand = list_val.front();
            list_val.pop_front();
            auto out = unaryOp(op_mapping[node->kind()], operand);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    std::array<const char*, kNumUnaryFloatOps> UnaryFloatOp = {
        "aten::log(Tensor self) -> Tensor",
        "aten::log10(Tensor self) -> Tensor",
        "aten::log1p(Tensor self) -> Tensor",
        "aten::log2(Tensor self) -> Tensor",
        "aten::lgamma(Tensor self) -> Tensor",
        "aten::exp(Tensor self) -> Tensor",
        "aten::expm1(Tensor self) -> Tensor",
        "aten::erf(Tensor self) -> Tensor",
        "aten::erfc(Tensor self) -> Tensor",
        "aten::cos(Tensor self) -> Tensor",
        "aten::acos(Tensor self) -> Tensor",
        "aten::cosh(Tensor self) -> Tensor",
        "aten::sin(Tensor self) -> Tensor",
        "aten::asin(Tensor self) -> Tensor",
        "aten::sinh(Tensor self) -> Tensor",
        "aten::tan(Tensor self) -> Tensor",
        "aten::atan(Tensor self) -> Tensor",
        "aten::tanh(Tensor self) -> Tensor",
        "aten::atanh(Tensor self) -> Tensor",
        "aten::sqrt(Tensor self) -> Tensor",
        "aten::rsqrt(Tensor self) -> Tensor",
        "aten::reciprocal(Tensor self) -> Tensor",
        "aten::sigmoid(Tensor self) -> Tensor"};
    for (auto signature : UnaryFloatOp) {
      auto ptr_op = getOperatorForLiteral(signature);
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            static std::unordered_map<Symbol, UnaryOpType> op_mapping({
                {aten::log, UnaryOpType::Log},
                {aten::log10, UnaryOpType::Log10},
                {aten::log1p, UnaryOpType::Log1p},
                {aten::log2, UnaryOpType::Log2},
                {aten::lgamma, UnaryOpType::Lgamma},
                {aten::exp, UnaryOpType::Exp},
                {aten::expm1, UnaryOpType::Expm1},
                {aten::erf, UnaryOpType::Erf},
                {aten::erfc, UnaryOpType::Erfc},
                {aten::cos, UnaryOpType::Cos},
                {aten::acos, UnaryOpType::Acos},
                {aten::cosh, UnaryOpType::Cosh},
                {aten::sin, UnaryOpType::Sin},
                {aten::asin, UnaryOpType::Asin},
                {aten::sinh, UnaryOpType::Sinh},
                {aten::tan, UnaryOpType::Tan},
                {aten::tanh, UnaryOpType::Tanh},
                {aten::atan, UnaryOpType::Atan},
                {aten::atanh, UnaryOpType::Atanh},
                {aten::sqrt, UnaryOpType::Sqrt},
                {aten::rsqrt, UnaryOpType::Rsqrt},
                {aten::reciprocal, UnaryOpType::Reciprocal},
                {aten::sigmoid, UnaryOpType::Sigmoid},
            });
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt, value_map[node->inputs()[0]->unique()]);
            auto operand = list_val.front();
            list_val.pop_front();
            auto out = unaryOp(
                op_mapping[node->kind()],
                operand,
                TypePromotion::float_op_config);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()]);
            auto operand = list_val.front();
            list_val.pop_front();

            auto out = randlike(operand);
            value_map.emplace(node->output()->unique(), out);
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()]);
            auto operand = list_val.front();
            list_val.pop_front();
            auto& beta = value_map[node->inputs()[1]->unique()];
            auto& threshold = value_map[node->inputs()[2]->unique()];
            auto out = softplus(operand, beta, threshold);
            value_map.emplace(node->output()->unique(), out);
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()]);
            auto operand = list_val.front();
            list_val.pop_front();
            auto& th = value_map[node->inputs()[1]->unique()];
            auto& value = value_map[node->inputs()[2]->unique()];

            auto out = threshold(operand, th, value);
            value_map.emplace(node->output()->unique(), out);
          },
          nullptr,
          nullptr);
    }

    { // LTC uses threshold_backward for relu_backward
      auto ptr_op = getOperatorForLiteral(
          "aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt,
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()]);
            auto grad_output = list_val.front();
            list_val.pop_front();
            auto input = list_val.front();
            auto& threshold = value_map[node->inputs()[2]->unique()];

            auto comparison = binaryOp(
                BinaryOpType::GT,
                input,
                threshold,
                TypePromotion::comparison_op_config);
            auto mask = castOp(input->getDataType().value(), comparison);
            auto out = mul(grad_output, mask);

            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt, value_map[node->inputs()[0]->unique()]);
            auto operand = list_val.front();
            list_val.pop_front();
            Val* low = value_map.count(node->inputs()[1]->unique()) != 0
                ? *value_map[node->inputs()[1]->unique()]
                : new Double(std::numeric_limits<float>::min());
            Val* high = value_map.count(node->inputs()[2]->unique()) != 0
                ? *value_map[node->inputs()[2]->unique()]
                : new Double(std::numeric_limits<float>::max());

            auto out = clamp(operand, low, high);
            value_map.emplace(node->output()->unique(), out);
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()],
                value_map[node->inputs()[2]->unique()]);
            auto condition = list_val.front();
            list_val.pop_front();
            auto x = list_val.front();
            list_val.pop_front();
            auto y = list_val.front();
            list_val.pop_front();

            auto out = where(condition, x, y);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    {
      std::array<const char*, kNumLerpOps> LerpOp = {
          "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
          "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor"};
      for (auto signature : LerpOp) {
        auto ptr_op = getOperatorForLiteral(signature);
        REGISTER_PARSE_RULE(
            ptr_op,
            {
              MemoryFormat format;
              std::list<Val*> list_val;
              std::tie(format, list_val) = getConsistentValues(
                  MemoryFormat::Contiguous(),
                  value_map[node->inputs()[0]->unique()],
                  value_map[node->inputs()[1]->unique()],
                  value_map[node->inputs()[2]->unique()]);
              auto self = list_val.front();
              list_val.pop_front();
              auto end = list_val.front();
              list_val.pop_front();
              auto weight = list_val.front();
              list_val.pop_front();

              auto out = lerp(self, end, weight);
              value_map.emplace(
                  node->output()->unique(), ValueHolder(out, format));
            },
            nullptr,
            nullptr);
      }
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt,
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()],
                value_map[node->inputs()[2]->unique()],
                value_map[node->inputs()[3]->unique()]);
            auto self = list_val.front();
            list_val.pop_front();
            auto tensor1 = list_val.front();
            list_val.pop_front();
            auto tensor2 = list_val.front();
            list_val.pop_front();
            auto value = list_val.front();
            list_val.pop_front();

            auto out = addcmul(self, tensor1, tensor2, value);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()]);
            auto input = list_val.front();
            list_val.pop_front();
            auto prob = list_val.front();
            list_val.pop_front();
            auto train = constant_as<bool>(node->input(2));

            TORCH_INTERNAL_ASSERT(
                train.has_value(), "dropout needs constant `train` flag");

            if (train.value()) {
              auto result = dropout(input->as<TensorView>(), prob);

              value_map.emplace(node->output(0)->unique(), result.output);
              value_map.emplace(node->output(1)->unique(), result.mask);
            } else {
              value_map.emplace(node->output(0)->unique(), input);
              value_map.emplace(
                  node->output(1)->unique(),
                  ValueHolder(TensorViewBuilder().build(), format));
            }
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::dropout(Tensor input, float p, bool train) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()]);
            auto input = list_val.front();
            list_val.pop_front();
            auto prob = list_val.front();
            list_val.pop_front();

            auto train = constant_as<bool>(node->input(2));
            TORCH_INTERNAL_ASSERT(
                train.has_value(), "dropout needs constant `train` flag");

            if (train.value()) {
              auto result = dropout(input->as<TensorView>(), prob);

              value_map.emplace(node->output()->unique(), result.output);
            } else {
              value_map.emplace(node->output()->unique(), input);
            }
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()],
                value_map[node->inputs()[2]->unique()]);
            auto grad = list_val.front();
            list_val.pop_front();
            auto mask = list_val.front();
            list_val.pop_front();
            auto scale = list_val.front();
            list_val.pop_front();

            auto output = dropout_backward(
                grad->as<TensorView>(), mask->as<TensorView>(), scale);
            value_map.emplace(node->output()->unique(), output);
          },
          nullptr,
          nullptr);
    }

    {
      std::array<const char*, kNumInstancenormFwd> InstanceNormFwd = {
          "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor"};
      for (auto signature : InstanceNormFwd) {
        auto ptr_op = getOperatorForLiteral(signature);
        REGISTER_PARSE_RULE(
            ptr_op,
            {
              auto fusion = FusionGuard::getCurFusion();

              // TODO: handle channels last
              MemoryFormat format;
              std::list<Val*> list_val;
              std::tie(format, list_val) = getConsistentValues(
                  MemoryFormat::Contiguous(),
                  value_map[node->inputs()[0]->unique()]);
              auto input_t = list_val.front();
              list_val.pop_front();
              auto input = input_t->as<TensorView>();

              TensorView* weight = nullptr;
              if (!node->input(1)->type()->isSubtypeOf(
                      static_cast<c10::TypePtr>(NoneType::get()))) {
                weight = value_map[node->input(1)->unique()]->as<TensorView>();
              }

              TensorView* bias = nullptr;
              if (!node->input(2)->type()->isSubtypeOf(
                      static_cast<c10::TypePtr>(NoneType::get()))) {
                bias = value_map[node->input(2)->unique()]->as<TensorView>();
              }

              TensorView* running_mean = nullptr;
              if (!node->input(3)->type()->isSubtypeOf(
                      static_cast<c10::TypePtr>(NoneType::get()))) {
                running_mean =
                    value_map[node->input(3)->unique()]->as<TensorView>();
                TORCH_INTERNAL_ASSERT(
                    fusion->hasInput(running_mean),
                    "IO_tensor `instance_norm::running_mean` can only be input tensor to fusion");
              }

              TensorView* running_var = nullptr;
              if (!node->input(4)->type()->isSubtypeOf(
                      static_cast<c10::TypePtr>(NoneType::get()))) {
                running_var =
                    value_map[node->input(4)->unique()]->as<TensorView>();
                TORCH_INTERNAL_ASSERT(
                    fusion->hasInput(running_var),
                    "IO_tensor `instance_norm::running_var` can only be input tensor to fusion");
              }

              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
              auto use_input_stats = constant_as<bool>(node->input(5));
              TORCH_INTERNAL_ASSERT(
                  use_input_stats.has_value(),
                  "The use_input_stats (bool) parameter is required.");
              const bool kUseInputStats = use_input_stats.value();

              Val* momentum_ptr = nullptr;
              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
              if (auto momentum = constant_as<float>(node->input(6))) {
                momentum_ptr = new Double(momentum.value());
              } else {
                // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
                momentum_ptr = value_map[node->input(6)->unique()];
              }

              Val* eps_ptr = nullptr;
              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
              if (auto eps = constant_as<float>(node->input(7))) {
                eps_ptr = new Double(eps.value());
              } else {
                // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
                eps_ptr = value_map[node->input(7)->unique()];
              }

              auto result = instance_norm(
                  input,
                  weight,
                  bias,
                  running_mean,
                  running_var,
                  kUseInputStats,
                  momentum_ptr,
                  eps_ptr);

              if (node->kind() ==
                  c10::Symbol::fromQualString("aten::instance_norm")) {
                value_map.emplace(node->output()->unique(), result.output);
              }
            },
            [](const Node* node) -> bool { return true; },
            [](const Node* node) -> OperatorType {
              return OperatorType::Normalization;
            });
      }
    }

    {
      std::array<const char*, kNumBatchnormFwd> BatchNormFwd = {
          "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)",
          "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
          "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"};
      for (auto signature : BatchNormFwd) {
        auto ptr_op = getOperatorForLiteral(signature);
        REGISTER_PARSE_RULE(
            ptr_op,
            {
              MemoryFormat format;
              Val* operand = nullptr;
              std::tie(format, operand) =
                  value_map[node->input(0)->unique()].getEntry();
              if (format.hasPermutation() && !format.isChannelsLast()) {
                format = MemoryFormat::Contiguous();
                operand = value_map[node->input(0)->unique()].maybeConvertValue(
                    format);
              }
              auto input = operand->as<TensorView>();

              TensorView* weight = nullptr;
              if (!node->input(1)->type()->isSubtypeOf(
                      static_cast<c10::TypePtr>(NoneType::get()))) {
                weight = value_map[node->input(1)->unique()]->as<TensorView>();
              }

              TensorView* bias = nullptr;
              if (!node->input(2)->type()->isSubtypeOf(
                      static_cast<c10::TypePtr>(NoneType::get()))) {
                bias = value_map[node->input(2)->unique()]->as<TensorView>();
              }

              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
              auto training = constant_as<bool>(node->input(5));
              TORCH_INTERNAL_ASSERT(
                  training.has_value(),
                  "The training (bool) parameter is required.");
              const bool kTraining = training.value();

              TensorView* running_mean = nullptr;
              if (!node->input(3)->type()->isSubtypeOf(
                      static_cast<c10::TypePtr>(NoneType::get()))) {
                running_mean =
                    value_map[node->input(3)->unique()]->as<TensorView>();
              }

              TensorView* running_var = nullptr;
              if (!node->input(4)->type()->isSubtypeOf(
                      static_cast<c10::TypePtr>(NoneType::get()))) {
                running_var =
                    value_map[node->input(4)->unique()]->as<TensorView>();
              }

              Val* momentum_ptr = nullptr;
              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
              if (auto momentum = constant_as<float>(node->input(6))) {
                momentum_ptr = new Double(momentum.value());
              } else {
                // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
                momentum_ptr = value_map[node->input(6)->unique()];
              }

              Val* eps_ptr = nullptr;
              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
              if (auto eps = constant_as<float>(node->input(7))) {
                eps_ptr = new Double(eps.value());
              } else {
                // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
                eps_ptr = value_map[node->input(7)->unique()];
              }

              auto result = batch_norm(
                  input,
                  weight,
                  bias,
                  running_mean,
                  running_var,
                  kTraining,
                  momentum_ptr,
                  eps_ptr,
                  format.isChannelsLast());

              if (node->kind() ==
                      c10::Symbol::fromQualString("aten::native_batch_norm") ||
                  node->kind() ==
                      c10::Symbol::fromQualString(
                          "aten::_batch_norm_impl_index")) {
                // TODO: output 3 & 4 are not created
                //       we are not creating these outputs because codegen
                //       currently lacks the support.
                value_map.emplace(
                    node->output(0)->unique(),
                    ValueHolder(result.output, format));
                value_map.emplace(node->output(1)->unique(), result.mean);
                value_map.emplace(node->output(2)->unique(), result.invstd);
              } else if (
                  node->kind() ==
                  c10::Symbol::fromQualString("aten::batch_norm")) {
                value_map.emplace(
                    node->output()->unique(),
                    ValueHolder(result.output, format));
              }
            },
            [](const Node* node) -> bool { return true; },
            [](const Node* node) -> OperatorType {
              return OperatorType::Normalization;
            });
      }
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            // discard impl_index and reservedSpace since we don't use them
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt,
                value_map[node->inputs()[1]->unique()],
                value_map[node->inputs()[2]->unique()]);
            if (format.hasPermutation() && !format.isChannelsLast()) {
              std::tie(format, list_val) = getConsistentValues(
                  MemoryFormat::Contiguous(),
                  value_map[node->inputs()[1]->unique()],
                  value_map[node->inputs()[2]->unique()]);
            }
            auto operand0 = list_val.front();
            list_val.pop_front();
            auto operand1 = list_val.front();
            list_val.pop_front();
            auto input = operand0->as<TensorView>();
            auto grad_out = operand1->as<TensorView>();

            TensorView* weight = nullptr;
            if (!node->input(3)->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              weight = value_map[node->input(3)->unique()]->as<TensorView>();
            }

            TensorView* running_mean = nullptr;
            if (!node->input(4)->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              running_mean =
                  value_map[node->input(4)->unique()]->as<TensorView>();
            }

            TensorView* running_var = nullptr;
            if (!node->input(5)->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              running_var =
                  value_map[node->input(5)->unique()]->as<TensorView>();
            }

            TensorView* save_mean = nullptr;
            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
            if (!node->input(6)->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
              save_mean = value_map[node->input(6)->unique()]->as<TensorView>();
            }

            TensorView* save_invstd = nullptr;
            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
            if (!node->input(7)->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              save_invstd =
                  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
                  value_map[node->input(7)->unique()]->as<TensorView>();
            }

            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
            auto training = constant_as<bool>(node->input(8));
            TORCH_INTERNAL_ASSERT(
                training.has_value(),
                "The training (bool) parameter is required.");
            const bool kTraining = training.value();

            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
            Val* eps_ptr = nullptr;
            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
            if (auto eps = constant_as<float>(node->input(9))) {
              eps_ptr = new Double(eps.value());
            } else {
              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
              eps_ptr = value_map[node->input(7)->unique()];
            }

            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
            auto out_mask_list = constant_as<c10::List<bool>>(node->input(10));
            TORCH_INTERNAL_ASSERT(
                out_mask_list.has_value(),
                "output mask for batch_norm_backward");
            std::vector<bool> output_mask;
            for (const auto value : out_mask_list->vec()) {
              output_mask.emplace_back(static_cast<bool>(value));
            }

            // TODO: merge this loop below.
            if (kTraining) {
              TORCH_INTERNAL_ASSERT(
                  save_mean != nullptr && save_invstd != nullptr,
                  "When training=True, save_mean and save_invstd are required.");
            } else {
              // TODO: this is not a legit assumption? Can't we run with
              // track_running_stats == false && training == false
              // which should just run through the case above.
              TORCH_INTERNAL_ASSERT(
                  running_mean != nullptr && running_var != nullptr,
                  "When training=False, running_mean and running_invstd are required.");
            }

            auto grads = batch_norm_backward(
                input,
                grad_out,
                weight,
                running_mean,
                running_var,
                save_mean,
                save_invstd,
                kTraining,
                eps_ptr,
                output_mask,
                format.isChannelsLast());

            if (output_mask[0]) {
              TORCH_INTERNAL_ASSERT(grads.grad_input != nullptr);
              value_map.emplace(
                  node->output(0)->unique(),
                  ValueHolder(grads.grad_input, format));
            } else {
              TORCH_INTERNAL_ASSERT(grads.grad_input == nullptr);
              value_map.emplace(
                  node->output(0)->unique(),
                  ValueHolder(TensorViewBuilder().build(), format));
            }

            if (output_mask[1]) {
              TORCH_INTERNAL_ASSERT(grads.grad_weight != nullptr);
              value_map.emplace(node->output(1)->unique(), grads.grad_weight);
            } else {
              TORCH_INTERNAL_ASSERT(grads.grad_weight == nullptr);
              value_map.emplace(
                  node->output(1)->unique(), TensorViewBuilder().build());
            }

            if (output_mask[2]) {
              TORCH_INTERNAL_ASSERT(grads.grad_bias != nullptr);
              value_map.emplace(node->output(2)->unique(), grads.grad_bias);
            } else {
              TORCH_INTERNAL_ASSERT(grads.grad_bias == nullptr);
              value_map.emplace(
                  node->output(2)->unique(), TensorViewBuilder().build());
            }
          },
          [](const Node* node) -> bool { return true; },
          [](const Node* node) -> OperatorType {
            return OperatorType::Normalization;
          });
    }

    {
      std::array<const char*, kNumLayernormFwd> LayerNormFwd = {
          "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
          "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"};
      for (auto signature : LayerNormFwd) {
        auto ptr_op = getOperatorForLiteral(signature);
        REGISTER_PARSE_RULE(
            ptr_op,
            {
              MemoryFormat format;
              std::list<Val*> list_val;
              std::tie(format, list_val) = getConsistentValues(
                  MemoryFormat::Contiguous(),
                  value_map[node->inputs()[0]->unique()]);
              auto input_t = list_val.front();
              list_val.pop_front();
              auto input = input_t->as<TensorView>();

              auto norm_shape_optional =
                  constant_as<c10::List<int64_t>>(node->input(1));
              TORCH_INTERNAL_ASSERT(
                  norm_shape_optional.has_value(),
                  "The Normalized_Shape list is required.");
              auto norm_shape = norm_shape_optional->vec();

              TensorView* weight = nullptr;
              if (!node->input(2)->type()->isSubtypeOf(
                      static_cast<c10::TypePtr>(NoneType::get()))) {
                weight = value_map[node->input(2)->unique()]->as<TensorView>();
              }

              TensorView* bias = nullptr;
              if (!node->input(3)->type()->isSubtypeOf(
                      static_cast<c10::TypePtr>(NoneType::get()))) {
                bias = value_map[node->input(3)->unique()]->as<TensorView>();
              }

              Val* eps_ptr = nullptr;
              if (auto eps = constant_as<float>(node->input(4))) {
                eps_ptr = new Double(eps.value());
              } else {
                eps_ptr = value_map[node->input(4)->unique()];
              }

              auto result =
                  layer_norm(input, norm_shape, weight, bias, eps_ptr);

              if (node->kind() ==
                  c10::Symbol::fromQualString("aten::native_layer_norm")) {
                value_map.emplace(node->output(0)->unique(), result.output);
                value_map.emplace(node->output(1)->unique(), result.mean);
                value_map.emplace(node->output(2)->unique(), result.invstd);
              } else if (
                  node->kind() ==
                  c10::Symbol::fromQualString("aten::layer_norm")) {
                value_map.emplace(node->output()->unique(), result.output);
              }
            },
            // TODO: #ProfileIValue List should update this
            [](const Node* node) -> bool { return true; },
            [](const Node* node) -> OperatorType {
              return OperatorType::Normalization;
            });
      }
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()]);
            auto grad_out_t = list_val.front();
            list_val.pop_front();
            auto input_t = list_val.front();
            list_val.pop_front();
            auto grad_out = grad_out_t->as<TensorView>();
            auto input = input_t->as<TensorView>();

            auto norm_shape_optional =
                constant_as<c10::List<int64_t>>(node->input(2));
            TORCH_INTERNAL_ASSERT(
                norm_shape_optional.has_value(),
                "The Normalized_Shape list is required.");
            auto norm_shape = norm_shape_optional->vec();

            auto mean = value_map[node->input(3)->unique()]->as<TensorView>();
            auto rstd = value_map[node->input(4)->unique()]->as<TensorView>();

            TensorView* weight = nullptr;
            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
            if (!node->input(5)->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
              weight = value_map[node->input(5)->unique()]->as<TensorView>();
            }

            TensorView* bias = nullptr;
            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
            if (!node->input(6)->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
              bias = value_map[node->input(6)->unique()]->as<TensorView>();
            }

            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
            auto output_mask_optional =
                constant_as<c10::List<bool>>(node->input(7));
            TORCH_INTERNAL_ASSERT(
                output_mask_optional.has_value(),
                "output mask for layer_norm_backward");
            std::vector<bool> output_mask = output_mask_optional->vec();

            auto grad = layer_norm_backward(
                grad_out,
                input,
                norm_shape,
                mean,
                rstd,
                weight,
                bias,
                output_mask);

            if (output_mask[0]) {
              TORCH_INTERNAL_ASSERT(grad.grad_input != nullptr);
              value_map.emplace(node->output(0)->unique(), grad.grad_input);
            } else {
              TORCH_INTERNAL_ASSERT(grad.grad_input == nullptr);
              value_map.emplace(
                  node->output(0)->unique(), TensorViewBuilder().build());
            }

            if (output_mask[1] && weight != nullptr) {
              TORCH_INTERNAL_ASSERT(grad.grad_weight != nullptr);
              value_map.emplace(node->output(1)->unique(), grad.grad_weight);
            } else {
              TORCH_INTERNAL_ASSERT(grad.grad_weight == nullptr);
              value_map.emplace(
                  node->output(1)->unique(), TensorViewBuilder().build());
            }

            if (output_mask[2] && bias != nullptr) {
              TORCH_INTERNAL_ASSERT(grad.grad_bias != nullptr);
              value_map.emplace(node->output(2)->unique(), grad.grad_bias);
            } else {
              TORCH_INTERNAL_ASSERT(grad.grad_bias == nullptr);
              value_map.emplace(
                  node->output(2)->unique(), TensorViewBuilder().build());
            }
          },
          // TODO: #ProfileIValue List should update this
          [](const Node* node) -> bool { return true; },
          [](const Node* node) -> OperatorType {
            return OperatorType::Normalization;
          });
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::softmax.int(Tensor self, int dim, int? dtype) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()]);
            auto input_t = list_val.front();
            list_val.pop_front();
            auto input = input_t->as<TensorView>();

            auto dim_value = constant_as<int>(node->input(1));
            TORCH_INTERNAL_ASSERT(
                dim_value.has_value(), "dim in softmax is not valid");

            auto output = softmax(input, dim_value.value());
            value_map.emplace(node->output()->unique(), output);
          },
          [](const Node* node) -> bool {
            if (node->inputs()[1]->node()->kind() != prim::Constant) {
              return false;
            }
            // TODO: support dynamic input by profiling it
            if (!node->inputs()[2]->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get())) &&
                node->inputs()[2]->node()->kind() != prim::Constant) {
              return false;
            }
            return true;
          },
          [](const Node* node) -> OperatorType {
            return OperatorType::Normalization;
          });
    }

    { // LTC uses this op for softmax
      auto ptr_op = getOperatorForLiteral(
          "aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()]);
            auto input_t = list_val.front();
            list_val.pop_front();
            auto input = input_t->as<TensorView>();

            auto dim_value = constant_as<int>(node->input(1));
            TORCH_INTERNAL_ASSERT(
                dim_value.has_value(), "dim in softmax is not valid");

            auto output = softmax(input, dim_value.value());
            value_map.emplace(node->output()->unique(), output);
          },
          [](const Node* node) -> bool {
            if (node->inputs()[1]->node()->kind() != prim::Constant) {
              return false;
            }
            if (node->inputs()[2]->node()->kind() != prim::Constant) {
              return false;
            } else {
              const auto half_to_float = constant_as<bool>(node->input(2));
              TORCH_INTERNAL_ASSERT(
                  half_to_float.has_value(), "Bool half_to_float is not valid");
              auto input_tensor_type =
                  node->input(0)->type()->cast<TensorType>();
              if (half_to_float.value() &&
                  input_tensor_type->scalarType() != at::ScalarType::Half) {
                return false;
              }
            }
            return true;
          },
          [](const Node* node) -> OperatorType {
            return OperatorType::Normalization;
          });
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            auto grad_output =
                value_map[node->input(0)->unique()]->as<TensorView>();

            auto output = value_map[node->input(1)->unique()]->as<TensorView>();

            auto dim_value = constant_as<int>(node->input(2));
            TORCH_INTERNAL_ASSERT(
                dim_value.has_value(), "dim in softmax is not valid");

            // input_dtype here is ignored! type_inference handles it
            auto grad_input =
                softmax_backward(grad_output, output, dim_value.value());

            value_map.emplace(node->output()->unique(), grad_input);
          },
          [](const Node* node) -> bool {
            if (node->inputs()[2]->node()->kind() != prim::Constant) {
              return false;
            }
            return true;
          },
          [](const Node* node) -> OperatorType {
            return OperatorType::Normalization;
          });
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            // TODO: support channels last in sum
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()]);
            auto self = list_val.front();
            list_val.pop_front();
            auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
            TORCH_INTERNAL_ASSERT(
                dims_list.has_value(),
                "aten::sum cannot be fused with dynamic axes");
            std::vector<int> dims;
            for (const auto dim : dims_list->vec()) {
              dims.emplace_back(static_cast<int>(dim));
            }
            auto keepdim = constant_as<bool>(node->input(2));
            TORCH_INTERNAL_ASSERT(
                keepdim.has_value(),
                "aten::sum cannot be fused with dynamic keepdim");
            auto out = sum(self->as<TensorView>(), dims, keepdim.value());
            value_map.emplace(node->output()->unique(), out);
          },
          [](const Node* node) -> bool {
            // TODO: support cast of output types
            if (!node->inputs()[3]->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              // We can only handle output as half, float, and double;
              if (const auto opt_ivalue = toIValue(node->input(3))) {
                const auto scalar_type = opt_ivalue->toScalarType();
                if (scalar_type == at::ScalarType::Double ||
                    scalar_type == at::ScalarType::Float ||
                    scalar_type == at::ScalarType::BFloat16 ||
                    scalar_type == at::ScalarType::Half) {
                  return true;
                }
              }
              return false;
            }
            // we don't support dynamic reduction axes;
            if (node->inputs()[1]->node()->kind() != prim::Constant) {
              return false;
            }
            // we don't support dynamic keepdim yet;
            if (node->inputs()[2]->node()->kind() != prim::Constant) {
              return false;
            }
            return true;
          },
          [](const Node* node) -> OperatorType {
            return OperatorType::Reduction;
          });
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()]);
            auto operand = list_val.front();
            list_val.pop_front();
            auto self = operand->as<TensorView>();
            auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
            TORCH_INTERNAL_ASSERT(
                dims_list.has_value(),
                "aten::mean cannot be fused with dynamic axes");
            std::vector<int> dims;
            for (const auto dim : dims_list->vec()) {
              dims.emplace_back(static_cast<int>(dim));
            }
            auto keepdim = constant_as<bool>(node->input(2));
            TORCH_INTERNAL_ASSERT(
                keepdim.has_value(),
                "aten::mean cannot be fused with dynamic keepdim");
            auto o_sum = sum(self, dims, keepdim.value());
            Val* num_features = new Double(1);
            for (auto axis : dims) {
              if (axis < 0) {
                axis += int(self->nDims());
              }
              num_features =
                  mul(num_features, self->domain()->domain()[axis]->extent());
            }
            auto out = div(o_sum, num_features);
            value_map.emplace(node->output()->unique(), out);
          },
          [](const Node* node) -> bool {
            // TODO: support cast of output types
            if (!node->inputs()[3]->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              // We can only handle output as half, float, and double;
              if (const auto opt_ivalue = toIValue(node->input(3))) {
                const auto scalar_type = opt_ivalue->toScalarType();
                if (scalar_type == at::ScalarType::Double ||
                    scalar_type == at::ScalarType::Float ||
                    scalar_type == at::ScalarType::BFloat16 ||
                    scalar_type == at::ScalarType::Half) {
                  return true;
                }
              }
              return false;
            }
            // we don't support dynamic reduction axes;
            if (node->inputs()[1]->node()->kind() != prim::Constant) {
              return false;
            }
            // we don't support dynamic keepdim yet;
            if (node->inputs()[2]->node()->kind() != prim::Constant) {
              return false;
            }
            return true;
          },
          [](const Node* node) -> OperatorType {
            return OperatorType::Reduction;
          });
    }
    {
      std::array<const char*, kNumSumToSize> SumToSize = {
          "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)",
          "aten::sum_to_size(Tensor self, int[] size) -> Tensor"};
      for (auto signature : SumToSize) {
        auto ptr_op = getOperatorForLiteral(signature);
        REGISTER_PARSE_RULE(
            ptr_op,
            {
              MemoryFormat format;
              std::list<Val*> list_val;
              std::tie(format, list_val) = getConsistentValues(
                  MemoryFormat::Contiguous(),
                  value_map[node->inputs()[0]->unique()]);
              auto self = list_val.front();
              list_val.pop_front();
              auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
              TORCH_INTERNAL_ASSERT(
                  size_to.has_value(),
                  "aten::sum cannot be fused with dynamic axes");
              if (!size_to->empty()) {
                auto out = sum_to(self->as<TensorView>(), size_to->vec());
                value_map.emplace(node->output()->unique(), out);
              } else {
                // We are introducing alias here!
                value_map.emplace(node->output()->unique(), self);
              }
            },
            [](const Node* node) -> bool {
              // we don't support dynamic reduction axes;
              if (node->inputs()[1]->node()->kind() != prim::Constant) {
                return false;
              }
              return true;
              // auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
              // return size_to.has_value() && !size_to->empty();
            },
            [](const Node* node) -> OperatorType {
              auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
              // technically size_to->empty() should never occur, as specialized
              // _grad_sum_to_size should have been removed by optimization pass
              if (size_to->empty()) {
                return OperatorType::ElementWise;
              } else {
                return OperatorType::ReductionToSize;
              }
            });
      }
    }

    {
      std::array<const char*, kNumAutocastOps> AutocastOps = {
          "aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)",
          "aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)"};
      for (auto signature : AutocastOps) {
        auto ptr_op = getOperatorForLiteral(signature);
        REGISTER_PARSE_RULE(
            ptr_op,
            {
              MemoryFormat format;
              std::list<Val*> list_val;
              std::tie(format, list_val) = getConsistentValues(
                  c10::nullopt, value_map[node->inputs()[0]->unique()]);
              auto self = list_val.front();
              list_val.pop_front();

              auto out = set(self);
              value_map.emplace(
                  node->output()->unique(), ValueHolder(out, format));
            },
            nullptr,
            nullptr);
      }
    }

    // Limiting aten::to implementation to only change the dtype of a tensor
    {
      auto ptr_op = getOperatorForLiteral(
          "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt, value_map[node->inputs()[0]->unique()]);
            auto self = list_val.front();
            list_val.pop_front();

            // we need static type for cast
            TORCH_INTERNAL_ASSERT(
                node->input(1)->node()->kind() == prim::Constant);
            auto dtype = toIValue(node->input(1))->toScalarType();

            // We want to keep our internal fusion math in FP32
            // Shape Inference will continue to propagate the right
            // type to outputs unchanged.
            if (dtype == at::ScalarType::Half) {
              dtype = at::ScalarType::Float;
            }
            if (dtype == at::ScalarType::BFloat16) {
              dtype = at::ScalarType::Float;
            }

            auto out = castOp(aten_to_data_type(dtype), self);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::type_as(Tensor self, Tensor other) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt, value_map[node->inputs()[0]->unique()]);
            auto self = list_val.front();
            list_val.pop_front();

            // TODO: switch to PyTorch dtype as it's closer to truth.
            // For now, reality is that PyTorch IR profiling information could
            // be missing even with profiling executor, due to upstream
            // transformations between profiling runs to fusion pass.
            auto opt_dtype =
                value_map[node->inputs()[1]->unique()]->getDataType();
            TORCH_INTERNAL_ASSERT(opt_dtype.has_value());

            auto out = castOp(opt_dtype.value(), self);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    {
      // We are not fusing `linear` yet, because we can't codegen efficient gemm
      // However, we still need this here, so PE would insert profile node for
      // this node.
      // During fusion pass, We decompose linear into gemm + elementwise.
      auto ptr_op = getOperatorForLiteral(
          "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            // this entry is created so we do profile input tensors;
            TORCH_INTERNAL_ASSERT(false, "not implemented yet");
          },
          [](const Node* node) -> bool {
            // We only profile `linear` layer with bias.
            if (node->input(2)->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              return false;
            }
            return true;
          });
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            // this entry is created so we do profile input tensors;
            if (node->input(1)->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
              // forwarding the value;
              value_map.emplace(
                  node->output()->unique(),
                  value_map[node->inputs()[0]->unique()]);
            } else {
              MemoryFormat format;
              std::list<Val*> list_val;
              std::tie(format, list_val) = getConsistentValues(
                  c10::nullopt,
                  value_map[node->inputs()[0]->unique()],
                  value_map[node->inputs()[1]->unique()]);
              auto lhs = list_val.front();
              list_val.pop_front();
              auto rhs = list_val.front();
              list_val.pop_front();

              auto out = binaryOp(
                  BinaryOpType::Add,
                  lhs,
                  rhs,
                  TypePromotion::default_op_config);
              value_map.emplace(
                  node->output()->unique(), ValueHolder(out, format));
            }
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt, value_map[node->inputs()[0]->unique()]);
            auto self = list_val.front();
            list_val.pop_front();
            auto out = gelu(self);
            value_map.emplace(
                node->output()->unique(), ValueHolder(out, format));
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::gelu_backward(Tensor grad, Tensor self) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                c10::nullopt,
                value_map[node->inputs()[0]->unique()],
                value_map[node->inputs()[1]->unique()]);
            auto grad_out = list_val.front();
            list_val.pop_front();
            auto self = list_val.front();
            list_val.pop_front();

            auto grad_in = gelu_backward(grad_out, self);
            value_map.emplace(
                node->output()->unique(), ValueHolder(grad_in, format));
          },
          nullptr,
          nullptr);
    }

    {
      auto ptr_op = getOperatorForLiteral(
          "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor");
      REGISTER_PARSE_RULE(
          ptr_op,
          {
            MemoryFormat format;
            std::list<Val*> list_val;
            std::tie(format, list_val) = getConsistentValues(
                MemoryFormat::Contiguous(),
                value_map[node->inputs()[0]->unique()]);
            auto self = list_val.front();
            list_val.pop_front();
            auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
            TORCH_INTERNAL_ASSERT(
                dims_list.has_value(),
                "aten::amax cannot be fused with dynamic axes");
            std::vector<int> dims;
            for (const auto dim : dims_list->vec()) {
              dims.emplace_back(static_cast<int>(dim));
            }
            auto keepdim = constant_as<bool>(node->input(2));
            TORCH_INTERNAL_ASSERT(
                keepdim.has_value(),
                "aten::amax cannot be fused with dynamic keepdim");

            auto out = max(self->as<TensorView>(), dims, keepdim.value());
            value_map.emplace(node->output()->unique(), out);
          },
          [](const Node* node) -> bool {
            // we don't support dynamic reduction axes;
            if (node->inputs()[1]->node()->kind() != prim::Constant) {
              return false;
            }
            // we don't support dynamic keepdim yet;
            if (node->inputs()[2]->node()->kind() != prim::Constant) {
              return false;
            }
            return true;
          },
          [](const Node* node) -> OperatorType {
            return OperatorType::Reduction;
          });
    }

    /*
    // TODO: Enable view in parser by detecting non-alias view operation
    {
      std::array<const char*, kNumViewSize> View = {
          "aten::view(Tensor(a) self, int[] size) -> Tensor(a)",
          "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"};
      for (auto signature : View) {
        auto ptr_op = getOperatorForLiteral(signature);
        REGISTER_PARSE_RULE(
            ptr_op,
            {
              auto self_value = node->inputs()[0];
              auto self = value_map[self_value->unique()]->as<TensorView>();

              auto self_type = self_value->type()->cast<c10::TensorType>();
              TORCH_INTERNAL_ASSERT(self_type != nullptr);
              auto self_sizes = getTensorSizes(self_type);

              auto size_optional =
                  constant_as<c10::List<int64_t>>(node->input(1));
              TORCH_INTERNAL_ASSERT(
                  size_optional.has_value(), "The size parameter is required.");

              auto output = view(self, self_sizes, size_optional->vec());
              value_map.emplace(node->output()->unique(), output);
            },
            nullptr,
            nullptr);
      }
    }
    */
  }