in torch_xla/csrc/init_python_bindings.cpp [749:1185]
void InitXlaModuleBindings(py::module m) {
m.def("_prepare_to_exit", []() { PrepareToExit(); });
m.def("_get_git_revs", []() { return GetRevisions(); });
m.def("_get_xla_tensor_dimension_size",
[](const at::Tensor& tensor, int dim) {
return GetXlaTensorDimensionSize(tensor, dim);
});
m.def("_xla_nms", [](const at::Tensor& boxes, const at::Tensor& scores,
const at::Tensor& score_threshold,
const at::Tensor& iou_threshold,
xla::int64_t output_size) {
return XlaNms(boxes, scores, score_threshold, iou_threshold, output_size);
});
m.def("_xla_user_computation",
[](const std::string& opname, const std::vector<at::Tensor>& inputs,
const ComputationPtr& computation) {
std::vector<at::Tensor> results;
{
NoGilSection nogil;
results = XlaUserComputation(opname, inputs, computation);
}
return results;
});
m.def("_get_xla_tensors_dot",
[](const std::vector<at::Tensor>& tensors) -> std::string {
auto coverter = [](absl::Span<const ir::Node* const> nodes) {
return ir::DumpUtil::ToDot(nodes);
};
return GetTensorsDump(tensors, coverter);
});
m.def("_get_xla_tensors_text",
[](const std::vector<at::Tensor>& tensors) -> std::string {
auto coverter = [](absl::Span<const ir::Node* const> nodes) {
return ir::DumpUtil::ToText(nodes);
};
return GetTensorsDump(tensors, coverter);
});
m.def("_get_xla_tensors_hlo",
[](const std::vector<at::Tensor>& tensors) -> std::string {
return GetTensorsHloGraph(tensors);
});
m.def("_xla_tensors_from_aten", [](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices) {
std::vector<at::Tensor> result;
{
NoGilSection nogil;
std::vector<at::Tensor> xla_tensors =
GetXlaTensorsFromAten(tensors, devices);
result.reserve(xla_tensors.size());
for (size_t i = 0; i < xla_tensors.size(); ++i) {
result.push_back(torch::autograd::make_variable(
xla_tensors[i], /*requires_grad=*/tensors.at(i).requires_grad()));
}
}
return result;
});
m.def("_xla_get_cpu_tensors", [](const std::vector<at::Tensor>& tensors) {
std::vector<at::Tensor> result;
{
NoGilSection nogil;
std::vector<at::Tensor> cpu_tensors =
bridge::XlaCreateTensorList(tensors);
result.reserve(cpu_tensors.size());
for (size_t i = 0; i < cpu_tensors.size(); ++i) {
result.push_back(torch::autograd::make_variable(
cpu_tensors[i], /*requires_grad=*/tensors.at(i).requires_grad()));
}
}
return result;
});
m.def("_xla_get_tensor_view_alias_id",
[](const at::Tensor& tensor) { return GetTensorViewAliasId(tensor); });
m.def("_xla_get_tensor_id",
[](const at::Tensor& tensor) { return GetTensorId(tensor); });
m.def("_xla_get_devices",
[]() { return xla::ComputationClient::Get()->GetLocalDevices(); });
m.def("_xla_get_all_devices",
[]() { return xla::ComputationClient::Get()->GetAllDevices(); });
m.def("_xla_real_devices", [](const std::vector<std::string>& devices) {
std::vector<std::string> xla_devices;
{
NoGilSection nogil;
xla_devices = GetXlaDevices(devices);
}
return xla_devices;
});
m.def("_xla_set_replication_devices",
[](const std::vector<std::string>& devices) {
auto replication_devices =
std::make_shared<std::vector<std::string>>(devices);
xla::ComputationClient::Get()->SetReplicationDevices(
std::move(replication_devices));
});
m.def("_xla_get_replication_devices", []() {
auto replication_devices =
xla::ComputationClient::Get()->GetReplicationDevices();
return replication_devices != nullptr ? *replication_devices
: std::vector<std::string>();
});
m.def("_xla_get_replication_devices_count", []() {
auto replication_devices =
xla::ComputationClient::Get()->GetReplicationDevices();
return replication_devices != nullptr ? replication_devices->size() : 0;
});
m.def("_xla_rendezvous",
[](int ordinal, const std::string& tag, const std::string& payload,
const std::vector<xla::int64_t>& replicas) {
return Rendezvous(ordinal, tag, payload, replicas);
});
py::class_<ir::Value, std::shared_ptr<ir::Value>>(m, "IrValue");
m.def("_xla_create_token",
[](const std::string& device) { return CreateToken(device); });
m.def("_xla_all_reduce_inplace", [](const std::string& reduce_type,
const std::vector<at::Tensor>& tensors,
const std::shared_ptr<ir::Value>& token,
double scale, const py::list& groups) {
std::vector<std::vector<xla::int64_t>> replica_groups =
CreateReduceGroups(groups);
std::shared_ptr<ir::Value> new_token;
{
NoGilSection nogil;
new_token =
AllReduceInPlace(reduce_type, tensors, token, scale, replica_groups);
}
return new_token;
});
m.def("_xla_all_reduce",
[](const std::string& reduce_type, const at::Tensor& input,
const std::shared_ptr<ir::Value>& token, double scale,
const py::list& groups) {
std::vector<std::vector<xla::int64_t>> replica_groups =
CreateReduceGroups(groups);
at::Tensor result;
std::shared_ptr<ir::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) =
AllReduce(reduce_type, input, token, scale, replica_groups);
}
auto result_tuple = py::tuple(2);
result_tuple[0] = torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_all_to_all",
[](const at::Tensor& input, const std::shared_ptr<ir::Value>& token,
xla::int64_t split_dimension, xla::int64_t concat_dimension,
xla::int64_t split_count, const py::list& groups) {
std::vector<std::vector<xla::int64_t>> replica_groups =
CreateReduceGroups(groups);
at::Tensor result;
std::shared_ptr<ir::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) =
AllToAll(input, token, split_dimension, concat_dimension,
split_count, replica_groups);
}
auto result_tuple = py::tuple(2);
result_tuple[0] = torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_all_gather",
[](const at::Tensor& input, const std::shared_ptr<ir::Value>& token,
xla::int64_t dim, xla::int64_t shard_count, const py::list& groups) {
std::vector<std::vector<xla::int64_t>> replica_groups =
CreateReduceGroups(groups);
at::Tensor result;
std::shared_ptr<ir::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) =
AllGather(input, token, dim, shard_count, replica_groups);
}
auto result_tuple = py::tuple(2);
result_tuple[0] = torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_collective_permute", [](const at::Tensor& input,
const std::shared_ptr<ir::Value>& token,
const py::list& pairs) {
std::vector<std::pair<xla::int64_t, xla::int64_t>> source_target_pairs =
CreateSourceTargetPairs(pairs);
at::Tensor result;
std::shared_ptr<ir::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) =
CollectivePermute(input, token, source_target_pairs);
}
auto result_tuple = py::tuple(2);
result_tuple[0] = torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_reduce_scatter",
[](const std::string& reduce_type, const at::Tensor& input,
const std::shared_ptr<ir::Value>& token, double scale,
xla::int64_t scatter_dim, xla::int64_t shard_count,
const py::list& groups) {
std::vector<std::vector<xla::int64_t>> replica_groups =
CreateReduceGroups(groups);
at::Tensor result;
std::shared_ptr<ir::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) =
ReduceScatter(reduce_type, input, token, scale, scatter_dim,
shard_count, replica_groups);
}
auto result_tuple = py::tuple(2);
result_tuple[0] = torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_set_default_device", [](const std::string& device) {
return SetCurrentThreadDevice(device);
});
m.def("_xla_get_default_device", []() { return GetCurrentThreadDevice(); });
m.def("_xla_set_rng_seed",
[](xla::uint64 seed, const std::string& device) {
SetRngSeed(seed, device);
},
py::arg("seed") = 101, py::arg("device") = "");
m.def("_xla_get_rng_seed",
[](const std::string& device) { return GetRngSeed(device); },
py::arg("device") = "");
m.def("_xla_sync_multi",
[](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices, bool wait,
bool sync_xla_data) {
NoGilSection nogil;
SyncTensors(tensors, devices, wait, sync_xla_data);
},
py::arg("tensors"), py::arg("devices"), py::arg("wait") = true,
py::arg("sync_xla_data") = true);
m.def("_xla_sync_live_tensors",
[](const std::string& device, const std::vector<std::string>& devices,
bool wait) {
NoGilSection nogil;
SyncLiveTensors(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
m.def("_xla_step_marker",
[](const std::string& device, const std::vector<std::string>& devices,
bool wait) {
NoGilSection nogil;
StepMarker(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
m.def("_xla_wait_device_ops",
[](const std::vector<std::string>& devices) {
NoGilSection nogil;
XLATensor::WaitDeviceOps(devices);
},
py::arg("devices"));
m.def("_xla_counter_names", []() { return xla::metrics::GetCounterNames(); });
m.def("_xla_counter_value", [](const std::string& name) -> py::object {
xla::metrics::CounterData* data = xla::metrics::GetCounter(name);
return data != nullptr ? py::cast<int64_t>(data->Value()) : py::none();
});
m.def("_xla_metric_names", []() { return xla::metrics::GetMetricNames(); });
m.def("_xla_metric_data", [](const std::string& name) -> py::object {
return GetMetricData(name);
});
m.def("_xla_metrics_report",
[]() { return xla::metrics_reader::CreateMetricReport(); });
m.def("_xla_tensors_report",
[](size_t nodes_threshold, const std::string& device) {
return GetLiveTensorsReport(nodes_threshold, device);
},
py::arg("nodes_threshold") = 100, py::arg("device") = "");
m.def("_xla_memory_info", [](const std::string& device) -> py::object {
return GetMemoryInfo(device);
});
m.def("_xla_set_use_full_mat_mul_precision",
[](bool use_full_mat_mul_precision) {
XlaHelpers::set_mat_mul_precision(
use_full_mat_mul_precision ? xla::PrecisionConfig::HIGHEST
: xla::PrecisionConfig::DEFAULT);
},
py::arg("use_full_mat_mul_precision") = true);
py::class_<xla::util::RecordReader, std::shared_ptr<xla::util::RecordReader>>(
m, "RecordReader");
m.def("_xla_create_tfrecord_reader",
[](const std::string& path, const std::string& compression,
xla::int64_t buffer_size) {
NoGilSection nogil;
return CreateRecordReader(path, compression, buffer_size);
},
py::arg("path"), py::arg("compression") = "",
py::arg("buffer_size") = 16 * 1024 * 1024);
m.def(
"_xla_tfrecord_read",
[](const std::shared_ptr<xla::util::RecordReader>& reader) -> py::object {
xla::util::RecordReader::Data record;
if (!RecordRead(reader, &record)) {
return py::none();
}
return py::bytes(record.data(), record.size());
});
m.def("_xla_tfexample_read",
[](const std::shared_ptr<xla::util::RecordReader>& reader) {
return RecordReadExample(reader);
});
py::class_<tensorflow::RandomAccessFile>(m, "TfRdFile");
m.def("_xla_tffile_open", [](const std::string& path) {
std::unique_ptr<tensorflow::RandomAccessFile> file;
{
NoGilSection nogil;
file = OpenTfFile(path);
}
return py::cast(file.release(),
pybind11::return_value_policy::take_ownership);
});
m.def("_xla_tffile_stat",
[](const std::string& path) { return StatTfFile(path); });
m.def("_xla_tffile_read",
[](tensorflow::RandomAccessFile* file, uint64_t offset, size_t size) {
return ReadTfFile(file, offset, size);
});
py::class_<tensorflow::WritableFile>(m, "TfWrFile");
m.def("_xla_tffile_create", [](const std::string& path) {
std::unique_ptr<tensorflow::WritableFile> file;
{
NoGilSection nogil;
file = CreateTfFile(path);
}
return py::cast(file.release(),
pybind11::return_value_policy::take_ownership);
});
m.def("_xla_tffile_write",
[](tensorflow::WritableFile* file, const std::string& data) {
NoGilSection nogil;
WriteTfFile(file, data);
});
m.def("_xla_tffile_flush", [](tensorflow::WritableFile* file) {
NoGilSection nogil;
FlushTfFile(file);
});
m.def("_xla_tffs_list",
[](const std::string& pattern) { return ListTfFs(pattern); });
m.def("_xla_tffs_remove", [](const std::string& path) {
NoGilSection nogil;
RemoveTfFile(path);
});
py::class_<xla::XlaBuilder, op_builder::BuilderPtr>(m, "XlaBuilder");
py::class_<op_builder::Op, op_builder::OpPtr>(m, "XlaOp");
py::class_<Computation, ComputationPtr>(m, "XlaComputation");
m.def("_xla_op_create_builder", [](const std::string& name) {
return std::make_shared<xla::XlaBuilder>(name);
});
m.def("_xla_op_tensor_shape",
[](const at::Tensor& tensor, const std::string& device) {
xla::Shape tensor_shape = GetTensorShape(tensor, device);
return op_builder::ShapeToPyShape(tensor_shape);
});
m.def("_xla_op_param", [](op_builder::BuilderPtr builder,
xla::int64_t param_no, py::object py_shape) {
xla::Shape shape = op_builder::PyShapeToShape(py_shape);
xla::XlaOp param = xla::Parameter(builder.get(), param_no, shape,
absl::StrCat("p", param_no));
return std::make_shared<op_builder::Op>(std::move(builder),
std::move(param));
});
m.def("_xla_op_build", [](const std::string& name, op_builder::OpPtr root) {
ComputationPtr computation;
{
NoGilSection nogil;
computation = CreateComputation(name, root->op);
}
return computation;
});
m.def("_xla_op_computation_from_module_proto",
[](const std::string& name, const std::string& module_proto) {
ComputationPtr computation;
{
NoGilSection nogil;
computation = CreateComputationFromProto(name, module_proto);
}
return computation;
});
m.def("_xla_computation_text", [](const ComputationPtr& computation) {
std::string hlo_text;
{
NoGilSection nogil;
hlo_text = ConsumeValue(
xla::util::GetComputationHloText(computation->computation()));
}
return hlo_text;
});
m.def("_xla_op_shape", [](op_builder::OpPtr op) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(op->op);
return op_builder::ShapeToPyShape(shape);
});
m.def("_xla_op_builder", [](op_builder::OpPtr op) { return op->builder; });
m.def("_xla_op_create",
[](op_builder::BuilderPtr builder, const std::string& opname,
const std::vector<op_builder::OpPtr>& operands, py::dict args) {
return op_builder::CreateOp(builder, opname, operands, args);
});
m.def("_run_xrt_local_service", [](xla::uint64 service_port) {
xla::ComputationClient::RunLocalService(service_port);
});
m.def("_xla_sgd_optimizer_step_",
[](at::Tensor& step, at::Tensor& param, at::Tensor& buf,
const at::Tensor& found_inf, const at::Tensor& d_p,
double weight_decay, double momentum, double lr, double dampening,
bool nesterov) {
{
NoGilSection nogil;
XLATensor found_inf_xla = bridge::GetXlaTensor(found_inf);
XLATensor step_xla = bridge::GetXlaTensor(step);
XLATensor param_xla = bridge::GetXlaTensor(param);
XLATensor d_p_xla = bridge::GetXlaTensor(d_p);
XLATensor buf_xla = bridge::GetXlaTensor(buf);
XLATensor::sgd_optimizer_step_(step_xla, param_xla, buf_xla,
found_inf_xla, d_p_xla, weight_decay,
momentum, lr, dampening, nesterov);
}
});
BuildProfilerSubmodule(&m);
}