in torch/csrc/jit/python/script_init.cpp [786:2199]
void initJitScriptBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<c10::Capsule>(m, "Capsule");
auto object_class =
py::class_<Object>(m, "ScriptObject")
.def("_type", [](Module& m) { return m.type(); })
.def(
"_get_method",
[](Object& self, const std::string& name) -> Method {
return self.get_method(name);
},
py::keep_alive<0, 1>())
.def(
"setattr",
[](Object& self, const std::string& name, py::object value) {
if (self.type()->hasConstant(name)) {
TORCH_CHECK(
false,
"Can't set constant '",
name,
"' which has value:",
self.type()->getConstant(name));
}
TypePtr type = self.type()->getAttribute(name);
try {
auto ivalue = toIValue(std::move(value), type);
self.setattr(name, ivalue);
} catch (std::exception& e) {
throw py::cast_error(c10::str(
"Could not cast attribute '",
name,
"' to type ",
type->repr_str(),
": ",
e.what()));
}
})
.def(
"getattr",
[](Object& self, const std::string& name) {
try {
return toPyObject(self.attr(name));
} catch (const ObjectAttributeError& err) {
throw AttributeError("%s", err.what());
}
})
.def(
"__getattr__",
[](Object& self, const std::string& name) -> py::object {
try {
if (name == "__qualname__") {
return py::cast(self.type()->name()->name());
}
if (auto method = self.find_method(name)) {
return py::cast(*method);
}
if (self.has_property(name)) {
auto prop = self.get_property(name);
// wrap the Method into callable PyObject
auto getter_func = py::cast(prop.getter_func);
return getter_func();
}
return toPyObject(self.attr(name));
} catch (const ObjectAttributeError& err) {
throw AttributeError("%s", err.what());
}
})
.def(
"__setattr__",
[](Object& self, const std::string& name, py::object value) {
try {
if (self.has_property(name)) {
auto prop = self.get_property(name);
if (!prop.setter_func.has_value()) {
TORCH_CHECK(false, "can't set attribute");
}
// wrap the Method into callable PyObject
auto setter_func = py::cast(prop.setter_func);
setter_func(value);
return;
}
if (self.type()->hasConstant(name)) {
TORCH_CHECK(
false,
"Can't set constant '",
name,
"' which has value:",
self.type()->getConstant(name));
}
TypePtr type = self.type()->getAttribute(name);
auto ivalue = toIValue(std::move(value), type);
self.setattr(name, ivalue);
} catch (const ObjectAttributeError& err) {
throw AttributeError("%s", err.what());
}
})
.def(
"hasattr",
[](Object& self, const std::string& name) {
return self.hasattr(name);
})
.def(
"_has_method",
[](Object& self, const std::string& name) {
return bool(self.find_method(name));
})
.def(
"_method_names",
[](Object& self) {
return fmap(self.get_methods(), [](const Method& method) {
return method.name();
});
})
.def(
"_properties", [](Object& self) { return self.get_properties(); })
.def("__copy__", &Object::copy)
.def(
"__hash__",
[](const Object& self) {
// Similar to Tensor's `__hash__`, which is `id()`.
return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
})
.def(py::pickle(
[](const Object& self)
-> std::tuple<py::object, std::string> { // __getstate__
if (auto getstate_method = self.find_method("__getstate__")) {
auto object_state = toPyObject((*getstate_method)(Stack{}));
TORCH_INTERNAL_ASSERT(self.type()->name());
return std::make_tuple(
object_state, self.type()->name()->qualifiedName());
}
std::stringstream err;
err << "Tried to serialize object ";
if (auto qualname = self.type()->name()) {
err << qualname->qualifiedName() << " ";
}
err << "which does not have a __getstate__ method defined!";
throw std::runtime_error(err.str());
},
[](const std::tuple<py::object, std::string>& state_tup)
-> Object {
py::object state;
std::string qualname;
std::tie(state, qualname) = state_tup;
auto class_type = getCustomClass(qualname);
TORCH_CHECK(
class_type,
"Tried to deserialize class ",
qualname,
" which is not known to the runtime. "
"If this is a custom C++ class, make "
"sure the appropriate code is linked.");
auto self = Object(c10::ivalue::Object::create(
c10::StrongTypePtr(
std::shared_ptr<torch::jit::CompilationUnit>(),
class_type),
1));
if (auto setstate_method = self.find_method("__setstate__")) {
auto setstate_schema =
setstate_method->function().getSchema();
TORCH_INTERNAL_ASSERT(
setstate_schema.arguments().size() == 2,
"__setstate__ method for class ",
class_type->repr_str(),
" must have exactly 2 arguments!");
auto state_type = setstate_schema.arguments().at(1).type();
(*setstate_method)(Stack{toIValue(state, state_type)});
return self;
}
std::stringstream err;
err << "Tried to deserialize object ";
if (auto qualname = class_type->name()) {
err << qualname->qualifiedName() << " ";
}
err << "which does not have a __setstate__ method defined!";
throw std::runtime_error(err.str());
}));
py::class_<Object::Property>(m, "ScriptObjectProperty")
.def_property_readonly(
"name", [](const Object::Property& self) { return self.name; })
.def_property_readonly(
"getter",
[](const Object::Property& self) { return self.getter_func; })
.def_property_readonly("setter", [](const Object::Property& self) {
return self.setter_func;
});
// Special case __str__ to make sure we can print Objects/Modules
// regardless of if the user defined a __str__
using MagicMethodImplType = std::function<py::object(
const Object& self, py::args args, py::kwargs kwargs)>;
std::unordered_map<std::string, MagicMethodImplType> special_magic_methods{
{"__str__",
[](const Object& self, py::args args, py::kwargs kwargs) -> py::object {
auto method = self.find_method("__str__");
if (!method) {
return py::str("ScriptObject");
}
return invokeScriptMethodFromPython(
*method,
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(args),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(kwargs));
}}};
for (const char* mm_name : magic_method_names) {
if (special_magic_methods.count(mm_name)) {
object_class.def(mm_name, special_magic_methods[mm_name]);
} else {
object_class.def(
mm_name,
[mm_name](const Object& self, py::args args, py::kwargs kwargs) {
auto method = self.find_method(mm_name);
if (!method) {
throw NotImplementedError();
}
return invokeScriptMethodFromPython(
*method,
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(args),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(kwargs));
});
}
}
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<DeepCopyMemoTable>(m, "DeepCopyMemoTable");
py::class_<UpgraderEntry>(m, "_UpgraderEntry")
.def(py::init<int, std::string, std::string>())
.def_property_readonly(
"bumped_at_version",
[](const UpgraderEntry& self) { return self.bumped_at_version; })
.def_property_readonly(
"upgrader_name",
[](const UpgraderEntry& self) { return self.upgrader_name; })
.def_property_readonly("old_schema", [](const UpgraderEntry& self) {
return self.old_schema;
});
object_class.def(
"__deepcopy__", [](const Object& self, const py::dict& memo) {
return Object(
pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
});
// Used by torch.package to save ScriptModule objects in unified format.
py::class_<ScriptModuleSerializer>(m, "ScriptModuleSerializer")
.def(py::init<caffe2::serialize::PyTorchStreamWriter&>())
.def("serialize", &ScriptModuleSerializer::serialize_unified_format)
.def(
"write_files",
&ScriptModuleSerializer::writeFiles,
py::arg("code_dir") = ".data/ts_code/code/")
.def(
"storage_context",
&ScriptModuleSerializer::storage_context,
pybind11::return_value_policy::reference_internal);
// Used by torch.package to coordinate sharing of storages between eager
// and ScriptModules.
py::class_<
SerializationStorageContext,
std::shared_ptr<SerializationStorageContext>>(
m, "SerializationStorageContext")
.def("has_storage", &SerializationStorageContext::hasStorage)
.def("get_or_add_storage", &SerializationStorageContext::getOrAddStorage);
// torch.jit.ScriptModule is a subclass of this C++ object.
// Methods here are prefixed with _ since they should not be
// public.
py::class_<Module, Object>(m, "ScriptModule")
.def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>())
.def(
"save",
[](Module& m,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
m.save(filename, _extra_files);
},
py::arg("filename"),
py::arg("_extra_files") = ExtraFilesMap())
.def(
"save_to_buffer",
[](Module& m, const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
std::ostringstream buf;
m.save(buf, _extra_files);
return py::bytes(buf.str());
},
py::arg("_extra_files") = ExtraFilesMap())
.def(
"_save_for_mobile",
[](Module& m,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
bool _save_mobile_debug_info = false) {
m._save_for_mobile(filename, _extra_files, _save_mobile_debug_info);
},
py::arg("filename"),
py::arg("_extra_files") = ExtraFilesMap(),
py::arg("_save_mobile_debug_info") = false)
.def(
"_save_to_buffer_for_mobile",
[](Module& m,
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
bool _save_mobile_debug_info = false) {
std::ostringstream buf;
m._save_for_mobile(buf, _extra_files, _save_mobile_debug_info);
return py::bytes(buf.str());
},
py::arg("_extra_files") = ExtraFilesMap(),
py::arg("_save_mobile_debug_info") = false)
.def("_set_optimized", &Module::set_optimized)
.def(
"dump",
&Module::dump,
py::arg("code") = true,
py::arg("attrs") = true,
py::arg("params") = true)
.def(
"dump_to_str",
&Module::dump_to_str,
py::arg("code") = true,
py::arg("attrs") = true,
py::arg("params") = true)
.def(
"_replicate_for_data_parallel",
[](Module& module) {
const ModulePtr& obj = module._ivalue();
auto copy = c10::ivalue::Object::create(
c10::StrongTypePtr(obj->compilation_unit(), obj->type()),
obj->slots().size());
for (size_t i = 0; i < obj->slots().size(); ++i) {
copy->setSlot(i, obj->getSlot(i));
}
return Module(std::move(copy));
})
.def(
"get_debug_state",
[](Module& self) {
if (auto m = self.find_method("forward")) {
return m->get_executor().getDebugState();
}
throw std::runtime_error(
"Attempted to call get_debug_state on a Module without a compiled forward()");
})
.def(
"_define",
[](Module& m,
std::shared_ptr<ConcreteModuleType> concreteType,
const std::string& script,
const ResolutionCallback& rcb) {
const auto self = ModuleSelf(std::move(concreteType));
m._ivalue()->compilation_unit()->define(
*m.type()->name(), script, pythonResolver(rcb), &self);
didFinishEmitModule(m);
})
.def(
"_register_attribute",
[](Module& m,
const std::string& name,
const TypePtr& type,
py::handle value) {
m.register_attribute(name, type, toIValue(value, type));
})
.def(
"_create_method_from_trace",
[](Module& self,
const std::string& name,
const py::function& func,
const py::tuple& input_tuple,
const py::function& var_name_lookup_fn,
bool strict,
bool force_outplace,
const std::vector<std::string>& argument_names) {
// prereq: Module's buffers and parameters are unique
// this was ensured in python before calling this function
auto typed_inputs = toTraceableStack(input_tuple);
std::shared_ptr<Graph> graph =
std::get<0>(tracer::createGraphByTracing(
func,
typed_inputs,
var_name_lookup_fn,
strict,
force_outplace,
&self,
argument_names));
const auto method_name = QualifiedName(*self.type()->name(), name);
auto fn = self._ivalue()->compilation_unit()->create_function(
method_name, graph);
self.type()->addMethod(fn);
didFinishEmitModule(self);
},
py::arg("name"),
py::arg("func"),
py::arg("input_tuple"),
py::arg("var_name_lookup_fn"),
py::arg("strict"),
py::arg("force_outplace"),
py::arg("argument_names") = std::vector<std::string>())
.def(
"_get_forward_hooks",
[](const Module& m) {
std::vector<StrongFunctionPtr> funcs;
for (auto& hook : m.type()->getForwardHooks()) {
funcs.emplace_back(
StrongFunctionPtr(m.type()->compilation_unit(), hook));
}
return funcs;
})
.def(
"_get_forward_pre_hooks",
[](const Module& m) {
std::vector<StrongFunctionPtr> funcs;
for (auto& pre_hook : m.type()->getForwardPreHooks()) {
funcs.emplace_back(
StrongFunctionPtr(m.type()->compilation_unit(), pre_hook));
}
return funcs;
})
.def_property_readonly(
"code",
[](Module& self) {
std::vector<at::IValue> constants;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printNamedType(self.type());
return pp.str();
})
.def_property_readonly(
"code_with_constants",
[](Module& self) {
std::vector<at::IValue> constants;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printNamedType(self.type());
std::map<std::string, at::IValue> consts;
int i = 0;
for (auto const& constant : constants) {
consts["c" + std::to_string(i)] = constant;
i += 1;
}
return std::make_tuple(pp.str(), consts);
})
.def("apply", &Module::apply)
.def("__copy__", &Module::copy)
.def(
"__hash__",
[](const Module& self) {
// Similar to Tensor's `__hash__`, which is `id()`.
return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
})
.def(
"__eq__",
[](const Module& self, const py::object& other) {
// TODO: call UDF if it exists
if (!py::isinstance<Module>(other)) {
return false;
}
return self._ivalue().get() ==
py::cast<Module>(other)._ivalue().get();
})
.def(
"__deepcopy__",
[](const Module& self, const py::dict& memo) {
return Module(
pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
})
.def("children", &Module::children)
.def_property_readonly("qualified_name", [](const Module& self) {
return self.type()->name()->qualifiedName();
});
py::class_<mobile::Module>(m, "LiteScriptModule")
.def(py::init<
c10::intrusive_ptr<c10::ivalue::Object>,
std::shared_ptr<mobile::CompilationUnit>>())
.def(
"find_method",
[](mobile::Module& m, const std::string& method_name) {
auto method = m.find_method(method_name);
return method != c10::nullopt;
},
py::arg("method_name"))
.def(
"run_method",
[](mobile::Module& m,
const std::string& method_name,
const py::tuple& input_tuple) {
Stack stack;
for (auto& input : input_tuple) {
stack.push_back(toTypeInferredIValue(input));
}
return m.get_method(method_name)(stack);
},
py::arg("method_name"),
py::arg("input_tuple"))
.def(
"forward",
[](mobile::Module& m, const py::tuple& input_tuple) {
Stack stack;
for (auto& input : input_tuple) {
stack.push_back(toTypeInferredIValue(input));
}
return m.get_method("forward")(stack);
},
py::arg("input_tuple"));
slot_dict_impl<detail::ParameterPolicy>::bind(m, "ParameterDict");
slot_dict_impl<detail::BufferPolicy>::bind(m, "BufferDict");
slot_dict_impl<detail::ModulePolicy>::bind(m, "ModuleDict");
py::class_<ErrorReport, std::shared_ptr<ErrorReport>>(m, "ErrorReport")
.def(py::init<SourceRange>())
.def("what", &ErrorReport::what)
.def_static("call_stack", ErrorReport::current_call_stack);
py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(
m, "CompilationUnit")
.def(
py::init([](const std::string& lang, const uint32_t _frames_up) {
auto cu = std::make_shared<CompilationUnit>();
if (lang.size() > 0) {
pyCompilationUnitDefine(*cu, lang, nullptr, _frames_up);
}
return cu;
}),
py::arg("lang") = "",
py::arg("_frames_up") = 0)
.def(
"find_function",
[](std::shared_ptr<CompilationUnit> self, const std::string& name) {
auto fn = self->find_function(QualifiedName(name));
if (fn) {
return c10::optional<StrongFunctionPtr>(
StrongFunctionPtr(std::move(self), fn));
} else {
return c10::optional<StrongFunctionPtr>(c10::nullopt);
}
})
.def(
"__getattr__",
[](std::shared_ptr<CompilationUnit> self, const std::string& name) {
auto fn = self->find_function(QualifiedName(name));
if (fn) {
return StrongFunctionPtr(std::move(self), fn);
} else {
throw AttributeError(
"'CompilationUnit' has no attribute '%s'", name.c_str());
}
})
.def(
"get_functions",
[](const std::shared_ptr<CompilationUnit>& self) {
auto raw_functions = self->get_functions();
std::vector<StrongFunctionPtr> functions;
functions.reserve(raw_functions.size());
for (auto fn : raw_functions) {
if (fn) {
functions.emplace_back(self, fn);
}
}
return functions;
})
.def("set_optimized", &CompilationUnit::set_optimized)
.def(
"define",
pyCompilationUnitDefine,
py::arg("src"),
py::arg("rcb") = nullptr,
py::arg("_frames_up") = 0)
.def(
"create_function",
[](std::shared_ptr<CompilationUnit>& self,
const std::string& qualified_name,
std::shared_ptr<Graph> graph,
bool should_mangle) {
Function* fn = self->create_function(
qualified_name, std::move(graph), should_mangle);
return StrongFunctionPtr(std::move(self), fn);
},
py::arg("qualified_name"),
py::arg("graph"),
py::arg("should_mangle") = false)
.def(
"get_interface",
[](const std::shared_ptr<CompilationUnit>& self,
const std::string& name) { return self->get_interface(name); })
.def(
"get_class",
[](const std::shared_ptr<CompilationUnit>& self,
const std::string& name) { return self->get_class(name); });
py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
.def(
"__call__",
[](py::args args, py::kwargs kwargs) {
HANDLE_TH_ERRORS
// see: [pybind11 varargs]
auto strongPtr = py::cast<StrongFunctionPtr>(args[0]);
Function& callee = *strongPtr.function_;
py::object result = invokeScriptFunctionFromPython(
callee,
// NOLINTNEXTLINE(performance-move-const-arg)
tuple_slice(std::move(args), 1),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(kwargs));
return result;
END_HANDLE_TH_ERRORS_PYBIND
})
.def(
"save",
[](const StrongFunctionPtr& self,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
Module module("__torch__.PlaceholderModule");
// [issue 27343]
// Modules have 'training' attributes by default, but due to
// https://github.com/pytorch/pytorch/issues/27343, functions end
// up having a training attribute when they are loaded. This adds
// a fake 'training' attribute that shouldn't be used, but prevents
// jitter on saving and loading. Once that issue is fixed this can
// be deleted.
module.register_attribute("training", BoolType::get(), true);
addFunctionToModule(module, self);
module.save(filename, _extra_files);
},
py::arg("filename"),
py::arg("_extra_files") = ExtraFilesMap())
.def(
"save_to_buffer",
[](const StrongFunctionPtr& self,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
std::ostringstream buf;
Module module("__torch__.PlaceholderModule");
// see [issue 27343]
module.register_attribute("training", BoolType::get(), true);
addFunctionToModule(module, self);
module.save(buf, _extra_files);
return py::bytes(buf.str());
},
py::arg("_extra_files") = ExtraFilesMap())
.def_property_readonly(
"graph",
[](const StrongFunctionPtr& self) {
return toGraphFunction(*self.function_).graph();
})
.def_property_readonly(
"inlined_graph",
[](const StrongFunctionPtr& self) {
auto g = toGraphFunction(*self.function_).graph()->copy();
Inline(*g);
return g;
})
.def_property_readonly(
"schema",
[](const StrongFunctionPtr& self) {
return self.function_->getSchema();
})
.def_property_readonly(
"code",
[](const StrongFunctionPtr& self) {
std::vector<at::IValue> constants;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printFunction(*self.function_);
return pp.str();
})
.def(
"get_debug_state",
[](const StrongFunctionPtr& self) {
return toGraphFunction(*self.function_)
.get_executor()
.getDebugState();
})
.def(
"_debug_flush_compilation_cache",
[](const StrongFunctionPtr& self) {
toGraphFunction(*self.function_)
.get_executor()
.debugFlushCompilationCache();
})
.def_property_readonly(
"name",
[](const StrongFunctionPtr& self) { return self.function_->name(); })
.def_property_readonly(
"qualified_name",
[](const StrongFunctionPtr& self) {
return self.function_->qualname().qualifiedName();
})
.def_property_readonly("__doc__", [](const StrongFunctionPtr& self) {
return self.function_->doc_string();
});
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
.def(
"__call__",
[](py::args args, py::kwargs kwargs) {
// see: [pybind11 varargs]
HANDLE_TH_ERRORS
Method& method = py::cast<Method&>(args[0]);
return invokeScriptMethodFromPython(
method,
// NOLINTNEXTLINE(performance-move-const-arg)
tuple_slice(std::move(args), 1),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(kwargs));
END_HANDLE_TH_ERRORS_PYBIND
})
.def_property_readonly("graph", &Method::graph)
.def_property_readonly(
"inlined_graph",
[](const Method& self) {
auto g = toGraphFunction(self.function()).graph()->copy();
Inline(*g);
return g;
})
.def_property_readonly(
"schema", [](Method& m) { return m.function().getSchema(); })
.def_property_readonly("name", &Method::name)
.def_property_readonly(
"code",
[](Method& self) {
std::vector<at::IValue> constants;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printMethod(self.function());
return pp.str();
})
.def(
"_debug_flush_compilation_cache",
[](Method& self) {
return self.get_executor().debugFlushCompilationCache();
})
.def_property_readonly(
"code_with_constants",
[](Method& self) {
std::vector<at::IValue> constants;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printMethod(self.function());
std::map<std::string, at::IValue> consts;
int i = 0;
for (auto const& constant : constants) {
consts["c" + std::to_string(i)] = constant;
i += 1;
}
return std::make_tuple(pp.str(), consts);
})
.def_property_readonly("owner", &Method::owner);
m.def(
"_jit_script_compile",
[](const std::string& qualname,
const Def& def,
const ResolutionCallback& rcb,
const FunctionDefaults& defaults) {
C10_LOG_API_USAGE_ONCE("torch.script.compile");
const auto name = c10::QualifiedName(qualname);
TORCH_INTERNAL_ASSERT(name.name() == def.name().name());
return script_compile_function(name, def, defaults, rcb);
});
m.def(
"_jit_script_compile_overload",
[](const std::string& qualname,
const Decl& overload_decl,
const Def& implementation_def,
const ResolutionCallback& rcb,
const FunctionDefaults& implementation_defaults,
const py::object& signature) {
const auto name = c10::QualifiedName(qualname);
return script_compile_overloaded_function(
name,
overload_decl,
implementation_def,
rcb,
implementation_defaults,
signature);
});
m.def(
"_replace_overloaded_method_decl",
[](const Decl& overload_decl,
const Def& implementation_def,
const std::string& new_name) {
checkOverloadDecl(overload_decl, implementation_def.decl());
return implementation_def.withDecl(overload_decl).withName(new_name);
});
m.def(
"_create_function_from_trace",
[](const std::string& qualname,
const py::function& func,
const py::tuple& input_tuple,
const py::function& var_name_lookup_fn,
bool strict,
bool force_outplace,
const std::vector<std::string>& argument_names) {
auto typed_inputs = toTraceableStack(input_tuple);
std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
func,
typed_inputs,
var_name_lookup_fn,
strict,
force_outplace,
/*self=*/nullptr,
argument_names));
auto cu = get_python_cu();
auto name = c10::QualifiedName(qualname);
auto result = cu->create_function(
std::move(name), std::move(graph), /*shouldMangle=*/true);
StrongFunctionPtr ret(std::move(cu), result);
didFinishEmitFunction(ret);
return ret;
},
py::arg("name"),
py::arg("func"),
py::arg("input_tuple"),
py::arg("var_name_lookup_fn"),
py::arg("strict"),
py::arg("force_outplace"),
py::arg("argument_names") = std::vector<std::string>());
m.def("_generate_upgraders_bytecode", &generate_bytecode_list);
m.def(
"_jit_script_class_compile",
[](const std::string& qualifiedName,
const ClassDef& classDef,
const ClassMethodDefaults& defaults,
const ResolutionCallback& rcb) {
C10_LOG_API_USAGE_ONCE("torch.script.class");
if (classDef.superclass().present()) {
throw ErrorReport(classDef.range())
<< "Torchscript does not support class inheritance.";
}
auto cu = get_python_cu();
auto classname = c10::QualifiedName(qualifiedName);
if (cu->get_type(classname) != nullptr) {
classname = cu->mangle(classname);
}
auto classType = ClassType::create(
classname,
cu,
/* is_module = */ false,
/* doc_string = */ "",
getUnresolvedClassAttributes(classDef));
cu->register_type(classType);
std::vector<ResolverPtr> methodRcbs, propRcbs;
std::vector<Def> methodDefs;
std::vector<Property> props;
for (const auto& def : classDef.body()) {
if (def.kind() != TK_DEF) {
throw ErrorReport(def.range())
<< "Currently class bodies can only contain method "
"definitions. File an issue on Github if you want "
"something else!";
}
methodDefs.emplace_back(Def(def));
methodRcbs.push_back(
pythonResolver(rcb, classDef.name().name(), classType));
}
// Gather definitions for property getters and setters as well as
// corresponding resolution callbacks.
if (classDef.properties().present()) {
for (const auto& prop : classDef.properties().get()) {
props.emplace_back(prop);
propRcbs.push_back(
pythonResolver(rcb, classDef.name().name(), classType));
}
}
const auto self = SimpleSelf(classType);
cu->define(classname, props, propRcbs, methodDefs, methodRcbs, &self);
// Stitch in default arguments for methods. Properties don't need to be
// considered since there is no way to invoke setters without passing in
// a value.
auto defs_it = methodDefs.begin();
while (defs_it != methodDefs.end()) {
auto def_name = (*defs_it).name().name();
// If the method is not in the defaults map, assume there are
// no default arguments for it.
auto default_it = defaults.find(def_name);
if (default_it == defaults.end()) {
continue;
}
const auto method_name =
QualifiedName(classname, (*defs_it).name().name());
auto& method = cu->get_function(method_name);
method.setSchema(getSchemaWithNameAndDefaults(
defs_it->range(),
method.getSchema(),
at::nullopt,
default_it->second));
++defs_it;
}
return classType;
});
m.def(
"_jit_script_interface_compile",
[](const std::string& qualifiedName,
const ClassDef& classDef,
const ResolutionCallback& rcb,
bool is_module) {
auto cu = get_python_cu();
auto className = c10::QualifiedName(qualifiedName);
if (cu->get_type(className) != nullptr) {
className = cu->mangle(className);
}
get_python_cu()->define_interface(
className, classDef, pythonResolver(rcb), is_module);
return className.qualifiedName();
});
py::class_<torch::jit::ErrorReport::CallStack>(
m, "CallStack", py::dynamic_attr())
.def(py::init<const std::string&, const SourceRange&>());
m.def("_parse_source_def", [](const std::string& src) {
Parser p(std::make_shared<Source>(src));
return Def(p.parseFunction(/*is_method=*/true));
});
m.def("parse_type_comment", [](const std::string& comment) {
Parser p(std::make_shared<Source>(comment));
return Decl(p.parseTypeComment());
});
m.def("_is_upgraders_enabled", &is_upgraders_enabled);
m.def("_get_upgraders_map_size", &get_upgraders_map_size);
m.def("_dump_upgraders_map", &dump_upgraders_map);
m.def("_test_only_populate_upgraders", &test_only_populate_upgraders);
m.def("_test_only_remove_upgraders", &test_only_remove_upgraders);
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
m.def("_get_operator_version_map", &get_operator_version_map);
m.def("_test_only_add_entry_to_op_version_map", &test_only_add_entry);
m.def("_test_only_remove_entry_to_op_version_map", &test_only_remove_entry);
m.def(
"import_ir_module",
[](std::shared_ptr<CompilationUnit> cu,
const std::string& filename,
py::object map_location,
const py::dict& extra_files) {
c10::optional<at::Device> optional_device;
if (!map_location.is(py::none())) {
AT_ASSERT(THPDevice_Check(map_location.ptr()));
optional_device =
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
}
ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
auto ret = import_ir_module(
std::move(cu), filename, optional_device, extra_files_map);
extra_files_to_python(extra_files_map, extra_files);
return ret;
});
m.def(
"_import_ir_module_from_package",
[](std::shared_ptr<CompilationUnit> cu,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader,
std::shared_ptr<torch::jit::DeserializationStorageContext>
storage_context,
py::object map_location,
std::string ts_id) {
c10::optional<at::Device> optional_device;
if (!map_location.is(py::none())) {
AT_ASSERT(THPDevice_Check(map_location.ptr()));
optional_device =
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
}
return import_ir_module(
std::move(cu),
std::move(reader),
std::move(storage_context),
optional_device,
std::move(ts_id));
});
m.def(
"import_ir_module_from_buffer",
[](std::shared_ptr<CompilationUnit> cu,
const std::string& buffer,
py::object map_location,
const py::dict& extra_files) {
std::istringstream in(buffer);
c10::optional<at::Device> optional_device;
if (!map_location.is(py::none())) {
AT_ASSERT(THPDevice_Check(map_location.ptr()));
optional_device =
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
}
ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
auto ret = import_ir_module(
std::move(cu), in, optional_device, extra_files_map);
extra_files_to_python(extra_files_map, extra_files);
return ret;
});
m.def(
"_load_for_lite_interpreter",
[](const std::string& filename, py::object map_location) {
c10::optional<at::Device> optional_device;
if (!map_location.is(py::none())) {
AT_ASSERT(THPDevice_Check(map_location.ptr()));
optional_device =
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
}
return _load_for_mobile(filename, optional_device);
});
m.def(
"_load_for_lite_interpreter_from_buffer",
[](const std::string& buffer, py::object map_location) {
std::istringstream in(buffer);
c10::optional<at::Device> optional_device;
if (!map_location.is(py::none())) {
AT_ASSERT(THPDevice_Check(map_location.ptr()));
optional_device =
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
}
return _load_for_mobile(in, optional_device);
});
m.def(
"_backport_for_mobile",
[](const std::string& filename_input,
const std::string& filename_output,
const int64_t version) {
return _backport_for_mobile(filename_input, filename_output, version);
});
m.def(
"_backport_for_mobile_from_buffer",
[](const std::string& buffer_input,
const std::string& filename_output,
const int64_t version) {
std::istringstream in(buffer_input);
return _backport_for_mobile(in, filename_output, version);
});
m.def(
"_backport_for_mobile_to_buffer",
[](const std::string& filename_input, const int64_t version) {
std::ostringstream buffer_output;
bool success =
_backport_for_mobile(filename_input, buffer_output, version);
return success ? py::bytes(buffer_output.str()) : py::bytes("");
});
m.def(
"_backport_for_mobile_from_buffer_to_buffer",
[](const std::string& buffer_input, const int64_t version) {
std::istringstream in(buffer_input);
std::ostringstream buffer_output;
bool success = _backport_for_mobile(in, buffer_output, version);
return success ? py::bytes(buffer_output.str()) : py::bytes("");
});
m.def("_get_model_bytecode_version", [](const std::string& filename) {
return _get_model_bytecode_version(filename);
});
m.def(
"_get_model_bytecode_version_from_buffer", [](const std::string& buffer) {
std::istringstream in(buffer);
return _get_model_bytecode_version(in);
});
m.def("_get_mobile_model_contained_types", [](const std::string& filename) {
return _get_mobile_model_contained_types(filename);
});
m.def(
"_get_mobile_model_contained_types_from_buffer",
[](const std::string& buffer) {
std::istringstream in(buffer);
return _get_mobile_model_contained_types(in);
});
py::class_<OperatorInfo>(m, "OperatorInfo")
.def_readonly("num_schema_args", &OperatorInfo::num_schema_args);
m.def("_get_model_ops_and_info", [](const std::string& filename) {
return _get_model_ops_and_info(filename);
});
m.def("_get_model_ops_and_info_from_buffer", [](const std::string& buffer) {
std::istringstream in(buffer);
return _get_model_ops_and_info(in);
});
m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) {
return debugMakeSet(torch::jit::mobile::_export_operator_list(sm));
});
m.def("_jit_set_emit_hooks", setEmitHooks);
m.def("_jit_get_emit_hooks", getEmitHooks);
m.def("_jit_clear_class_registry", []() {
get_python_cu()->_clear_python_cu();
});
m.def(
"_debug_set_autodiff_subgraph_inlining",
debugSetAutodiffSubgraphInlining);
m.def("_debug_set_fusion_group_inlining", debugSetFusionGroupInlining);
m.def("_debug_get_fusion_group_inlining", getFusionGroupInlining);
m.def("_propagate_shapes", _propagate_shapes);
m.def(
"_propagate_and_assign_input_shapes", _propagate_and_assign_input_shapes);
m.def(
"_last_executed_optimized_graph",
[]() { return lastExecutedOptimizedGraph(); },
"Retrieve the optimized graph that was run the last time the graph executor ran on this thread");
m.def(
"_create_function_from_graph",
[](const std::string& qualname, std::shared_ptr<Graph> graph) {
// TODO this should go in the global Python CU
auto cu = std::make_shared<CompilationUnit>();
c10::QualifiedName name(qualname);
auto fn = cu->create_function(std::move(name), std::move(graph));
return StrongFunctionPtr(std::move(cu), fn);
});
m.def("_ivalue_tags_match", ivalue_tags_match);
m.def("_ivalue_debug_python_object", [](py::object py_obj) {
// convert to IValue first, IValue will incref via py::object
IValue pyobj_ivalue = toIValue(std::move(py_obj), PyObjectType::get());
// convert back to PyObject by borrowing the reference, which also
// incref, after the return of this function, IValue is out of scope
// which decref, so the return value is original refcount + 1
py::object ret = toPyObject(pyobj_ivalue);
return ret;
});
m.def("_jit_debug_module_iterators", _jit_debug_module_iterators);
py::class_<testing::FileCheck>(m, "FileCheck")
.def(py::init<>())
.def("check", &testing::FileCheck::check)
.def("check_not", &testing::FileCheck::check_not)
.def("check_same", &testing::FileCheck::check_same)
.def("check_next", &testing::FileCheck::check_next)
.def("check_count", &testing::FileCheck::check_count)
.def("check_dag", &testing::FileCheck::check_dag)
.def(
"check_source_highlighted",
&testing::FileCheck::check_source_highlighted)
.def(
"check_count",
[](testing::FileCheck& f,
const std::string& str,
size_t count,
bool exactly) { return f.check_count(str, count, exactly); },
"Check Count",
py::arg("str"),
py::arg("count"),
py::arg("exactly") = false)
.def(
"run",
[](testing::FileCheck& f, const std::string& str) {
return f.run(str);
})
.def(
"run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); })
.def(
"run",
[](testing::FileCheck& f,
const std::string& input,
const std::string& output) { return f.run(input, output); },
"Run",
py::arg("checks_file"),
py::arg("test_file"))
.def(
"run",
[](testing::FileCheck& f, const std::string& input, const Graph& g) {
return f.run(input, g);
},
"Run",
py::arg("checks_file"),
py::arg("graph"));
m.def(
"_logging_set_logger",
[](logging::LoggerBase* logger) { return logging::setLogger(logger); },
py::return_value_policy::reference);
m.def("_set_graph_executor_optimize", [](bool optimize) {
setGraphExecutorOptimize(optimize);
});
m.def("_get_graph_executor_optimize", &torch::jit::getGraphExecutorOptimize);
m.def(
"_enable_mobile_interface_call_export",
&torch::jit::enableMobileInterfaceCallExport);
m.def("_create_module_with_type", [](const ClassTypePtr& type) {
return Module(get_python_cu(), type);
}).def("_create_object_with_type", [](const ClassTypePtr& type) {
return Object(get_python_cu(), type);
});
m.def("_export_opnames", [](Module& sm) {
return debugMakeList(torch::jit::export_opnames(sm));
});
py::class_<
ConcreteModuleTypeBuilder,
std::shared_ptr<ConcreteModuleTypeBuilder>>(
m, "ConcreteModuleTypeBuilder")
.def(py::init<py::object>())
.def(
"add_constant",
[](ConcreteModuleTypeBuilder& self,
std::string name,
py::object value) {
self.addConstant(std::move(name), std::move(value));
})
.def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
.def(
"add_function_attribute",
&ConcreteModuleTypeBuilder::addFunctionAttribute)
.def(
"add_builtin_function",
&ConcreteModuleTypeBuilder::addBuiltinFunction)
.def("add_forward_hook", &ConcreteModuleTypeBuilder::addForwardHook)
.def(
"add_forward_pre_hook", &ConcreteModuleTypeBuilder::addForwardPreHook)
.def("add_module", &ConcreteModuleTypeBuilder::addModule)
.def("add_overload", &ConcreteModuleTypeBuilder::addOverload)
.def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned)
.def(
"add_failed_attribute",
&ConcreteModuleTypeBuilder::addFailedAttribute)
.def(
"add_ignored_attribute",
&ConcreteModuleTypeBuilder::addIgnoredAttribute)
.def(
"add_ignored_attributes",
[](ConcreteModuleTypeBuilder& self,
const std::vector<std::string>& names) {
for (auto& name : names) {
self.addIgnoredAttribute(name);
}
})
.def(
"set_module_dict",
[](ConcreteModuleTypeBuilder& self) {
self.setIterableModuleKind(IterableModuleKind::DICT);
})
.def("build", &ConcreteModuleTypeBuilder::build)
.def(
"equals",
[](const ConcreteModuleTypeBuilder& self,
const ConcreteModuleTypeBuilder& other) {
return self.equals(other);
})
.def("set_module_list", [](ConcreteModuleTypeBuilder& self) {
self.setIterableModuleKind(IterableModuleKind::LIST);
});
py::class_<ConcreteModuleType, std::shared_ptr<ConcreteModuleType>>(
m, "ConcreteModuleType")
.def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
.def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
.def_static("from_jit_type", &ConcreteModuleType::fromJitType)
.def("get_constants", &ConcreteModuleType::getConstantsPy)
.def("get_attributes", &ConcreteModuleType::getAttributesPy)
.def("get_modules", &ConcreteModuleType::getModulesPy)
.def("dump", &ConcreteModuleType::dump)
.def("is_ignored_attribute", &ConcreteModuleType::isIgnoredAttribute)
.def(
"equals",
[](const ConcreteModuleType& self, const ConcreteModuleType& other) {
return self.equals(other);
})
.def(
"equals",
[](const ConcreteModuleType& self,
const ConcreteModuleTypeBuilder& other) {
return self.equals(other);
})
.def(
"_create_methods_and_properties",
[](std::shared_ptr<ConcreteModuleType> concreteType,
const std::vector<Property>& properties,
const std::vector<ResolutionCallback>& propertyRcbs,
const std::vector<Def>& methodDefs,
const std::vector<ResolutionCallback>& methodRcbs,
const std::vector<FunctionDefaults>& defaults) {
TORCH_INTERNAL_ASSERT(methodDefs.size() == methodRcbs.size());
TORCH_INTERNAL_ASSERT(properties.size() == propertyRcbs.size());
std::vector<ResolverPtr> methodResolvers, propertyResolvers;
methodResolvers.reserve(methodRcbs.size());
for (auto& callback : methodRcbs) {
methodResolvers.push_back(pythonResolver(callback));
}
propertyResolvers.reserve(propertyRcbs.size());
for (auto& callback : propertyRcbs) {
propertyResolvers.push_back(pythonResolver(callback));
}
const auto& selfType =
concreteType->getJitType()->expect<ClassType>();
const auto& prefix = selfType->name().value();
const auto self = ModuleSelf(std::move(concreteType));
auto cu = selfType->compilation_unit();
cu->define(
prefix,
properties,
propertyResolvers,
methodDefs,
methodResolvers,
&self);
// Stitch in default arguments for each Def if provided
auto defaults_it = defaults.begin();
auto defs_it = methodDefs.begin();
while (defs_it != methodDefs.end()) {
const auto method_name =
QualifiedName(prefix, (*defs_it).name().name());
auto& method = cu->get_function(method_name);
method.setSchema(getSchemaWithNameAndDefaults(
defs_it->range(),
method.getSchema(),
at::nullopt,
*defaults_it));
++defs_it;
++defaults_it;
}
})
.def(
"_create_hooks",
[](std::shared_ptr<ConcreteModuleType> concreteType,
const std::vector<Def>& hookDefs,
const std::vector<ResolutionCallback>& hookRcbs,
const std::vector<Def>& preHookDefs,
const std::vector<ResolutionCallback>& preHookRcbs) {
TORCH_INTERNAL_ASSERT(hookDefs.size() == hookRcbs.size());
TORCH_INTERNAL_ASSERT(preHookDefs.size() == preHookRcbs.size());
std::vector<ResolverPtr> hookResolvers, preHookResolvers;
hookResolvers.reserve(hookRcbs.size());
for (auto& callback : hookRcbs) {
hookResolvers.push_back(pythonResolver(callback));
}
preHookResolvers.reserve(preHookRcbs.size());
for (auto& callback : preHookRcbs) {
preHookResolvers.push_back(pythonResolver(callback));
}
const auto& selfType =
concreteType->getJitType()->expect<ClassType>();
const auto& prefix = selfType->name().value();
const auto self = ModuleSelf(std::move(concreteType));
auto cu = selfType->compilation_unit();
cu->define_hooks(
prefix,
hookDefs,
hookResolvers,
preHookDefs,
preHookResolvers,
&self);
});
m.def(
"_resolve_type",
[](const std::string& name,
const SourceRange& range,
const ResolutionCallback& rcb) {
return pythonResolver(rcb)->resolveType(name, range);
});
m.def(
"_resolve_type_from_object",
[](const py::object& obj,
const SourceRange& range,
const ResolutionCallback& rcb) {
return pythonResolver(rcb)->resolveTypeFromObject(obj, range);
});
m.def(
"_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); });
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
m, "LoggerBase");
py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
.value("SUM", logging::LockingLogger::AggregationType::SUM)
.value("AVG", logging::LockingLogger::AggregationType::AVG)
.export_values();
py::class_<
logging::LockingLogger,
logging::LoggerBase,
std::shared_ptr<logging::LockingLogger>>(m, "LockingLogger")
.def(py::init<>())
.def("set_aggregation_type", &logging::LockingLogger::setAggregationType)
.def("get_counter_val", &logging::LockingLogger::getCounterValue);
py::class_<
logging::NoopLogger,
logging::LoggerBase,
std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
.def(py::init<>());
m.def(
"_check_onnx_proto",
[](const std::string& proto_string) { check_onnx_proto(proto_string); },
py::arg("proto_string"));
m.def("_jit_is_script_object", [](const py::object& obj) {
return py::isinstance<Object>(obj);
});
initScriptDictBindings(module);
initScriptListBindings(module);
}