void BuildProfilerSubmodule()

in torch_xla/csrc/init_python_bindings.cpp [681:747]


void BuildProfilerSubmodule(py::module* m) {
  py::module profiler = m->def_submodule("profiler", "Profiler integration");
  py::class_<xla::profiler::ProfilerServer,
             std::unique_ptr<xla::profiler::ProfilerServer>>
      profiler_server_class(profiler, "ProfilerServer");
  profiler.def("start_server",
               [](int port) -> std::unique_ptr<xla::profiler::ProfilerServer> {
                 auto server =
                     absl::make_unique<xla::profiler::ProfilerServer>();
                 server->Start(port);
                 return server;
               },
               py::arg("port"));

  profiler.def("trace",
               [](const char* service_addr, const char* logdir, int duration_ms,
                  int num_tracing_attempts, int timeout_s, int interval_s,
                  py::dict options) {
                 absl::flat_hash_map<std::string, absl::variant<int>> opts =
                     ConvertDictToMap(options);
                 std::chrono::seconds sleep_s(interval_s);
                 tensorflow::Status status;
                 {
                   NoGilSection nogil;
                   for (int i = 0; i <= timeout_s / interval_s; i++) {
                     status = tensorflow::profiler::pywrap::Trace(
                         service_addr, logdir, /*worker_list=*/"",
                         /*include_dataset_ops=*/false, duration_ms,
                         num_tracing_attempts, opts);
                     if (status.ok()) {
                       return;
                     }
                     std::this_thread::sleep_for(sleep_s);
                   }
                 }
                 if (!status.ok()) {
                   PyErr_SetString(PyExc_RuntimeError, status.error_message());
                   throw py::error_already_set();
                 }
               },
               py::arg("service_addr"), py::arg("logdir"),
               py::arg("duration_ms") = 1000,
               py::arg("num_tracing_attempts") = 3, py::arg("timeout_s") = 120,
               py::arg("interval_s") = 5, py::arg("options"));

  py::class_<tensorflow::profiler::TraceMeWrapper> traceme_class(
      profiler, "TraceMe", py::module_local());
  traceme_class.def(py::init<py::str, py::kwargs>())
      .def("__enter__", [](py::object self) -> py::object { return self; })
      .def("__exit__",
           [](py::object self, const py::object& ex_type,
              const py::object& ex_value,
              const py::object& traceback) -> py::object {
             py::cast<tensorflow::profiler::TraceMeWrapper*>(self)->Stop();
             return py::none();
           })
      .def("set_metadata", &tensorflow::profiler::TraceMeWrapper::SetMetadata)
      .def_static("is_enabled",
                  &tensorflow::profiler::TraceMeWrapper::IsEnabled);

  py::class_<ir::ScopePusher, std::unique_ptr<ir::ScopePusher>>
      scope_pusher_class(profiler, "ScopePusher");
  profiler.def("scope_pusher",
               [](const std::string& name) -> std::unique_ptr<ir::ScopePusher> {
                 return absl::make_unique<ir::ScopePusher>(name);
               });
}