in torch_xla/csrc/init_python_bindings.cpp [681:747]
void BuildProfilerSubmodule(py::module* m) {
py::module profiler = m->def_submodule("profiler", "Profiler integration");
py::class_<xla::profiler::ProfilerServer,
std::unique_ptr<xla::profiler::ProfilerServer>>
profiler_server_class(profiler, "ProfilerServer");
profiler.def("start_server",
[](int port) -> std::unique_ptr<xla::profiler::ProfilerServer> {
auto server =
absl::make_unique<xla::profiler::ProfilerServer>();
server->Start(port);
return server;
},
py::arg("port"));
profiler.def("trace",
[](const char* service_addr, const char* logdir, int duration_ms,
int num_tracing_attempts, int timeout_s, int interval_s,
py::dict options) {
absl::flat_hash_map<std::string, absl::variant<int>> opts =
ConvertDictToMap(options);
std::chrono::seconds sleep_s(interval_s);
tensorflow::Status status;
{
NoGilSection nogil;
for (int i = 0; i <= timeout_s / interval_s; i++) {
status = tensorflow::profiler::pywrap::Trace(
service_addr, logdir, /*worker_list=*/"",
/*include_dataset_ops=*/false, duration_ms,
num_tracing_attempts, opts);
if (status.ok()) {
return;
}
std::this_thread::sleep_for(sleep_s);
}
}
if (!status.ok()) {
PyErr_SetString(PyExc_RuntimeError, status.error_message());
throw py::error_already_set();
}
},
py::arg("service_addr"), py::arg("logdir"),
py::arg("duration_ms") = 1000,
py::arg("num_tracing_attempts") = 3, py::arg("timeout_s") = 120,
py::arg("interval_s") = 5, py::arg("options"));
py::class_<tensorflow::profiler::TraceMeWrapper> traceme_class(
profiler, "TraceMe", py::module_local());
traceme_class.def(py::init<py::str, py::kwargs>())
.def("__enter__", [](py::object self) -> py::object { return self; })
.def("__exit__",
[](py::object self, const py::object& ex_type,
const py::object& ex_value,
const py::object& traceback) -> py::object {
py::cast<tensorflow::profiler::TraceMeWrapper*>(self)->Stop();
return py::none();
})
.def("set_metadata", &tensorflow::profiler::TraceMeWrapper::SetMetadata)
.def_static("is_enabled",
&tensorflow::profiler::TraceMeWrapper::IsEnabled);
py::class_<ir::ScopePusher, std::unique_ptr<ir::ScopePusher>>
scope_pusher_class(profiler, "ScopePusher");
profiler.def("scope_pusher",
[](const std::string& name) -> std::unique_ptr<ir::ScopePusher> {
return absl::make_unique<ir::ScopePusher>(name);
});
}