void InitXlaModuleBindings()

in torch_xla/csrc/init_python_bindings.cpp [749:1185]


void InitXlaModuleBindings(py::module m) {
  m.def("_prepare_to_exit", []() { PrepareToExit(); });
  m.def("_get_git_revs", []() { return GetRevisions(); });
  m.def("_get_xla_tensor_dimension_size",
        [](const at::Tensor& tensor, int dim) {
          return GetXlaTensorDimensionSize(tensor, dim);
        });
  m.def("_xla_nms", [](const at::Tensor& boxes, const at::Tensor& scores,
                       const at::Tensor& score_threshold,
                       const at::Tensor& iou_threshold,
                       xla::int64_t output_size) {
    return XlaNms(boxes, scores, score_threshold, iou_threshold, output_size);
  });
  m.def("_xla_user_computation",
        [](const std::string& opname, const std::vector<at::Tensor>& inputs,
           const ComputationPtr& computation) {
          std::vector<at::Tensor> results;
          {
            NoGilSection nogil;
            results = XlaUserComputation(opname, inputs, computation);
          }
          return results;
        });
  m.def("_get_xla_tensors_dot",
        [](const std::vector<at::Tensor>& tensors) -> std::string {
          auto coverter = [](absl::Span<const ir::Node* const> nodes) {
            return ir::DumpUtil::ToDot(nodes);
          };
          return GetTensorsDump(tensors, coverter);
        });
  m.def("_get_xla_tensors_text",
        [](const std::vector<at::Tensor>& tensors) -> std::string {
          auto coverter = [](absl::Span<const ir::Node* const> nodes) {
            return ir::DumpUtil::ToText(nodes);
          };
          return GetTensorsDump(tensors, coverter);
        });
  m.def("_get_xla_tensors_hlo",
        [](const std::vector<at::Tensor>& tensors) -> std::string {
          return GetTensorsHloGraph(tensors);
        });
  m.def("_xla_tensors_from_aten", [](const std::vector<at::Tensor>& tensors,
                                     const std::vector<std::string>& devices) {
    std::vector<at::Tensor> result;
    {
      NoGilSection nogil;
      std::vector<at::Tensor> xla_tensors =
          GetXlaTensorsFromAten(tensors, devices);
      result.reserve(xla_tensors.size());
      for (size_t i = 0; i < xla_tensors.size(); ++i) {
        result.push_back(torch::autograd::make_variable(
            xla_tensors[i], /*requires_grad=*/tensors.at(i).requires_grad()));
      }
    }
    return result;
  });
  m.def("_xla_get_cpu_tensors", [](const std::vector<at::Tensor>& tensors) {
    std::vector<at::Tensor> result;
    {
      NoGilSection nogil;
      std::vector<at::Tensor> cpu_tensors =
          bridge::XlaCreateTensorList(tensors);
      result.reserve(cpu_tensors.size());
      for (size_t i = 0; i < cpu_tensors.size(); ++i) {
        result.push_back(torch::autograd::make_variable(
            cpu_tensors[i], /*requires_grad=*/tensors.at(i).requires_grad()));
      }
    }
    return result;
  });
  m.def("_xla_get_tensor_view_alias_id",
        [](const at::Tensor& tensor) { return GetTensorViewAliasId(tensor); });
  m.def("_xla_get_tensor_id",
        [](const at::Tensor& tensor) { return GetTensorId(tensor); });
  m.def("_xla_get_devices",
        []() { return xla::ComputationClient::Get()->GetLocalDevices(); });
  m.def("_xla_get_all_devices",
        []() { return xla::ComputationClient::Get()->GetAllDevices(); });
  m.def("_xla_real_devices", [](const std::vector<std::string>& devices) {
    std::vector<std::string> xla_devices;
    {
      NoGilSection nogil;
      xla_devices = GetXlaDevices(devices);
    }
    return xla_devices;
  });
  m.def("_xla_set_replication_devices",
        [](const std::vector<std::string>& devices) {
          auto replication_devices =
              std::make_shared<std::vector<std::string>>(devices);
          xla::ComputationClient::Get()->SetReplicationDevices(
              std::move(replication_devices));
        });
  m.def("_xla_get_replication_devices", []() {
    auto replication_devices =
        xla::ComputationClient::Get()->GetReplicationDevices();
    return replication_devices != nullptr ? *replication_devices
                                          : std::vector<std::string>();
  });
  m.def("_xla_get_replication_devices_count", []() {
    auto replication_devices =
        xla::ComputationClient::Get()->GetReplicationDevices();
    return replication_devices != nullptr ? replication_devices->size() : 0;
  });
  m.def("_xla_rendezvous",
        [](int ordinal, const std::string& tag, const std::string& payload,
           const std::vector<xla::int64_t>& replicas) {
          return Rendezvous(ordinal, tag, payload, replicas);
        });

  py::class_<ir::Value, std::shared_ptr<ir::Value>>(m, "IrValue");
  m.def("_xla_create_token",
        [](const std::string& device) { return CreateToken(device); });
  m.def("_xla_all_reduce_inplace", [](const std::string& reduce_type,
                                      const std::vector<at::Tensor>& tensors,
                                      const std::shared_ptr<ir::Value>& token,
                                      double scale, const py::list& groups) {
    std::vector<std::vector<xla::int64_t>> replica_groups =
        CreateReduceGroups(groups);
    std::shared_ptr<ir::Value> new_token;
    {
      NoGilSection nogil;
      new_token =
          AllReduceInPlace(reduce_type, tensors, token, scale, replica_groups);
    }
    return new_token;
  });
  m.def("_xla_all_reduce",
        [](const std::string& reduce_type, const at::Tensor& input,
           const std::shared_ptr<ir::Value>& token, double scale,
           const py::list& groups) {
          std::vector<std::vector<xla::int64_t>> replica_groups =
              CreateReduceGroups(groups);
          at::Tensor result;
          std::shared_ptr<ir::Value> new_token;
          {
            NoGilSection nogil;
            std::tie(result, new_token) =
                AllReduce(reduce_type, input, token, scale, replica_groups);
          }
          auto result_tuple = py::tuple(2);
          result_tuple[0] = torch::autograd::make_variable(
              result, /*requires_grad=*/input.requires_grad());
          result_tuple[1] = new_token;
          return result_tuple;
        });
  m.def("_xla_all_to_all",
        [](const at::Tensor& input, const std::shared_ptr<ir::Value>& token,
           xla::int64_t split_dimension, xla::int64_t concat_dimension,
           xla::int64_t split_count, const py::list& groups) {
          std::vector<std::vector<xla::int64_t>> replica_groups =
              CreateReduceGroups(groups);
          at::Tensor result;
          std::shared_ptr<ir::Value> new_token;
          {
            NoGilSection nogil;
            std::tie(result, new_token) =
                AllToAll(input, token, split_dimension, concat_dimension,
                         split_count, replica_groups);
          }
          auto result_tuple = py::tuple(2);
          result_tuple[0] = torch::autograd::make_variable(
              result, /*requires_grad=*/input.requires_grad());
          result_tuple[1] = new_token;
          return result_tuple;
        });
  m.def("_xla_all_gather",
        [](const at::Tensor& input, const std::shared_ptr<ir::Value>& token,
           xla::int64_t dim, xla::int64_t shard_count, const py::list& groups) {
          std::vector<std::vector<xla::int64_t>> replica_groups =
              CreateReduceGroups(groups);
          at::Tensor result;
          std::shared_ptr<ir::Value> new_token;
          {
            NoGilSection nogil;
            std::tie(result, new_token) =
                AllGather(input, token, dim, shard_count, replica_groups);
          }
          auto result_tuple = py::tuple(2);
          result_tuple[0] = torch::autograd::make_variable(
              result, /*requires_grad=*/input.requires_grad());
          result_tuple[1] = new_token;
          return result_tuple;
        });
  m.def("_xla_collective_permute", [](const at::Tensor& input,
                                      const std::shared_ptr<ir::Value>& token,
                                      const py::list& pairs) {
    std::vector<std::pair<xla::int64_t, xla::int64_t>> source_target_pairs =
        CreateSourceTargetPairs(pairs);
    at::Tensor result;
    std::shared_ptr<ir::Value> new_token;
    {
      NoGilSection nogil;
      std::tie(result, new_token) =
          CollectivePermute(input, token, source_target_pairs);
    }
    auto result_tuple = py::tuple(2);
    result_tuple[0] = torch::autograd::make_variable(
        result, /*requires_grad=*/input.requires_grad());
    result_tuple[1] = new_token;
    return result_tuple;
  });
  m.def("_xla_reduce_scatter",
        [](const std::string& reduce_type, const at::Tensor& input,
           const std::shared_ptr<ir::Value>& token, double scale,
           xla::int64_t scatter_dim, xla::int64_t shard_count,
           const py::list& groups) {
          std::vector<std::vector<xla::int64_t>> replica_groups =
              CreateReduceGroups(groups);
          at::Tensor result;
          std::shared_ptr<ir::Value> new_token;
          {
            NoGilSection nogil;
            std::tie(result, new_token) =
                ReduceScatter(reduce_type, input, token, scale, scatter_dim,
                              shard_count, replica_groups);
          }
          auto result_tuple = py::tuple(2);
          result_tuple[0] = torch::autograd::make_variable(
              result, /*requires_grad=*/input.requires_grad());
          result_tuple[1] = new_token;
          return result_tuple;
        });
  m.def("_xla_set_default_device", [](const std::string& device) {
    return SetCurrentThreadDevice(device);
  });
  m.def("_xla_get_default_device", []() { return GetCurrentThreadDevice(); });
  m.def("_xla_set_rng_seed",
        [](xla::uint64 seed, const std::string& device) {
          SetRngSeed(seed, device);
        },
        py::arg("seed") = 101, py::arg("device") = "");
  m.def("_xla_get_rng_seed",
        [](const std::string& device) { return GetRngSeed(device); },
        py::arg("device") = "");
  m.def("_xla_sync_multi",
        [](const std::vector<at::Tensor>& tensors,
           const std::vector<std::string>& devices, bool wait,
           bool sync_xla_data) {
          NoGilSection nogil;
          SyncTensors(tensors, devices, wait, sync_xla_data);
        },
        py::arg("tensors"), py::arg("devices"), py::arg("wait") = true,
        py::arg("sync_xla_data") = true);
  m.def("_xla_sync_live_tensors",
        [](const std::string& device, const std::vector<std::string>& devices,
           bool wait) {
          NoGilSection nogil;
          SyncLiveTensors(device, devices, wait);
        },
        py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
  m.def("_xla_step_marker",
        [](const std::string& device, const std::vector<std::string>& devices,
           bool wait) {
          NoGilSection nogil;
          StepMarker(device, devices, wait);
        },
        py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
  m.def("_xla_wait_device_ops",
        [](const std::vector<std::string>& devices) {
          NoGilSection nogil;
          XLATensor::WaitDeviceOps(devices);
        },
        py::arg("devices"));
  m.def("_xla_counter_names", []() { return xla::metrics::GetCounterNames(); });
  m.def("_xla_counter_value", [](const std::string& name) -> py::object {
    xla::metrics::CounterData* data = xla::metrics::GetCounter(name);
    return data != nullptr ? py::cast<int64_t>(data->Value()) : py::none();
  });
  m.def("_xla_metric_names", []() { return xla::metrics::GetMetricNames(); });
  m.def("_xla_metric_data", [](const std::string& name) -> py::object {
    return GetMetricData(name);
  });
  m.def("_xla_metrics_report",
        []() { return xla::metrics_reader::CreateMetricReport(); });
  m.def("_xla_tensors_report",
        [](size_t nodes_threshold, const std::string& device) {
          return GetLiveTensorsReport(nodes_threshold, device);
        },
        py::arg("nodes_threshold") = 100, py::arg("device") = "");
  m.def("_xla_memory_info", [](const std::string& device) -> py::object {
    return GetMemoryInfo(device);
  });
  m.def("_xla_set_use_full_mat_mul_precision",
        [](bool use_full_mat_mul_precision) {
          XlaHelpers::set_mat_mul_precision(
              use_full_mat_mul_precision ? xla::PrecisionConfig::HIGHEST
                                         : xla::PrecisionConfig::DEFAULT);
        },
        py::arg("use_full_mat_mul_precision") = true);

  py::class_<xla::util::RecordReader, std::shared_ptr<xla::util::RecordReader>>(
      m, "RecordReader");
  m.def("_xla_create_tfrecord_reader",
        [](const std::string& path, const std::string& compression,
           xla::int64_t buffer_size) {
          NoGilSection nogil;
          return CreateRecordReader(path, compression, buffer_size);
        },
        py::arg("path"), py::arg("compression") = "",
        py::arg("buffer_size") = 16 * 1024 * 1024);
  m.def(
      "_xla_tfrecord_read",
      [](const std::shared_ptr<xla::util::RecordReader>& reader) -> py::object {
        xla::util::RecordReader::Data record;
        if (!RecordRead(reader, &record)) {
          return py::none();
        }
        return py::bytes(record.data(), record.size());
      });
  m.def("_xla_tfexample_read",
        [](const std::shared_ptr<xla::util::RecordReader>& reader) {
          return RecordReadExample(reader);
        });

  py::class_<tensorflow::RandomAccessFile>(m, "TfRdFile");
  m.def("_xla_tffile_open", [](const std::string& path) {
    std::unique_ptr<tensorflow::RandomAccessFile> file;
    {
      NoGilSection nogil;
      file = OpenTfFile(path);
    }
    return py::cast(file.release(),
                    pybind11::return_value_policy::take_ownership);
  });
  m.def("_xla_tffile_stat",
        [](const std::string& path) { return StatTfFile(path); });
  m.def("_xla_tffile_read",
        [](tensorflow::RandomAccessFile* file, uint64_t offset, size_t size) {
          return ReadTfFile(file, offset, size);
        });

  py::class_<tensorflow::WritableFile>(m, "TfWrFile");
  m.def("_xla_tffile_create", [](const std::string& path) {
    std::unique_ptr<tensorflow::WritableFile> file;
    {
      NoGilSection nogil;
      file = CreateTfFile(path);
    }
    return py::cast(file.release(),
                    pybind11::return_value_policy::take_ownership);
  });
  m.def("_xla_tffile_write",
        [](tensorflow::WritableFile* file, const std::string& data) {
          NoGilSection nogil;
          WriteTfFile(file, data);
        });
  m.def("_xla_tffile_flush", [](tensorflow::WritableFile* file) {
    NoGilSection nogil;
    FlushTfFile(file);
  });

  m.def("_xla_tffs_list",
        [](const std::string& pattern) { return ListTfFs(pattern); });
  m.def("_xla_tffs_remove", [](const std::string& path) {
    NoGilSection nogil;
    RemoveTfFile(path);
  });

  py::class_<xla::XlaBuilder, op_builder::BuilderPtr>(m, "XlaBuilder");
  py::class_<op_builder::Op, op_builder::OpPtr>(m, "XlaOp");
  py::class_<Computation, ComputationPtr>(m, "XlaComputation");
  m.def("_xla_op_create_builder", [](const std::string& name) {
    return std::make_shared<xla::XlaBuilder>(name);
  });
  m.def("_xla_op_tensor_shape",
        [](const at::Tensor& tensor, const std::string& device) {
          xla::Shape tensor_shape = GetTensorShape(tensor, device);
          return op_builder::ShapeToPyShape(tensor_shape);
        });
  m.def("_xla_op_param", [](op_builder::BuilderPtr builder,
                            xla::int64_t param_no, py::object py_shape) {
    xla::Shape shape = op_builder::PyShapeToShape(py_shape);
    xla::XlaOp param = xla::Parameter(builder.get(), param_no, shape,
                                      absl::StrCat("p", param_no));
    return std::make_shared<op_builder::Op>(std::move(builder),
                                            std::move(param));
  });
  m.def("_xla_op_build", [](const std::string& name, op_builder::OpPtr root) {
    ComputationPtr computation;
    {
      NoGilSection nogil;
      computation = CreateComputation(name, root->op);
    }
    return computation;
  });
  m.def("_xla_op_computation_from_module_proto",
        [](const std::string& name, const std::string& module_proto) {
          ComputationPtr computation;
          {
            NoGilSection nogil;
            computation = CreateComputationFromProto(name, module_proto);
          }
          return computation;
        });
  m.def("_xla_computation_text", [](const ComputationPtr& computation) {
    std::string hlo_text;
    {
      NoGilSection nogil;
      hlo_text = ConsumeValue(
          xla::util::GetComputationHloText(computation->computation()));
    }
    return hlo_text;
  });
  m.def("_xla_op_shape", [](op_builder::OpPtr op) {
    const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(op->op);
    return op_builder::ShapeToPyShape(shape);
  });
  m.def("_xla_op_builder", [](op_builder::OpPtr op) { return op->builder; });
  m.def("_xla_op_create",
        [](op_builder::BuilderPtr builder, const std::string& opname,
           const std::vector<op_builder::OpPtr>& operands, py::dict args) {
          return op_builder::CreateOp(builder, opname, operands, args);
        });
  m.def("_run_xrt_local_service", [](xla::uint64 service_port) {
    xla::ComputationClient::RunLocalService(service_port);
  });
  m.def("_xla_sgd_optimizer_step_",
        [](at::Tensor& step, at::Tensor& param, at::Tensor& buf,
           const at::Tensor& found_inf, const at::Tensor& d_p,
           double weight_decay, double momentum, double lr, double dampening,
           bool nesterov) {
          {
            NoGilSection nogil;
            XLATensor found_inf_xla = bridge::GetXlaTensor(found_inf);
            XLATensor step_xla = bridge::GetXlaTensor(step);
            XLATensor param_xla = bridge::GetXlaTensor(param);
            XLATensor d_p_xla = bridge::GetXlaTensor(d_p);
            XLATensor buf_xla = bridge::GetXlaTensor(buf);
            XLATensor::sgd_optimizer_step_(step_xla, param_xla, buf_xla,
                                           found_inf_xla, d_p_xla, weight_decay,
                                           momentum, lr, dampening, nesterov);
          }
        });

  BuildProfilerSubmodule(&m);
}