in torch/csrc/jit/codegen/cuda/parser.cpp [679:2395]
static void registerJitOperator() {
// Register parse-function for each JIT operator;
// This is a one-time look up, our hash registry indexes on the pointer in
// OperatorRegistry.
std::array<const char*, kNumBinaryOpsWithAlpha> BinaryOpWithAlpha = {
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
"aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
"aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
"aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor"};
for (auto signature : BinaryOpWithAlpha) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
using BinaryOpWithAlphaType = Val* (*)(Val*, Val*, Val*);
static std::unordered_map<
Symbol,
std::pair<BinaryOpType, BinaryOpWithAlphaType>>
op_mapping(
{{aten::add,
std::make_pair(
BinaryOpType::Add,
static_cast<BinaryOpWithAlphaType>(&add_alpha))},
{aten::sub,
std::make_pair(
BinaryOpType::Sub,
static_cast<BinaryOpWithAlphaType>(&sub_alpha))}});
// TODO: handle scaling factor when it's not constant 1;
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto lhs = list_val.front();
list_val.pop_front();
auto rhs = list_val.front();
list_val.pop_front();
Val* alpha = value_map[node->inputs()[2]->unique()];
auto out = alpha->isOneInt()
? binaryOp(
op_mapping[node->kind()].first,
lhs,
rhs,
TypePromotion::default_op_config)
: op_mapping[node->kind()].second(lhs, rhs, alpha);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
std::array<const char*, kNumBinaryFloatOps> BinaryFloatOp = {
"aten::div(Tensor self, Tensor other) -> Tensor",
"aten::div(Tensor self, Scalar other) -> Tensor",
"aten::atan2(Tensor self, Tensor other) -> Tensor"};
for (auto signature : BinaryFloatOp) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
static std::unordered_map<Symbol, BinaryOpType> op_mapping(
{{aten::div, BinaryOpType::Div},
{aten::atan2, BinaryOpType::Atan2}});
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto lhs = list_val.front();
list_val.pop_front();
auto rhs = list_val.front();
list_val.pop_front();
auto out = binaryOp(
op_mapping[node->kind()],
lhs,
rhs,
TypePromotion::float_op_config);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
std::array<const char*, kNumBinaryCastOps> BinaryCastOp = {
"aten::mul(Tensor self, Tensor other) -> Tensor",
"aten::mul(Tensor self, Scalar other) -> Tensor",
"aten::max(Tensor self, Tensor other) -> Tensor",
"aten::min(Tensor self, Tensor other) -> Tensor",
"aten::pow(Tensor self, Tensor exponent) -> Tensor",
"aten::pow(Tensor self, Scalar exponent) -> Tensor",
"aten::pow(Scalar self, Tensor exponent) -> Tensor",
"aten::remainder(Tensor self, Tensor other) -> Tensor",
"aten::fmod(Tensor self, Tensor other) -> Tensor",
"aten::__and__(Tensor self, Tensor other) -> Tensor",
"aten::__or__(Tensor self, Tensor other) -> Tensor",
"aten::__xor__(Tensor self, Tensor other) -> Tensor",
"aten::__lshift__(Tensor self, Tensor other) -> Tensor",
"aten::__rshift__(Tensor self, Tensor other) -> Tensor"};
for (auto signature : BinaryCastOp) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
static std::unordered_map<Symbol, BinaryOpType> op_mapping(
{{aten::mul, BinaryOpType::Mul},
{aten::min, BinaryOpType::Min},
{aten::max, BinaryOpType::Max},
{aten::pow, BinaryOpType::Pow},
{aten::remainder, BinaryOpType::Remainder},
{aten::fmod, BinaryOpType::Fmod},
{aten::__and__, BinaryOpType::And},
{aten::__or__, BinaryOpType::Or},
{aten::__xor__, BinaryOpType::Xor},
{aten::__lshift__, BinaryOpType::Lshift},
{aten::__rshift__, BinaryOpType::Rshift}});
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto lhs = list_val.front();
list_val.pop_front();
auto rhs = list_val.front();
list_val.pop_front();
auto out = binaryOp(
op_mapping[node->kind()],
lhs,
rhs,
TypePromotion::default_op_config);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
std::array<const char*, kNumBinaryComparisonOps> BinaryOp = {
"aten::eq(Tensor self, Tensor other) -> Tensor",
"aten::eq(Tensor self, Scalar other) -> Tensor",
"aten::ne(Tensor self, Tensor other) -> Tensor",
"aten::ne(Tensor self, Scalar other) -> Tensor",
"aten::ge(Tensor self, Tensor other) -> Tensor",
"aten::ge(Tensor self, Scalar other) -> Tensor",
"aten::gt(Tensor self, Tensor other) -> Tensor",
"aten::gt(Tensor self, Scalar other) -> Tensor",
"aten::le(Tensor self, Tensor other) -> Tensor",
"aten::le(Tensor self, Scalar other) -> Tensor",
"aten::lt(Tensor self, Tensor other) -> Tensor",
"aten::lt(Tensor self, Scalar other) -> Tensor"};
for (auto signature : BinaryOp) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
static std::unordered_map<Symbol, BinaryOpType> op_mapping(
{{aten::lt, BinaryOpType::LT},
{aten::le, BinaryOpType::LE},
{aten::gt, BinaryOpType::GT},
{aten::ge, BinaryOpType::GE},
{aten::ne, BinaryOpType::NE},
{aten::eq, BinaryOpType::Eq}});
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto lhs = list_val.front();
list_val.pop_front();
auto rhs = list_val.front();
list_val.pop_front();
auto out = binaryOp(
op_mapping[node->kind()],
lhs,
rhs,
TypePromotion::comparison_op_config);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
std::array<const char*, kNumUnaryOps> UnaryOp = {
"aten::abs(Tensor self) -> Tensor",
"aten::bitwise_not(Tensor self) -> Tensor",
"aten::ceil(Tensor self) -> Tensor",
"aten::floor(Tensor self) -> Tensor",
"aten::frac(Tensor self) -> Tensor",
"aten::neg(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
"aten::round(Tensor self) -> Tensor",
"aten::silu(Tensor self) -> Tensor",
"aten::trunc(Tensor self) -> Tensor",
};
for (auto signature : UnaryOp) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
static std::unordered_map<Symbol, UnaryOpType> op_mapping({
{aten::abs, UnaryOpType::Abs},
{aten::bitwise_not, UnaryOpType::Not},
{aten::ceil, UnaryOpType::Ceil},
{aten::floor, UnaryOpType::Floor},
{aten::frac, UnaryOpType::Frac},
{aten::neg, UnaryOpType::Neg},
{aten::relu, UnaryOpType::Relu},
{aten::round, UnaryOpType::Round},
{aten::silu, UnaryOpType::Silu},
{aten::trunc, UnaryOpType::Trunc},
});
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto operand = list_val.front();
list_val.pop_front();
auto out = unaryOp(op_mapping[node->kind()], operand);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
std::array<const char*, kNumUnaryFloatOps> UnaryFloatOp = {
"aten::log(Tensor self) -> Tensor",
"aten::log10(Tensor self) -> Tensor",
"aten::log1p(Tensor self) -> Tensor",
"aten::log2(Tensor self) -> Tensor",
"aten::lgamma(Tensor self) -> Tensor",
"aten::exp(Tensor self) -> Tensor",
"aten::expm1(Tensor self) -> Tensor",
"aten::erf(Tensor self) -> Tensor",
"aten::erfc(Tensor self) -> Tensor",
"aten::cos(Tensor self) -> Tensor",
"aten::acos(Tensor self) -> Tensor",
"aten::cosh(Tensor self) -> Tensor",
"aten::sin(Tensor self) -> Tensor",
"aten::asin(Tensor self) -> Tensor",
"aten::sinh(Tensor self) -> Tensor",
"aten::tan(Tensor self) -> Tensor",
"aten::atan(Tensor self) -> Tensor",
"aten::tanh(Tensor self) -> Tensor",
"aten::atanh(Tensor self) -> Tensor",
"aten::sqrt(Tensor self) -> Tensor",
"aten::rsqrt(Tensor self) -> Tensor",
"aten::reciprocal(Tensor self) -> Tensor",
"aten::sigmoid(Tensor self) -> Tensor"};
for (auto signature : UnaryFloatOp) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
static std::unordered_map<Symbol, UnaryOpType> op_mapping({
{aten::log, UnaryOpType::Log},
{aten::log10, UnaryOpType::Log10},
{aten::log1p, UnaryOpType::Log1p},
{aten::log2, UnaryOpType::Log2},
{aten::lgamma, UnaryOpType::Lgamma},
{aten::exp, UnaryOpType::Exp},
{aten::expm1, UnaryOpType::Expm1},
{aten::erf, UnaryOpType::Erf},
{aten::erfc, UnaryOpType::Erfc},
{aten::cos, UnaryOpType::Cos},
{aten::acos, UnaryOpType::Acos},
{aten::cosh, UnaryOpType::Cosh},
{aten::sin, UnaryOpType::Sin},
{aten::asin, UnaryOpType::Asin},
{aten::sinh, UnaryOpType::Sinh},
{aten::tan, UnaryOpType::Tan},
{aten::tanh, UnaryOpType::Tanh},
{aten::atan, UnaryOpType::Atan},
{aten::atanh, UnaryOpType::Atanh},
{aten::sqrt, UnaryOpType::Sqrt},
{aten::rsqrt, UnaryOpType::Rsqrt},
{aten::reciprocal, UnaryOpType::Reciprocal},
{aten::sigmoid, UnaryOpType::Sigmoid},
});
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto operand = list_val.front();
list_val.pop_front();
auto out = unaryOp(
op_mapping[node->kind()],
operand,
TypePromotion::float_op_config);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto operand = list_val.front();
list_val.pop_front();
auto out = randlike(operand);
value_map.emplace(node->output()->unique(), out);
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto operand = list_val.front();
list_val.pop_front();
auto& beta = value_map[node->inputs()[1]->unique()];
auto& threshold = value_map[node->inputs()[2]->unique()];
auto out = softplus(operand, beta, threshold);
value_map.emplace(node->output()->unique(), out);
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto operand = list_val.front();
list_val.pop_front();
auto& th = value_map[node->inputs()[1]->unique()];
auto& value = value_map[node->inputs()[2]->unique()];
auto out = threshold(operand, th, value);
value_map.emplace(node->output()->unique(), out);
},
nullptr,
nullptr);
}
{ // LTC uses threshold_backward for relu_backward
auto ptr_op = getOperatorForLiteral(
"aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto grad_output = list_val.front();
list_val.pop_front();
auto input = list_val.front();
auto& threshold = value_map[node->inputs()[2]->unique()];
auto comparison = binaryOp(
BinaryOpType::GT,
input,
threshold,
TypePromotion::comparison_op_config);
auto mask = castOp(input->getDataType().value(), comparison);
auto out = mul(grad_output, mask);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto operand = list_val.front();
list_val.pop_front();
Val* low = value_map.count(node->inputs()[1]->unique()) != 0
? *value_map[node->inputs()[1]->unique()]
: new Double(std::numeric_limits<float>::min());
Val* high = value_map.count(node->inputs()[2]->unique()) != 0
? *value_map[node->inputs()[2]->unique()]
: new Double(std::numeric_limits<float>::max());
auto out = clamp(operand, low, high);
value_map.emplace(node->output()->unique(), out);
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()],
value_map[node->inputs()[2]->unique()]);
auto condition = list_val.front();
list_val.pop_front();
auto x = list_val.front();
list_val.pop_front();
auto y = list_val.front();
list_val.pop_front();
auto out = where(condition, x, y);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
{
std::array<const char*, kNumLerpOps> LerpOp = {
"aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
"aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor"};
for (auto signature : LerpOp) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()],
value_map[node->inputs()[2]->unique()]);
auto self = list_val.front();
list_val.pop_front();
auto end = list_val.front();
list_val.pop_front();
auto weight = list_val.front();
list_val.pop_front();
auto out = lerp(self, end, weight);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
}
{
auto ptr_op = getOperatorForLiteral(
"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()],
value_map[node->inputs()[2]->unique()],
value_map[node->inputs()[3]->unique()]);
auto self = list_val.front();
list_val.pop_front();
auto tensor1 = list_val.front();
list_val.pop_front();
auto tensor2 = list_val.front();
list_val.pop_front();
auto value = list_val.front();
list_val.pop_front();
auto out = addcmul(self, tensor1, tensor2, value);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto input = list_val.front();
list_val.pop_front();
auto prob = list_val.front();
list_val.pop_front();
auto train = constant_as<bool>(node->input(2));
TORCH_INTERNAL_ASSERT(
train.has_value(), "dropout needs constant `train` flag");
if (train.value()) {
auto result = dropout(input->as<TensorView>(), prob);
value_map.emplace(node->output(0)->unique(), result.output);
value_map.emplace(node->output(1)->unique(), result.mask);
} else {
value_map.emplace(node->output(0)->unique(), input);
value_map.emplace(
node->output(1)->unique(),
ValueHolder(TensorViewBuilder().build(), format));
}
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::dropout(Tensor input, float p, bool train) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto input = list_val.front();
list_val.pop_front();
auto prob = list_val.front();
list_val.pop_front();
auto train = constant_as<bool>(node->input(2));
TORCH_INTERNAL_ASSERT(
train.has_value(), "dropout needs constant `train` flag");
if (train.value()) {
auto result = dropout(input->as<TensorView>(), prob);
value_map.emplace(node->output()->unique(), result.output);
} else {
value_map.emplace(node->output()->unique(), input);
}
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()],
value_map[node->inputs()[2]->unique()]);
auto grad = list_val.front();
list_val.pop_front();
auto mask = list_val.front();
list_val.pop_front();
auto scale = list_val.front();
list_val.pop_front();
auto output = dropout_backward(
grad->as<TensorView>(), mask->as<TensorView>(), scale);
value_map.emplace(node->output()->unique(), output);
},
nullptr,
nullptr);
}
{
std::array<const char*, kNumInstancenormFwd> InstanceNormFwd = {
"aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor"};
for (auto signature : InstanceNormFwd) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
auto fusion = FusionGuard::getCurFusion();
// TODO: handle channels last
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto input_t = list_val.front();
list_val.pop_front();
auto input = input_t->as<TensorView>();
TensorView* weight = nullptr;
if (!node->input(1)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
weight = value_map[node->input(1)->unique()]->as<TensorView>();
}
TensorView* bias = nullptr;
if (!node->input(2)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
bias = value_map[node->input(2)->unique()]->as<TensorView>();
}
TensorView* running_mean = nullptr;
if (!node->input(3)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
running_mean =
value_map[node->input(3)->unique()]->as<TensorView>();
TORCH_INTERNAL_ASSERT(
fusion->hasInput(running_mean),
"IO_tensor `instance_norm::running_mean` can only be input tensor to fusion");
}
TensorView* running_var = nullptr;
if (!node->input(4)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
running_var =
value_map[node->input(4)->unique()]->as<TensorView>();
TORCH_INTERNAL_ASSERT(
fusion->hasInput(running_var),
"IO_tensor `instance_norm::running_var` can only be input tensor to fusion");
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto use_input_stats = constant_as<bool>(node->input(5));
TORCH_INTERNAL_ASSERT(
use_input_stats.has_value(),
"The use_input_stats (bool) parameter is required.");
const bool kUseInputStats = use_input_stats.value();
Val* momentum_ptr = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (auto momentum = constant_as<float>(node->input(6))) {
momentum_ptr = new Double(momentum.value());
} else {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
momentum_ptr = value_map[node->input(6)->unique()];
}
Val* eps_ptr = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (auto eps = constant_as<float>(node->input(7))) {
eps_ptr = new Double(eps.value());
} else {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
eps_ptr = value_map[node->input(7)->unique()];
}
auto result = instance_norm(
input,
weight,
bias,
running_mean,
running_var,
kUseInputStats,
momentum_ptr,
eps_ptr);
if (node->kind() ==
c10::Symbol::fromQualString("aten::instance_norm")) {
value_map.emplace(node->output()->unique(), result.output);
}
},
[](const Node* node) -> bool { return true; },
[](const Node* node) -> OperatorType {
return OperatorType::Normalization;
});
}
}
{
std::array<const char*, kNumBatchnormFwd> BatchNormFwd = {
"aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)",
"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"};
for (auto signature : BatchNormFwd) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
Val* operand = nullptr;
std::tie(format, operand) =
value_map[node->input(0)->unique()].getEntry();
if (format.hasPermutation() && !format.isChannelsLast()) {
format = MemoryFormat::Contiguous();
operand = value_map[node->input(0)->unique()].maybeConvertValue(
format);
}
auto input = operand->as<TensorView>();
TensorView* weight = nullptr;
if (!node->input(1)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
weight = value_map[node->input(1)->unique()]->as<TensorView>();
}
TensorView* bias = nullptr;
if (!node->input(2)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
bias = value_map[node->input(2)->unique()]->as<TensorView>();
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto training = constant_as<bool>(node->input(5));
TORCH_INTERNAL_ASSERT(
training.has_value(),
"The training (bool) parameter is required.");
const bool kTraining = training.value();
TensorView* running_mean = nullptr;
if (!node->input(3)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
running_mean =
value_map[node->input(3)->unique()]->as<TensorView>();
}
TensorView* running_var = nullptr;
if (!node->input(4)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
running_var =
value_map[node->input(4)->unique()]->as<TensorView>();
}
Val* momentum_ptr = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (auto momentum = constant_as<float>(node->input(6))) {
momentum_ptr = new Double(momentum.value());
} else {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
momentum_ptr = value_map[node->input(6)->unique()];
}
Val* eps_ptr = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (auto eps = constant_as<float>(node->input(7))) {
eps_ptr = new Double(eps.value());
} else {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
eps_ptr = value_map[node->input(7)->unique()];
}
auto result = batch_norm(
input,
weight,
bias,
running_mean,
running_var,
kTraining,
momentum_ptr,
eps_ptr,
format.isChannelsLast());
if (node->kind() ==
c10::Symbol::fromQualString("aten::native_batch_norm") ||
node->kind() ==
c10::Symbol::fromQualString(
"aten::_batch_norm_impl_index")) {
// TODO: output 3 & 4 are not created
// we are not creating these outputs because codegen
// currently lacks the support.
value_map.emplace(
node->output(0)->unique(),
ValueHolder(result.output, format));
value_map.emplace(node->output(1)->unique(), result.mean);
value_map.emplace(node->output(2)->unique(), result.invstd);
} else if (
node->kind() ==
c10::Symbol::fromQualString("aten::batch_norm")) {
value_map.emplace(
node->output()->unique(),
ValueHolder(result.output, format));
}
},
[](const Node* node) -> bool { return true; },
[](const Node* node) -> OperatorType {
return OperatorType::Normalization;
});
}
}
{
auto ptr_op = getOperatorForLiteral(
"aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)");
REGISTER_PARSE_RULE(
ptr_op,
{
// discard impl_index and reservedSpace since we don't use them
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt,
value_map[node->inputs()[1]->unique()],
value_map[node->inputs()[2]->unique()]);
if (format.hasPermutation() && !format.isChannelsLast()) {
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[1]->unique()],
value_map[node->inputs()[2]->unique()]);
}
auto operand0 = list_val.front();
list_val.pop_front();
auto operand1 = list_val.front();
list_val.pop_front();
auto input = operand0->as<TensorView>();
auto grad_out = operand1->as<TensorView>();
TensorView* weight = nullptr;
if (!node->input(3)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
weight = value_map[node->input(3)->unique()]->as<TensorView>();
}
TensorView* running_mean = nullptr;
if (!node->input(4)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
running_mean =
value_map[node->input(4)->unique()]->as<TensorView>();
}
TensorView* running_var = nullptr;
if (!node->input(5)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
running_var =
value_map[node->input(5)->unique()]->as<TensorView>();
}
TensorView* save_mean = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (!node->input(6)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
save_mean = value_map[node->input(6)->unique()]->as<TensorView>();
}
TensorView* save_invstd = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (!node->input(7)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
save_invstd =
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
value_map[node->input(7)->unique()]->as<TensorView>();
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto training = constant_as<bool>(node->input(8));
TORCH_INTERNAL_ASSERT(
training.has_value(),
"The training (bool) parameter is required.");
const bool kTraining = training.value();
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
Val* eps_ptr = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (auto eps = constant_as<float>(node->input(9))) {
eps_ptr = new Double(eps.value());
} else {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
eps_ptr = value_map[node->input(7)->unique()];
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto out_mask_list = constant_as<c10::List<bool>>(node->input(10));
TORCH_INTERNAL_ASSERT(
out_mask_list.has_value(),
"output mask for batch_norm_backward");
std::vector<bool> output_mask;
for (const auto value : out_mask_list->vec()) {
output_mask.emplace_back(static_cast<bool>(value));
}
// TODO: merge this loop below.
if (kTraining) {
TORCH_INTERNAL_ASSERT(
save_mean != nullptr && save_invstd != nullptr,
"When training=True, save_mean and save_invstd are required.");
} else {
// TODO: this is not a legit assumption? Can't we run with
// track_running_stats == false && training == false
// which should just run through the case above.
TORCH_INTERNAL_ASSERT(
running_mean != nullptr && running_var != nullptr,
"When training=False, running_mean and running_invstd are required.");
}
auto grads = batch_norm_backward(
input,
grad_out,
weight,
running_mean,
running_var,
save_mean,
save_invstd,
kTraining,
eps_ptr,
output_mask,
format.isChannelsLast());
if (output_mask[0]) {
TORCH_INTERNAL_ASSERT(grads.grad_input != nullptr);
value_map.emplace(
node->output(0)->unique(),
ValueHolder(grads.grad_input, format));
} else {
TORCH_INTERNAL_ASSERT(grads.grad_input == nullptr);
value_map.emplace(
node->output(0)->unique(),
ValueHolder(TensorViewBuilder().build(), format));
}
if (output_mask[1]) {
TORCH_INTERNAL_ASSERT(grads.grad_weight != nullptr);
value_map.emplace(node->output(1)->unique(), grads.grad_weight);
} else {
TORCH_INTERNAL_ASSERT(grads.grad_weight == nullptr);
value_map.emplace(
node->output(1)->unique(), TensorViewBuilder().build());
}
if (output_mask[2]) {
TORCH_INTERNAL_ASSERT(grads.grad_bias != nullptr);
value_map.emplace(node->output(2)->unique(), grads.grad_bias);
} else {
TORCH_INTERNAL_ASSERT(grads.grad_bias == nullptr);
value_map.emplace(
node->output(2)->unique(), TensorViewBuilder().build());
}
},
[](const Node* node) -> bool { return true; },
[](const Node* node) -> OperatorType {
return OperatorType::Normalization;
});
}
{
std::array<const char*, kNumLayernormFwd> LayerNormFwd = {
"aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"};
for (auto signature : LayerNormFwd) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto input_t = list_val.front();
list_val.pop_front();
auto input = input_t->as<TensorView>();
auto norm_shape_optional =
constant_as<c10::List<int64_t>>(node->input(1));
TORCH_INTERNAL_ASSERT(
norm_shape_optional.has_value(),
"The Normalized_Shape list is required.");
auto norm_shape = norm_shape_optional->vec();
TensorView* weight = nullptr;
if (!node->input(2)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
weight = value_map[node->input(2)->unique()]->as<TensorView>();
}
TensorView* bias = nullptr;
if (!node->input(3)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
bias = value_map[node->input(3)->unique()]->as<TensorView>();
}
Val* eps_ptr = nullptr;
if (auto eps = constant_as<float>(node->input(4))) {
eps_ptr = new Double(eps.value());
} else {
eps_ptr = value_map[node->input(4)->unique()];
}
auto result =
layer_norm(input, norm_shape, weight, bias, eps_ptr);
if (node->kind() ==
c10::Symbol::fromQualString("aten::native_layer_norm")) {
value_map.emplace(node->output(0)->unique(), result.output);
value_map.emplace(node->output(1)->unique(), result.mean);
value_map.emplace(node->output(2)->unique(), result.invstd);
} else if (
node->kind() ==
c10::Symbol::fromQualString("aten::layer_norm")) {
value_map.emplace(node->output()->unique(), result.output);
}
},
// TODO: #ProfileIValue List should update this
[](const Node* node) -> bool { return true; },
[](const Node* node) -> OperatorType {
return OperatorType::Normalization;
});
}
}
{
auto ptr_op = getOperatorForLiteral(
"aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto grad_out_t = list_val.front();
list_val.pop_front();
auto input_t = list_val.front();
list_val.pop_front();
auto grad_out = grad_out_t->as<TensorView>();
auto input = input_t->as<TensorView>();
auto norm_shape_optional =
constant_as<c10::List<int64_t>>(node->input(2));
TORCH_INTERNAL_ASSERT(
norm_shape_optional.has_value(),
"The Normalized_Shape list is required.");
auto norm_shape = norm_shape_optional->vec();
auto mean = value_map[node->input(3)->unique()]->as<TensorView>();
auto rstd = value_map[node->input(4)->unique()]->as<TensorView>();
TensorView* weight = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (!node->input(5)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
weight = value_map[node->input(5)->unique()]->as<TensorView>();
}
TensorView* bias = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
if (!node->input(6)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
bias = value_map[node->input(6)->unique()]->as<TensorView>();
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto output_mask_optional =
constant_as<c10::List<bool>>(node->input(7));
TORCH_INTERNAL_ASSERT(
output_mask_optional.has_value(),
"output mask for layer_norm_backward");
std::vector<bool> output_mask = output_mask_optional->vec();
auto grad = layer_norm_backward(
grad_out,
input,
norm_shape,
mean,
rstd,
weight,
bias,
output_mask);
if (output_mask[0]) {
TORCH_INTERNAL_ASSERT(grad.grad_input != nullptr);
value_map.emplace(node->output(0)->unique(), grad.grad_input);
} else {
TORCH_INTERNAL_ASSERT(grad.grad_input == nullptr);
value_map.emplace(
node->output(0)->unique(), TensorViewBuilder().build());
}
if (output_mask[1] && weight != nullptr) {
TORCH_INTERNAL_ASSERT(grad.grad_weight != nullptr);
value_map.emplace(node->output(1)->unique(), grad.grad_weight);
} else {
TORCH_INTERNAL_ASSERT(grad.grad_weight == nullptr);
value_map.emplace(
node->output(1)->unique(), TensorViewBuilder().build());
}
if (output_mask[2] && bias != nullptr) {
TORCH_INTERNAL_ASSERT(grad.grad_bias != nullptr);
value_map.emplace(node->output(2)->unique(), grad.grad_bias);
} else {
TORCH_INTERNAL_ASSERT(grad.grad_bias == nullptr);
value_map.emplace(
node->output(2)->unique(), TensorViewBuilder().build());
}
},
// TODO: #ProfileIValue List should update this
[](const Node* node) -> bool { return true; },
[](const Node* node) -> OperatorType {
return OperatorType::Normalization;
});
}
{
auto ptr_op = getOperatorForLiteral(
"aten::softmax.int(Tensor self, int dim, int? dtype) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto input_t = list_val.front();
list_val.pop_front();
auto input = input_t->as<TensorView>();
auto dim_value = constant_as<int>(node->input(1));
TORCH_INTERNAL_ASSERT(
dim_value.has_value(), "dim in softmax is not valid");
auto output = softmax(input, dim_value.value());
value_map.emplace(node->output()->unique(), output);
},
[](const Node* node) -> bool {
if (node->inputs()[1]->node()->kind() != prim::Constant) {
return false;
}
// TODO: support dynamic input by profiling it
if (!node->inputs()[2]->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get())) &&
node->inputs()[2]->node()->kind() != prim::Constant) {
return false;
}
return true;
},
[](const Node* node) -> OperatorType {
return OperatorType::Normalization;
});
}
{ // LTC uses this op for softmax
auto ptr_op = getOperatorForLiteral(
"aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto input_t = list_val.front();
list_val.pop_front();
auto input = input_t->as<TensorView>();
auto dim_value = constant_as<int>(node->input(1));
TORCH_INTERNAL_ASSERT(
dim_value.has_value(), "dim in softmax is not valid");
auto output = softmax(input, dim_value.value());
value_map.emplace(node->output()->unique(), output);
},
[](const Node* node) -> bool {
if (node->inputs()[1]->node()->kind() != prim::Constant) {
return false;
}
if (node->inputs()[2]->node()->kind() != prim::Constant) {
return false;
} else {
const auto half_to_float = constant_as<bool>(node->input(2));
TORCH_INTERNAL_ASSERT(
half_to_float.has_value(), "Bool half_to_float is not valid");
auto input_tensor_type =
node->input(0)->type()->cast<TensorType>();
if (half_to_float.value() &&
input_tensor_type->scalarType() != at::ScalarType::Half) {
return false;
}
}
return true;
},
[](const Node* node) -> OperatorType {
return OperatorType::Normalization;
});
}
{
auto ptr_op = getOperatorForLiteral(
"aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
auto grad_output =
value_map[node->input(0)->unique()]->as<TensorView>();
auto output = value_map[node->input(1)->unique()]->as<TensorView>();
auto dim_value = constant_as<int>(node->input(2));
TORCH_INTERNAL_ASSERT(
dim_value.has_value(), "dim in softmax is not valid");
// input_dtype here is ignored! type_inference handles it
auto grad_input =
softmax_backward(grad_output, output, dim_value.value());
value_map.emplace(node->output()->unique(), grad_input);
},
[](const Node* node) -> bool {
if (node->inputs()[2]->node()->kind() != prim::Constant) {
return false;
}
return true;
},
[](const Node* node) -> OperatorType {
return OperatorType::Normalization;
});
}
{
auto ptr_op = getOperatorForLiteral(
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
REGISTER_PARSE_RULE(
ptr_op,
{
// TODO: support channels last in sum
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();
auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
TORCH_INTERNAL_ASSERT(
dims_list.has_value(),
"aten::sum cannot be fused with dynamic axes");
std::vector<int> dims;
for (const auto dim : dims_list->vec()) {
dims.emplace_back(static_cast<int>(dim));
}
auto keepdim = constant_as<bool>(node->input(2));
TORCH_INTERNAL_ASSERT(
keepdim.has_value(),
"aten::sum cannot be fused with dynamic keepdim");
auto out = sum(self->as<TensorView>(), dims, keepdim.value());
value_map.emplace(node->output()->unique(), out);
},
[](const Node* node) -> bool {
// TODO: support cast of output types
if (!node->inputs()[3]->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
// We can only handle output as half, float, and double;
if (const auto opt_ivalue = toIValue(node->input(3))) {
const auto scalar_type = opt_ivalue->toScalarType();
if (scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::BFloat16 ||
scalar_type == at::ScalarType::Half) {
return true;
}
}
return false;
}
// we don't support dynamic reduction axes;
if (node->inputs()[1]->node()->kind() != prim::Constant) {
return false;
}
// we don't support dynamic keepdim yet;
if (node->inputs()[2]->node()->kind() != prim::Constant) {
return false;
}
return true;
},
[](const Node* node) -> OperatorType {
return OperatorType::Reduction;
});
}
{
auto ptr_op = getOperatorForLiteral(
"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto operand = list_val.front();
list_val.pop_front();
auto self = operand->as<TensorView>();
auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
TORCH_INTERNAL_ASSERT(
dims_list.has_value(),
"aten::mean cannot be fused with dynamic axes");
std::vector<int> dims;
for (const auto dim : dims_list->vec()) {
dims.emplace_back(static_cast<int>(dim));
}
auto keepdim = constant_as<bool>(node->input(2));
TORCH_INTERNAL_ASSERT(
keepdim.has_value(),
"aten::mean cannot be fused with dynamic keepdim");
auto o_sum = sum(self, dims, keepdim.value());
Val* num_features = new Double(1);
for (auto axis : dims) {
if (axis < 0) {
axis += int(self->nDims());
}
num_features =
mul(num_features, self->domain()->domain()[axis]->extent());
}
auto out = div(o_sum, num_features);
value_map.emplace(node->output()->unique(), out);
},
[](const Node* node) -> bool {
// TODO: support cast of output types
if (!node->inputs()[3]->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
// We can only handle output as half, float, and double;
if (const auto opt_ivalue = toIValue(node->input(3))) {
const auto scalar_type = opt_ivalue->toScalarType();
if (scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::BFloat16 ||
scalar_type == at::ScalarType::Half) {
return true;
}
}
return false;
}
// we don't support dynamic reduction axes;
if (node->inputs()[1]->node()->kind() != prim::Constant) {
return false;
}
// we don't support dynamic keepdim yet;
if (node->inputs()[2]->node()->kind() != prim::Constant) {
return false;
}
return true;
},
[](const Node* node) -> OperatorType {
return OperatorType::Reduction;
});
}
{
std::array<const char*, kNumSumToSize> SumToSize = {
"aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)",
"aten::sum_to_size(Tensor self, int[] size) -> Tensor"};
for (auto signature : SumToSize) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();
auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
TORCH_INTERNAL_ASSERT(
size_to.has_value(),
"aten::sum cannot be fused with dynamic axes");
if (!size_to->empty()) {
auto out = sum_to(self->as<TensorView>(), size_to->vec());
value_map.emplace(node->output()->unique(), out);
} else {
// We are introducing alias here!
value_map.emplace(node->output()->unique(), self);
}
},
[](const Node* node) -> bool {
// we don't support dynamic reduction axes;
if (node->inputs()[1]->node()->kind() != prim::Constant) {
return false;
}
return true;
// auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
// return size_to.has_value() && !size_to->empty();
},
[](const Node* node) -> OperatorType {
auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
// technically size_to->empty() should never occur, as specialized
// _grad_sum_to_size should have been removed by optimization pass
if (size_to->empty()) {
return OperatorType::ElementWise;
} else {
return OperatorType::ReductionToSize;
}
});
}
}
{
std::array<const char*, kNumAutocastOps> AutocastOps = {
"aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)",
"aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)"};
for (auto signature : AutocastOps) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();
auto out = set(self);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
}
// Limiting aten::to implementation to only change the dtype of a tensor
{
auto ptr_op = getOperatorForLiteral(
"aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();
// we need static type for cast
TORCH_INTERNAL_ASSERT(
node->input(1)->node()->kind() == prim::Constant);
auto dtype = toIValue(node->input(1))->toScalarType();
// We want to keep our internal fusion math in FP32
// Shape Inference will continue to propagate the right
// type to outputs unchanged.
if (dtype == at::ScalarType::Half) {
dtype = at::ScalarType::Float;
}
if (dtype == at::ScalarType::BFloat16) {
dtype = at::ScalarType::Float;
}
auto out = castOp(aten_to_data_type(dtype), self);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::type_as(Tensor self, Tensor other) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();
// TODO: switch to PyTorch dtype as it's closer to truth.
// For now, reality is that PyTorch IR profiling information could
// be missing even with profiling executor, due to upstream
// transformations between profiling runs to fusion pass.
auto opt_dtype =
value_map[node->inputs()[1]->unique()]->getDataType();
TORCH_INTERNAL_ASSERT(opt_dtype.has_value());
auto out = castOp(opt_dtype.value(), self);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
{
// We are not fusing `linear` yet, because we can't codegen efficient gemm
// However, we still need this here, so PE would insert profile node for
// this node.
// During fusion pass, We decompose linear into gemm + elementwise.
auto ptr_op = getOperatorForLiteral(
"aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
// this entry is created so we do profile input tensors;
TORCH_INTERNAL_ASSERT(false, "not implemented yet");
},
[](const Node* node) -> bool {
// We only profile `linear` layer with bias.
if (node->input(2)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
return false;
}
return true;
});
}
{
auto ptr_op = getOperatorForLiteral(
"prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)");
REGISTER_PARSE_RULE(
ptr_op,
{
// this entry is created so we do profile input tensors;
if (node->input(1)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
// forwarding the value;
value_map.emplace(
node->output()->unique(),
value_map[node->inputs()[0]->unique()]);
} else {
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto lhs = list_val.front();
list_val.pop_front();
auto rhs = list_val.front();
list_val.pop_front();
auto out = binaryOp(
BinaryOpType::Add,
lhs,
rhs,
TypePromotion::default_op_config);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
}
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();
auto out = gelu(self);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::gelu_backward(Tensor grad, Tensor self) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto grad_out = list_val.front();
list_val.pop_front();
auto self = list_val.front();
list_val.pop_front();
auto grad_in = gelu_backward(grad_out, self);
value_map.emplace(
node->output()->unique(), ValueHolder(grad_in, format));
},
nullptr,
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();
auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
TORCH_INTERNAL_ASSERT(
dims_list.has_value(),
"aten::amax cannot be fused with dynamic axes");
std::vector<int> dims;
for (const auto dim : dims_list->vec()) {
dims.emplace_back(static_cast<int>(dim));
}
auto keepdim = constant_as<bool>(node->input(2));
TORCH_INTERNAL_ASSERT(
keepdim.has_value(),
"aten::amax cannot be fused with dynamic keepdim");
auto out = max(self->as<TensorView>(), dims, keepdim.value());
value_map.emplace(node->output()->unique(), out);
},
[](const Node* node) -> bool {
// we don't support dynamic reduction axes;
if (node->inputs()[1]->node()->kind() != prim::Constant) {
return false;
}
// we don't support dynamic keepdim yet;
if (node->inputs()[2]->node()->kind() != prim::Constant) {
return false;
}
return true;
},
[](const Node* node) -> OperatorType {
return OperatorType::Reduction;
});
}
/*
// TODO: Enable view in parser by detecting non-alias view operation
{
std::array<const char*, kNumViewSize> View = {
"aten::view(Tensor(a) self, int[] size) -> Tensor(a)",
"aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"};
for (auto signature : View) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
ptr_op,
{
auto self_value = node->inputs()[0];
auto self = value_map[self_value->unique()]->as<TensorView>();
auto self_type = self_value->type()->cast<c10::TensorType>();
TORCH_INTERNAL_ASSERT(self_type != nullptr);
auto self_sizes = getTensorSizes(self_type);
auto size_optional =
constant_as<c10::List<int64_t>>(node->input(1));
TORCH_INTERNAL_ASSERT(
size_optional.has_value(), "The size parameter is required.");
auto output = view(self, self_sizes, size_optional->vec());
value_map.emplace(node->output()->unique(), output);
},
nullptr,
nullptr);
}
}
*/
}