in torch/csrc/jit/tensorexpr/tensorexpr_init.cpp [59:912]
void initTensorExprBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
// Tensor Expr Classes
auto te = m.def_submodule("_te");
auto dtype_class =
py::class_<Dtype>(te, "Dtype").def(py::init(&parsePythonDtype));
py::implicitly_convertible<py::object, Dtype>();
#define DTYPE_SINGLETON_ACCESSOR(ctype, name) \
dtype_class.def_property_readonly_static( \
#name, [](py::object) { return k##name; }); // NOLINT
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_SINGLETON_ACCESSOR)
#undef DTYPE_SINGLETON_ACCESSOR
auto expr_handle_class =
py::class_<ExprHandle>(te, "ExprHandle")
.def(
"__str__",
[](const ExprHandle& self) {
std::stringstream ss;
ss << self;
return ss.str();
})
.def(py::self + py::self)
.def(py::self * py::self)
.def(py::self - py::self)
.def(py::self / py::self)
.def(py::self % py::self)
.def(py::self == py::self)
.def(py::self != py::self)
.def(py::self > py::self)
.def(py::self >= py::self)
.def(py::self < py::self)
.def(py::self <= py::self)
.def(py::self & py::self)
.def(py::self | py::self)
.def(py::self ^ py::self)
.def(py::self << py::self)
.def(py::self >> py::self)
.def(
"__pow__",
[](const ExprHandle& self, const ExprHandle& other) {
return pow(self, other);
})
.def("sin", [](const ExprHandle& self) { return sin(self); })
.def("cos", [](const ExprHandle& self) { return cos(self); })
.def("tan", [](const ExprHandle& self) { return tan(self); })
.def("asin", [](const ExprHandle& self) { return asin(self); })
.def("acos", [](const ExprHandle& self) { return acos(self); })
.def("atan", [](const ExprHandle& self) { return atan(self); })
.def("sinh", [](const ExprHandle& self) { return sinh(self); })
.def("cosh", [](const ExprHandle& self) { return cosh(self); })
.def("tanh", [](const ExprHandle& self) { return tanh(self); })
.def("sigmoid", [](const ExprHandle& self) { return sigmoid(self); })
.def("exp", [](const ExprHandle& self) { return exp(self); })
.def("expm1", [](const ExprHandle& self) { return expm1(self); })
.def(
"abs",
[](const ExprHandle& self) { return tensorexpr::abs(self); })
.def("log", [](const ExprHandle& self) { return log(self); })
.def(
"fast_tanh",
[](const ExprHandle& self) { return fast_tanh(self); })
.def(
"fast_sigmoid",
[](const ExprHandle& self) { return fast_sigmoid(self); })
.def(
"fast_log", [](const ExprHandle& self) { return fast_log(self); })
.def("log_vml", [](const ExprHandle& self) { return log_vml(self); })
.def("log2", [](const ExprHandle& self) { return log2(self); })
.def("log10", [](const ExprHandle& self) { return log10(self); })
.def("log1p", [](const ExprHandle& self) { return log1p(self); })
.def("erf", [](const ExprHandle& self) { return erf(self); })
.def("erfc", [](const ExprHandle& self) { return erfc(self); })
.def(
"sqrt",
[](const ExprHandle& self) { return tensorexpr::sqrt(self); })
.def("rsqrt", [](const ExprHandle& self) { return rsqrt(self); })
.def("ceil", [](const ExprHandle& self) { return ceil(self); })
.def("floor", [](const ExprHandle& self) { return floor(self); })
.def("round", [](const ExprHandle& self) { return round(self); })
.def("trunc", [](const ExprHandle& self) { return trunc(self); })
.def("frac", [](const ExprHandle& self) { return frac(self); })
.def("lgamma", [](const ExprHandle& self) { return lgamma(self); })
.def("isnan", [](const ExprHandle& self) { return isnan(self); })
.def(
"cast",
[](const ExprHandle& self, const Dtype& dt) {
return Cast::make(dt, self);
})
#define EXPRHANDLE_INIT(ctype, name) \
.def(py::init([](ctype val) { return name##Imm::make(val); }))
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_INIT)
#undef EXPRHANDLE_INIT
;
#define EXPRHANDLE_IMPL_CONV(ctype, name) \
py::implicitly_convertible<ctype, ExprHandle>();
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_IMPL_CONV)
#undef EXPRHANDLE_IMPL_CONV
te.def(
"ifThenElse",
[](const ExprHandle& c, const ExprHandle& t, const ExprHandle& f) {
return ifThenElse(c, t, f);
});
te.def("sin", [](const ExprHandle& v1) { return sin(v1); });
te.def("cos", [](const ExprHandle& v1) { return cos(v1); });
te.def("tan", [](const ExprHandle& v1) { return tan(v1); });
te.def("asin", [](const ExprHandle& v1) { return asin(v1); });
te.def("acos", [](const ExprHandle& v1) { return acos(v1); });
te.def("atan", [](const ExprHandle& v1) { return atan(v1); });
te.def("sinh", [](const ExprHandle& v1) { return sinh(v1); });
te.def("cosh", [](const ExprHandle& v1) { return cosh(v1); });
te.def("tanh", [](const ExprHandle& v1) { return tanh(v1); });
te.def("sigmoid", [](const ExprHandle& v1) { return sigmoid(v1); });
te.def("exp", [](const ExprHandle& v1) { return exp(v1); });
te.def("expm1", [](const ExprHandle& v1) { return expm1(v1); });
te.def("abs", [](const ExprHandle& v1) { return abs(v1); });
te.def("log", [](const ExprHandle& v1) { return log(v1); });
te.def("log2", [](const ExprHandle& v1) { return log2(v1); });
te.def("log10", [](const ExprHandle& v1) { return log10(v1); });
te.def("log1p", [](const ExprHandle& v1) { return log1p(v1); });
te.def("erf", [](const ExprHandle& v1) { return erf(v1); });
te.def("erfc", [](const ExprHandle& v1) { return erfc(v1); });
te.def("sqrt", [](const ExprHandle& v1) { return sqrt(v1); });
te.def("rsqrt", [](const ExprHandle& v1) { return rsqrt(v1); });
te.def("ceil", [](const ExprHandle& v1) { return ceil(v1); });
te.def("floor", [](const ExprHandle& v1) { return floor(v1); });
te.def("round", [](const ExprHandle& v1) { return round(v1); });
te.def("trunc", [](const ExprHandle& v1) { return trunc(v1); });
te.def("frac", [](const ExprHandle& v1) { return frac(v1); });
te.def("lgamma", [](const ExprHandle& v1) { return lgamma(v1); });
te.def("isnan", [](const ExprHandle& v1) { return isnan(v1); });
te.def("atan2", [](const ExprHandle& v1, const ExprHandle& v2) {
return atan2(v1, v2);
});
te.def("pow", [](const ExprHandle& v1, const ExprHandle& v2) {
return pow(v1, v2);
});
te.def("fmod", [](const ExprHandle& v1, const ExprHandle& v2) {
return fmod(v1, v2);
});
te.def("remainder", [](const ExprHandle& v1, const ExprHandle& v2) {
return remainder(v1, v2);
});
#define EXPRHANDLE_CTOR(ctype, name) \
expr_handle_class.def_static(#ctype, [](ctype v) { return ExprHandle(v); });
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_CTOR)
#undef EXPRHANDLE_CTOR
py::class_<VarHandle, ExprHandle>(te, "VarHandle")
.def(
"__str__",
[](const ExprHandle& self) {
std::stringstream ss;
ss << self;
return ss.str();
})
.def(py::init<Dtype>())
.def(py::init<const std::string&, Dtype>());
py::class_<BufHandle, ExprHandle>( // NOLINT
te,
"BufHandle")
.def(
py::init<const std::string&, const std::vector<ExprHandle>&, Dtype>())
.def(py::init<const std::vector<ExprHandle>&, Dtype>())
.def(py::init<Dtype>())
.def(
"__hash__",
[](const BufHandle& self) {
return std::hash<BufPtr>()(self.node());
})
.def(
"__eq__",
[](const BufHandle& self, const BufHandle& other) {
return self.node() == other.node();
})
.def(
"load",
[](BufHandle& self, const std::vector<ExprHandle>& v) {
return Load::make(self, v);
})
.def(
"load",
[](BufHandle& self, const ExprHandle& v) {
return Load::make(self, {v});
})
.def(
"store",
[](BufHandle& self,
const std::vector<ExprHandle>& i,
const ExprHandle& v) { return Store::make(self, i, v); })
.def(
"store",
[](BufHandle& self, const ExprHandle& i, const ExprHandle& v) {
return Store::make(self, {i}, v);
});
py::class_<Tensor>(te, "Tensor")
.def(
py::init([](BufHandle& b, StmtPtr s) { return Tensor(b.node(), s); }))
.def(
"load",
[](Tensor& self, const std::vector<ExprHandle>& v) {
return self.load(v);
})
.def("buf", [](Tensor& self) { return BufHandle(self.buf()); })
.def("stmt", &Tensor::stmt);
py::class_<Cast, std::shared_ptr<Cast>>(te, "Cast")
.def_static("make", &Cast::make)
.def(
"src_value",
[](CastPtr& self) { return ExprHandle(self->src_value()); })
.def("set_src_value", [](CastPtr& self, const ExprHandle& value) {
self->set_src_value(value.node());
});
py::class_<DimArg>(te, "DimArg")
.def(py::init<const ExprHandle&>())
.def(py::init<const ExprHandle&, const std::string&>());
py::implicitly_convertible<ExprHandle, DimArg>();
py::implicitly_convertible<int32_t, DimArg>();
py::implicitly_convertible<int64_t, DimArg>();
te.def(
"Compute",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
py::function func) {
if (dim_args.size() == 1) {
return Compute(func_name, dim_args, [&func](const VarHandle& a) {
return py::cast<ExprHandle>(func(a));
});
} else if (dim_args.size() == 2) {
return Compute(
func_name,
dim_args,
[&func](const VarHandle& a, const VarHandle& b) {
return py::cast<ExprHandle>(func(a, b));
});
} else if (dim_args.size() == 3) {
return Compute(
func_name,
dim_args,
[&func](
const VarHandle& a, const VarHandle& b, const VarHandle& c) {
return py::cast<ExprHandle>(func(a, b, c));
});
} else if (dim_args.size() == 4) {
return Compute(
func_name,
dim_args,
[&func](
const VarHandle& a,
const VarHandle& b,
const VarHandle& c,
const VarHandle& d) {
return py::cast<ExprHandle>(func(a, b, c, d));
});
} else {
throw std::runtime_error("Too many args");
}
},
py::return_value_policy::reference);
te.def(
"Compute2",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
py::function func) {
return Compute(
func_name, dim_args, [&func](const std::vector<VarHandle>& dims) {
return py::cast<ExprHandle>(func(dims));
});
},
py::return_value_policy::reference);
py::class_<Reducer>(te, "Reducer")
.def(py::init<
ExprHandle,
std::function<ExprHandle(ExprHandle, ExprHandle)>>());
py::class_<Sum, Reducer>(te, "Sum").def(py::init<>());
py::class_<Maximum, Reducer>(te, "Maximum").def(py::init<Dtype>());
te.def(
"Reduce",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
const Reducer& reducer,
Tensor buffer,
const std::vector<DimArg>& reduce_args) {
return Reduce(func_name, dim_args, reducer, buffer, reduce_args);
},
py::return_value_policy::reference);
te.def(
"Reduce",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
const Reducer& reducer,
const BufHandle& buffer,
const std::vector<DimArg>& reduce_args) {
return Reduce(func_name, dim_args, reducer, buffer, reduce_args);
},
py::return_value_policy::reference);
te.def(
"Reduce",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
const Reducer& reducer,
const std::function<ExprHandle(const std::vector<VarHandle>&)>&
body_func,
const std::vector<DimArg>& reduce_args) {
return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
},
py::return_value_policy::reference);
te.def(
"Reduce",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
const Reducer& reducer,
const std::function<ExprHandle(const std::vector<VarHandle>&)>&
init_func,
const std::function<ExprHandle(const std::vector<VarHandle>&)>&
body_func,
const std::vector<DimArg>& reduce_args) {
return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
},
py::return_value_policy::reference);
py::class_<Stmt, std::shared_ptr<Stmt>>(te, "Stmt")
.def(py::init([](const std::vector<StmtPtr>& stmts) {
return tensorexpr::Block::make(stmts);
}))
.def("__str__", [](Stmt& self) {
std::stringstream ss;
ss << self;
return ss.str();
});
py::class_<Store, Stmt, std::shared_ptr<Store>>(te, "Store")
.def_static(
"make",
[](const BufHandle& buf,
std::vector<ExprHandle>& indices,
const ExprHandle& value) {
return Store::make(buf, indices, value);
});
py::class_<For, Stmt, std::shared_ptr<For>>(te, "For")
.def("index_var", [](For& self) { return VarHandle(self.var()); })
.def("body", &For::body)
.def("set_parallel", &For::set_parallel)
.def(
"set_gpu_block_index",
[](For& self, int block_index) {
self.set_gpu_block_index(block_index);
})
.def(
"set_gpu_thread_index",
[](For& self, int thread_index) {
self.set_gpu_thread_index(thread_index);
})
.def_static(
"make",
[](const VarHandle& var,
const ExprHandle& start,
const ExprHandle& stop,
StmtPtr body) { return For::make(var, start, stop, body); });
py::class_<Cond, Stmt, std::shared_ptr<Cond>>(te, "Cond")
.def_static(
"make",
[](const ExprHandle& condition,
StmtPtr true_stmt,
StmtPtr false_stmt) {
return Cond::make(condition, true_stmt, false_stmt);
})
.def("true_stmt", &Cond::true_stmt)
.def("false_stmt", &Cond::false_stmt);
py::class_<tensorexpr::Block, Stmt, std::shared_ptr<tensorexpr::Block>>(
te, "Block")
.def(py::init([](const std::vector<StmtPtr>& stmts) {
return tensorexpr::Block::make(stmts);
}))
.def("stmts", &tensorexpr::Block::stmts);
py::class_<ExternalCall, Stmt, std::shared_ptr<ExternalCall>>(
te, "ExternalCall")
.def(py::init(&ExternalCall::make));
py::class_<LoopNest>(te, "LoopNest")
.def(py::init<const std::vector<Tensor>&>())
.def(py::init<const std::vector<Tensor>&, const std::vector<Tensor>&>())
.def(py::init([](StmtPtr s, const std::vector<BufHandle>& bufs) {
std::unordered_set<BufPtr> buf_nodes;
for (auto& buf : bufs) {
buf_nodes.insert(buf.node());
}
return std::make_unique<LoopNest>(s, buf_nodes);
}))
.def("vectorize_inner_loops", &LoopNest::vectorizeInnerLoops)
.def(
"prepare_for_codegen",
[](LoopNest& self) { return self.prepareForCodegen(); },
py::return_value_policy::reference)
.def(
"get_loop_body_for",
[](const LoopNest& self, Tensor t) { return self.getLoopBodyFor(t); },
py::return_value_policy::reference)
.def(
"get_loop_body_for",
[](const LoopNest& self, BufHandle& b) {
return self.getLoopBodyFor(b.node());
},
py::return_value_policy::reference)
.def(
"get_loops_for",
[](const LoopNest& self, Tensor t) {
return self.getLoopStmtsFor(t);
},
py::return_value_policy::reference)
.def(
"get_all_loopnests_for",
[](const LoopNest& self, const BufHandle& b) {
return self.getAllLoopNestsWritingToBuf(b.node());
},
py::return_value_policy::reference)
.def(
"get_enclosing_loopnest",
[](const LoopNest& self, StmtPtr s) {
return self.getEnclosingLoopNest(s);
},
py::return_value_policy::reference)
.def(
"get_innermost_loops_for",
[](const LoopNest& self, const BufHandle& b) {
return self.getAllInnermostLoopsWritingToBuf(b.node());
},
py::return_value_policy::reference)
.def(
"get_writes_for",
[](const LoopNest& self, const BufHandle& b) {
return self.getAllWritesToBuf(b.node());
},
py::return_value_policy::reference)
.def(
"get_loop_at",
[](const LoopNest& self,
ForPtr root,
const std::vector<int>& indices) {
return self.getLoopAt(root, indices);
},
py::return_value_policy::reference)
.def(
"get_parent_loop",
[](const LoopNest& self, StmtPtr s) { return self.getParentLoop(s); },
py::return_value_policy::reference)
.def_static(
"get_loop_stmts_in_loopnest",
[](ForPtr f, size_t num) {
return LoopNest::getLoopStmtsInLoopNest(f, num);
},
py::return_value_policy::reference)
.def(
"split_with_tail",
[](ForPtr f, int factor) {
ForPtr inner = nullptr, tail = nullptr;
LoopNest::splitWithTail(f, factor, &inner, &tail);
return std::make_tuple(inner, tail);
},
py::return_value_policy::reference)
.def(
"split_with_mask",
[](ForPtr f, int factor) {
ForPtr inner = nullptr;
LoopNest::splitWithMask(f, factor, &inner);
return inner;
},
py::return_value_policy::reference)
.def(
"slice_head",
[](ForPtr f, int factor) {
ForPtr head = nullptr, tail = nullptr;
LoopNest::sliceHead(f, factor, &head, &tail);
return std::make_tuple(head, tail);
},
py::return_value_policy::reference)
.def(
"slice_tail",
[](ForPtr f, int factor) {
ForPtr head = nullptr, tail = nullptr;
LoopNest::sliceTail(f, factor, &head, &tail);
return std::make_tuple(head, tail);
},
py::return_value_policy::reference)
.def_static(
"normalize",
[](ForPtr f) {
LoopNest::normalize(f);
return f;
},
py::return_value_policy::reference)
.def(
"tile",
[](LoopNest& self, ForPtr x, ForPtr y, int x_factor, int y_factor) {
return self.tile(x, y, x_factor, y_factor);
},
py::return_value_policy::reference)
.def_static(
"distribute_loop",
[](ForPtr f) { return LoopNest::distributeLoop(f); },
py::return_value_policy::reference)
.def_static(
"distribute_loop",
[](ForPtr f, const std::unordered_set<StmtPtr>& pivots) {
return LoopNest::distributeLoop(f, pivots);
},
py::return_value_policy::reference)
.def_static(
"distribute_loop_over_inner_loops",
[](ForPtr f) { return LoopNest::distributeLoopOverInnerLoops(f); },
py::return_value_policy::reference)
.def_static(
"unsafe_fuse_loops",
[](const std::vector<ForPtr>& loops) {
ForPtr fused_loop = nullptr;
LoopNest::unsafeFuseLoops(loops, &fused_loop);
return fused_loop;
},
py::return_value_policy::reference)
.def_static(
"fuse_loops",
[](const std::vector<ForPtr>& loops) {
ForPtr fused_loop = nullptr;
LoopNest::fuseLoops(loops, &fused_loop);
return fused_loop;
},
py::return_value_policy::reference)
.def_static(
"reorder",
[](const std::vector<ForPtr>& loops,
const std::vector<size_t>& permutation) {
return LoopNest::reorder(loops, permutation);
},
py::return_value_policy::reference)
.def(
"unroll",
[](const LoopNest& self, ForPtr f) {
StmtPtr unrolled = nullptr;
self.unroll(f, &unrolled);
return unrolled;
},
py::return_value_policy::reference)
.def(
"vectorize",
[](ForPtr f) { LoopNest::vectorize(f); },
py::return_value_policy::reference)
.def_static(
"compress_buffer",
[](BufHandle& buf, StmtPtr stmt) {
return LoopNest::compressBuffer(buf.node(), stmt);
},
py::return_value_policy::reference)
.def_static(
"cache_accesses",
[](const BufHandle& producer,
const std::string& name,
StmtPtr consumer) {
std::pair<BufPtr, StmtPtr> ret =
LoopNest::cacheAccesses(producer.node(), name, consumer);
return std::make_pair(BufHandle(ret.first), ret.second);
},
py::return_value_policy::reference)
.def_static(
"compute_at",
[](StmtPtr s, ForPtr at) { LoopNest::computeAt(s, at); })
.def(
"compute_inline",
[](LoopNest& self, StmtPtr s) { self.computeInline(s); },
py::return_value_policy::reference)
.def(
"compute_inline",
[](LoopNest& self, const BufHandle& b) {
self.computeInline(b.node());
},
py::return_value_policy::reference)
.def(
"rfactor",
[](StmtPtr s, ForPtr target_for) {
BufPtr rfac_buf = nullptr;
LoopNest::rfactor(s, target_for, &rfac_buf);
return BufHandle(rfac_buf);
},
py::return_value_policy::reference)
.def(
"flatten",
[](LoopNest& self, const std::vector<ForPtr>& loops) {
ForPtr flattened = nullptr;
LoopNest::flatten(loops, &flattened);
return flattened;
},
py::return_value_policy::reference)
.def(
"reorder_axis",
&LoopNest::reorderAxis,
py::return_value_policy::reference)
.def("simplify", &LoopNest::simplify, py::return_value_policy::reference)
.def_static("sanitize_names", &LoopNest::sanitizeNames)
.def(
"inline_intermediate_bufs",
[](LoopNest& self, bool allow_duplicated_work) {
self.inlineIntermediateBufs(allow_duplicated_work);
})
.def(
"eliminate_dead_stores",
[](LoopNest& self) { self.eliminateDeadStores(); })
.def(
"__str__",
[](const LoopNest& self) {
std::stringstream ss;
ss << *self.root_stmt();
return ss.str();
})
.def(
"root_stmt",
&LoopNest::root_stmt,
py::return_value_policy::reference);
te.def(
"simplify",
[](StmtPtr stmt) { return IRSimplifier::simplify(stmt); },
py::return_value_policy::reference);
te.def(
"lower",
[](std::string op_str,
py::list inputs,
std::vector<ExprHandle> outputShape,
Dtype outputType) {
auto op = c10::Symbol::fromQualString(op_str);
std::vector<ArgValue> argInputs;
for (auto inp : inputs) {
argInputs.push_back(convertPyToArgValue(inp));
}
if (NNCLoweringFunction lowering =
getStandardLoweringFor(op.toQualString())) {
return lowering(
argInputs, outputShape, outputType.scalar_type(), at::kCPU);
}
std::string msg = std::string("Unhandled node kind (in te.lower): ") +
op.toQualString();
throw malformed_input(msg);
});
py::class_<ArgValue>(te, "ArgValue")
.def(py::init([](py::handle inp) {
return std::make_unique<ArgValue>(convertPyToArgValue(inp));
}))
.def(
"as_buf",
[](const ArgValue& self) { return c10::get<BufHandle>(self); })
.def(
"as_var",
[](const ArgValue& self) { return c10::get<VarHandle>(self); })
.def(
"as_float",
[](const ArgValue& self) { return c10::get<double>(self); })
.def(
"as_int",
[](const ArgValue& self) { return c10::get<int64_t>(self); })
.def("as_bool", [](const ArgValue& self) { return c10::get<bool>(self); })
.def(
"as_none",
[](const ArgValue& self) { return c10::get<ArgNone>(self); })
.def(
"as_buflist",
[](const ArgValue& self) { return c10::get<BufList>(self); })
.def("as_intlist", [](const ArgValue& self) {
return c10::get<IntList>(self);
});
py::class_<c10::ScalarType>(te, "ScalarType");
using TSGraph = std::shared_ptr<Graph>;
py::class_<TensorExprKernel>(te, "TensorExprKernel")
.def(py::init<const TSGraph&>())
.def(
py::init([](const TSGraph& g,
std::unordered_map<std::string, NNCLoweringFunction>
custom_lowerings_str,
std::vector<int64_t> symbolic_shape_inputs,
bool pre_alloc = false) {
std::unordered_map<c10::Symbol, NNCLoweringFunction>
custom_lowerings;
for (auto& kv : custom_lowerings_str) {
custom_lowerings[c10::Symbol::fromQualString(kv.first)] =
kv.second;
}
return std::make_unique<TensorExprKernel>(
g, custom_lowerings, symbolic_shape_inputs, pre_alloc);
}),
py::arg("g"),
py::arg("custom_lowerings_str"),
py::arg("symbolic_shape_inputs") = std::vector<int64_t>(),
py::arg("pre_alloc") = false)
.def(
"run",
[](TensorExprKernel& self, const py::tuple& inputs) {
Stack stack;
stack.reserve(inputs.size()); // captures?
for (auto& obj : inputs) {
stack.push_back(toTypeInferredIValue(obj));
}
auto g_inputs = self.graph()->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
if (stack[i].isTensor()) {
g_inputs[i]->setType(stack[i].type());
}
}
self.run(stack);
return createPyObjectForStack(std::move(stack));
})
.def(
"fallback",
[](TensorExprKernel& self, const py::tuple& inputs) {
Stack stack;
stack.reserve(inputs.size()); // captures?
for (auto& obj : inputs) {
stack.push_back(toTypeInferredIValue(obj));
}
auto g_inputs = self.graph()->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
if (stack[i].isTensor()) {
g_inputs[i]->setType(stack[i].type());
}
}
self.fallback(stack);
return createPyObjectForStack(std::move(stack));
})
.def(
"get_codegen_stmt",
[](TensorExprKernel& self) { return self.getCodeGenStmt(); },
py::return_value_policy::reference)
.def(
"get_code_text",
[](TensorExprKernel& self, const std::string& attr = "") {
return self.getCodeText(attr);
},
py::arg("attr") = "")
.def("recompile", [](TensorExprKernel& self) { self.recompile(); });
py::class_<CodeGen>(te, "CodeGen")
.def(
"call",
[](CodeGen& self, const py::sequence& values) {
std::vector<CodeGen::CallArg> value_ptrs;
value_ptrs.reserve(py::len(values));
for (const auto& value : values) {
if (py::isinstance<py::int_>(value)) {
value_ptrs.emplace_back(value.cast<int64_t>());
} else {
value_ptrs.emplace_back(value.cast<at::Tensor>().data_ptr());
}
}
self.call(value_ptrs);
})
.def(
"call_raw",
[](CodeGen& self, const py::sequence& values) {
std::vector<void*> value_ptrs;
value_ptrs.reserve(py::len(values));
for (const auto& value : values) {
// Tensor.data_ptr() returns an int in python
value_ptrs.emplace_back(
reinterpret_cast<void*>(value.cast<intptr_t>()));
}
self.call_raw(value_ptrs);
})
.def(
"get_code_text",
[](CodeGen& self, const std::string& attr = "") {
return self.getCodeText(attr);
},
py::arg("attr") = "");
py::class_<SimpleIREvaluator, CodeGen>(te, "SimpleIREvaluator"); // NOLINT
#ifdef TORCH_ENABLE_LLVM
py::class_<LLVMCodeGen, CodeGen>(te, "LLVMCodeGen"); // NOLINT
#endif
py::class_<CodeGen::BufferArg>(te, "BufferArg")
.def(py::init<Tensor>())
.def(py::init<const VarHandle&>())
.def(py::init<const BufHandle&>());
py::implicitly_convertible<Tensor, CodeGen::BufferArg>();
py::implicitly_convertible<VarHandle, CodeGen::BufferArg>();
py::implicitly_convertible<BufHandle, CodeGen::BufferArg>();
te.def(
"construct_codegen",
[](const std::string& name,
StmtPtr stmt,
const std::vector<CodeGen::BufferArg>& args) {
CodeGen* cg = nullptr;
if (name == "llvm") {
#ifdef TORCH_ENABLE_LLVM
cg = new LLVMCodeGen(stmt, args);
#else
throw std::runtime_error("PyTorch not compiled with LLVM support!");
#endif
} else if (name == "cuda") {
#ifdef USE_CUDA
cg = new CudaCodeGen(stmt, args);
#else
throw std::runtime_error("PyTorch not compiled with CUDA support!");
#endif
} else if (name == "ir_eval") {
cg = new SimpleIREvaluator(stmt, args);
} else {
throw std::runtime_error(
"construct_codegen() expects 'llvm', 'cuda', or 'ir_eval'");
}
return cg;
});
te.def("annotate_input_shapes", &tensorexpr::annotateInputShapes);
te.def("remove_unused_self_argument", &tensorexpr::removeUnusedSelfArgument);
te.def("make_shapes_symbolic", &tensorexpr::makeShapesSymbolic);
te.def("is_graph_compilable", &tensorexpr::isGraphCompilable);
te.def("fixup_missing_shape_info", &tensorexpr::fixupMissingShapeInfo);
te.def("remove_graph_output", &tensorexpr::removeGraphOutput);
te.def(
"replace_list_output_with_tuple",
&tensorexpr::replaceListOutputWithTuple);
te.def("trim_graph", &tensorexpr::trimGraph);
#ifdef TORCH_ENABLE_LLVM
te.def("set_llvm_target_triple", [](const c10::optional<std::string>& val) {
tensorexpr::LLVMTargetTriple() = val;
});
te.def("set_llvm_target_cpu", [](const c10::optional<std::string>& val) {
tensorexpr::LLVMTargetCPU() = val;
});
te.def("set_llvm_target_attrs", [](const c10::optional<std::string>& val) {
tensorexpr::LLVMTargetAttrs() = val;
});
te.def("set_llvm_aot_workflow", [](bool val) {
tensorexpr::LLVMAOTWorkflow() = val;
});
#endif
}