in orttraining/orttraining/python/orttraining_pybind_state.cc [348:791]
void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) {
py::class_<std::vector<OrtValue>>(m, "OrtValueVector")
.def(py::init<>())
.def("push_back", [](std::vector<OrtValue>* v, const OrtValue& ortvalue) {
v->push_back(ortvalue);
})
.def("push_back", [](std::vector<OrtValue>* v, py::object dlpack_tensor, const bool is_bool_tensor) {
v->push_back(FromDlpack(dlpack_tensor.ptr(), is_bool_tensor));
})
.def("reserve", [](std::vector<OrtValue>* v, const size_t len) { v->reserve(len); })
.def("shrink_to_fit", [](std::vector<OrtValue>* v) { v->shrink_to_fit(); })
.def("__len__", [](const std::vector<OrtValue>& v) { return v.size(); })
.def("__iter__", [](const std::vector<OrtValue>& v) {
return py::make_iterator(v.cbegin(), v.cend());
},
py::keep_alive<0, 1>())
.def("__getitem__", [](const std::vector<OrtValue>& v, const size_t idx) {
return v.at(idx);
})
.def("dlpack_at", [](std::vector<OrtValue>* v, const size_t idx) {
return py::reinterpret_steal<py::object>(ToDlpack(v->at(idx)));
});
py::class_<OrtValueCache, OrtValueCachePtr>(m, "OrtValueCache")
.def(py::init<>())
.def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) {
cache_ptr->emplace(node_arg_name, value);
})
.def("keys", [](const OrtValueCachePtr& cache_ptr) {
py::list keys;
for(auto kv : *cache_ptr.get()) {
keys.append(kv.first);
}
return keys;
})
.def("clear", [](const OrtValueCachePtr& cache_ptr) {
cache_ptr->clear();
})
.def("count", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name) {
return cache_ptr->count(node_arg_name);
})
.def("remove", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name) {
const auto& num_entries_erased = cache_ptr->erase(node_arg_name);
ORT_ENFORCE(num_entries_erased == 1, "NodeArg not found in cache: ", node_arg_name);
});
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
parameters.def(py::init())
.def_readwrite("loss_output_name", &TrainingParameters::loss_output_name)
.def_readwrite("immutable_weights", &TrainingParameters::immutable_weights)
.def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train)
.def_readwrite("weights_to_train", &TrainingParameters::weights_to_train)
.def_readwrite("sliced_tensor_names", &TrainingParameters::sliced_tensor_names)
.def_readwrite("training_optimizer_name", &TrainingParameters::training_optimizer_name)
.def_readwrite("lr_params_feed_name", &TrainingParameters::lr_params_feed_name)
.def_readwrite("optimizer_attributes_map", &TrainingParameters::optimizer_attributes_map)
.def_readwrite("optimizer_int_attributes_map", &TrainingParameters::optimizer_int_attributes_map)
.def_readwrite("sliced_schema", &TrainingParameters::sliced_schema)
.def_readwrite("sliced_axes", &TrainingParameters::sliced_axes)
.def_readwrite("use_fp16_moments", &TrainingParameters::use_fp16_moments)
.def_readwrite("use_mixed_precision", &TrainingParameters::use_mixed_precision)
.def_readwrite("allreduce_post_accumulation", &TrainingParameters::allreduce_post_accumulation)
.def_readwrite("loss_scale", &TrainingParameters::loss_scale)
.def_readwrite("world_rank", &TrainingParameters::world_rank)
.def_readwrite("world_size", &TrainingParameters::world_size)
.def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size)
.def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size)
.def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size)
.def_readwrite("pipeline_cut_info_string", &TrainingParameters::pipeline_cut_info_string)
.def_readwrite("num_pipeline_micro_batches", &TrainingParameters::num_pipeline_micro_batches)
.def_readwrite("gradient_accumulation_steps", &TrainingParameters::gradient_accumulation_steps)
.def_readwrite("deepspeed_zero_stage", &TrainingParameters::deepspeed_zero_stage)
.def_readwrite("enable_grad_norm_clip", &TrainingParameters::enable_grad_norm_clip)
.def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs)
.def_readwrite("use_memory_efficient_gradient", &TrainingParameters::use_memory_efficient_gradient)
.def_readwrite("attn_dropout_recompute", &TrainingParameters::attn_dropout_recompute)
.def_readwrite("gelu_recompute", &TrainingParameters::gelu_recompute)
.def_readwrite("transformer_layer_recompute", &TrainingParameters::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers)
.def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size)
.def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size)
.def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size)
.def("set_optimizer_initial_state",
[](TrainingParameters& parameters, const std::unordered_map<std::string, std::unordered_map<std::string, py::object>>& py_state) -> void {
onnxruntime::training::TrainingSession::OptimizerState optim_state;
for (const auto& weight_it : py_state) {
auto state = weight_it.second;
NameMLValMap state_tensors;
for (auto& initializer : state) {
OrtValue ml_value;
// InputDeflist is null because parameters havent been tied to session yet
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
CreateGenericMLValue(nullptr, GetAllocator(), "", initializer.second, &ml_value, true);
ThrowIfPyErrOccured();
state_tensors.emplace(initializer.first, ml_value);
}
optim_state.emplace(weight_it.first, state_tensors);
}
parameters.optimizer_initial_state = optim_state;
})
.def_readwrite("model_after_graph_transforms_path", &TrainingParameters::model_after_graph_transforms_path)
.def_readwrite("model_with_gradient_graph_path", &TrainingParameters::model_with_gradient_graph_path)
.def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path)
.def_readwrite("enable_adasum", &TrainingParameters::enable_adasum)
.def_readwrite("propagate_cast_ops_level", &TrainingParameters::propagate_cast_ops_level)
.def_readwrite("propagate_cast_ops_allow", &TrainingParameters::propagate_cast_ops_allow)
.def_readwrite("allow_layer_norm_mod_precision", &TrainingParameters::allow_layer_norm_mod_precision);
#if defined(USE_MPI)
m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); });
m.def("get_mpi_context_local_size", []() -> int { return MPIContext::GetInstance().GetLocalSize(); });
m.def("get_mpi_context_world_rank", []() -> int { return MPIContext::GetInstance().GetWorldRank(); });
m.def("get_mpi_context_world_size", []() -> int { return MPIContext::GetInstance().GetWorldSize(); });
#endif
m.def("register_aten_op_executor",
[](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
size_t is_tensor_argument_address_int, aten_op_executor_address_int;
ORT_THROW_IF_ERROR(
ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int));
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
});
m.def("register_forward_runner", [](py::object obj) -> void {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance();
pool.RegisterForwardRunner(obj.ptr());
#else
ORT_UNUSED_PARAMETER(obj);
#endif
});
m.def("register_backward_runner", [](py::object obj) -> void {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance();
pool.RegisterBackwardRunner(obj.ptr());
#else
ORT_UNUSED_PARAMETER(obj);
#endif
});
m.def("register_torch_autograd_function", [](std::string key, py::object obj) -> void {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance();
pool.RegisterTorchAutogradFunction(key, obj.ptr());
#else
ORT_UNUSED_PARAMETER(key);
ORT_UNUSED_PARAMETER(obj);
#endif
});
m.def("unregister_python_functions", []() -> void {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
// Release all custom python functions registered.
auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance();
pool.UnRegisterFunctions();
#endif
});
py::class_<TrainingConfigurationResult> config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc");
config_result.def(py::init())
.def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object {
if (result.loss_scale_input_name.has_value()) {
return py::str{result.loss_scale_input_name.value()};
}
return py::none();
});
// Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user
struct PyTrainingSession : public PyInferenceSession {
PyTrainingSession(Environment& env, const PySessionOptions& so)
: PyInferenceSession(std::make_unique<PipelineTrainingSession>(so, env)) {
}
};
py::class_<PyTrainingSession, PyInferenceSession> training_session(m, "TrainingSession");
training_session
.def(py::init([](const PySessionOptions& so) {
Environment& env = GetTrainingORTEnv();
return std::make_unique<PyTrainingSession>(env, so);
}))
.def(py::init([]() {
Environment& env = GetTrainingORTEnv();
return std::make_unique<PyTrainingSession>(env, GetDefaultCPUSessionOptions());
}))
.def("finalize", [](py::object) {
#if defined(USE_MPI)
#ifdef _WIN32
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
// shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction
// call shutdown_mpi() here instead.
MPIContext::shutdown_mpi();
#endif
#endif
})
.def("load_model", [ep_registration_fn](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path));
#if defined(USE_MPI)
bool use_nccl = parameters.allreduce_post_accumulation;
if (!use_nccl && parameters.world_size > 1)
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
ProviderOptionsVector merged_options;
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
return config_result;
})
.def("read_bytes", [ep_registration_fn](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
std::istringstream buffer(serialized_model);
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer));
#if defined(USE_MPI)
bool use_nccl = parameters.allreduce_post_accumulation;
if (!use_nccl && parameters.world_size > 1)
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
ProviderOptionsVector merged_options;
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
return config_result;
})
.def("get_state", [](PyTrainingSession* sess) {
NameMLValMap state_tensors;
ORT_THROW_IF_ERROR(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->GetStateTensors(state_tensors));
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
//convert to numpy array
std::map<std::string, py::object> rmap;
for (auto& kv : state_tensors) {
if (kv.second.IsTensor()) {
py::object obj;
const Tensor& rtensor = kv.second.Get<Tensor>();
GetPyObjFromTensor(rtensor, obj, &data_transfer_manager);
rmap.insert({kv.first, obj});
} else {
throw std::runtime_error("Non tensor type in session state tensors is not expected.");
}
}
return rmap;
})
.def("get_model_state", [](PyTrainingSession* sess, bool include_mixed_precision_weights) {
std::unordered_map<std::string, NameMLValMap> model_state_tensors;
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetModelState(model_state_tensors, include_mixed_precision_weights));
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
return ConvertORTTensorMapToNumpy(model_state_tensors, data_transfer_manager);
})
.def("get_optimizer_state", [](PyTrainingSession* sess) {
std::unordered_map<std::string, NameMLValMap> opt_state_tensors;
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetOptimizerState(opt_state_tensors));
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
return ConvertORTTensorMapToNumpy(opt_state_tensors, data_transfer_manager);
})
.def("get_partition_info_map", [](PyTrainingSession* sess) {
std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> part_info_map;
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetPartitionInfoMap(part_info_map));
return part_info_map;
})
.def("load_state", [](PyTrainingSession* sess, std::unordered_map<std::string, py::object>& state, bool strict) {
NameMLValMap state_tensors;
for (auto initializer : state) {
OrtValue ml_value;
auto px = sess->GetSessionHandle()->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
CreateGenericMLValue(px.second, GetAllocator(), initializer.first, initializer.second, &ml_value);
ThrowIfPyErrOccured();
state_tensors.insert(std::make_pair(initializer.first, ml_value));
}
ORT_THROW_IF_ERROR(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->SetStateTensors(state_tensors, strict));
})
.def("is_output_fp32_node", [](PyTrainingSession* sess, const std::string& output_name) {
return static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name);
});
py::class_<PartialGraphExecutionState>(m, "PartialGraphExecutionState")
.def(py::init([]() {
return std::make_unique<PartialGraphExecutionState>();
}));
py::class_<TrainingAgent>(m, "TrainingAgent", R"pbdoc(This is the main class used to run a ORTModule model.)pbdoc")
.def(py::init([](PyInferenceSession* session, const std::vector<std::string>& fw_feed_names,
const std::vector<OrtDevice>& fw_outputs_device_info,
const std::vector<std::string>& bw_fetches_names,
const std::vector<OrtDevice>& bw_outputs_device_info) {
return std::make_unique<TrainingAgent>(*session->GetSessionHandle(), fw_feed_names, fw_outputs_device_info,
bw_fetches_names, bw_outputs_device_info);
}))
.def("run_forward", [](TrainingAgent* agent, const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches, PartialGraphExecutionState* state, OrtValueCachePtr cache) -> void {
Status status = agent->RunForward(feeds, fetches, *state, cache);
if (!status.IsOK()) {
throw std::runtime_error("Error in forward pass execution: " + status.ErrorMessage());
}
})
.def("run_backward", [](TrainingAgent* agent, const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches, PartialGraphExecutionState* state) -> void {
Status status = agent->RunBackward(feeds, fetches, *state);
if (!status.IsOK()) {
throw std::runtime_error("Error in backward pass execution: " + status.ErrorMessage());
}
});
py::enum_<GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy>(m, "PropagateCastOpsStrategy", py::module_local(), py::arithmetic{})
.value("NONE", GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::None)
.value("INSERT_AND_REDUCE", GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::InsertAndReduce)
.value("FLOOD_FILL", GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::FloodFill)
.def("__or__", py::overload_cast<GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy,
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy>(&operator|))
.def("__and__", py::overload_cast<GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy,
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy>(&operator&))
.def("__eq__", py::overload_cast<GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy,
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy>(&operator==))
.def("__neq__", py::overload_cast<GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy,
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy>(&operator!=));
py::class_<GraphTransformerConfiguration::PropagateCastOpsConfiguration>
propagate_cast_ops_config(
m, "PropagateCastOpsConfiguration",
R"pbdoc(Propagate cast ops configuration.)pbdoc");
propagate_cast_ops_config.def(py::init())
.def_readwrite("strategy", &GraphTransformerConfiguration::PropagateCastOpsConfiguration::strategy)
.def_readwrite("level", &GraphTransformerConfiguration::PropagateCastOpsConfiguration::level)
.def_readwrite("allow", &GraphTransformerConfiguration::PropagateCastOpsConfiguration::allow);
py::class_<GraphTransformerConfiguration> graph_transformer_config(
m, "GraphTransformerConfiguration",
R"pbdoc(Graph transformer configuration.)pbdoc");
graph_transformer_config.def(py::init())
.def_readwrite("propagate_cast_ops_config", &GraphTransformerConfiguration::propagate_cast_ops_config);
py::class_<TrainingGraphTransformerConfiguration, GraphTransformerConfiguration> training_graph_transformer_config(
m, "TrainingGraphTransformerConfiguration",
R"pbdoc(Training Graph transformer configuration.)pbdoc");
training_graph_transformer_config.def(py::init())
.def_readwrite("enable_gelu_approximation", &TrainingGraphTransformerConfiguration::enable_gelu_approximation)
.def_readwrite("attn_dropout_recompute", &TrainingGraphTransformerConfiguration::attn_dropout_recompute)
.def_readwrite("gelu_recompute", &TrainingGraphTransformerConfiguration::gelu_recompute)
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
.def_readwrite("allow_layer_norm_mod_precision", &TrainingGraphTransformerConfiguration::allow_layer_norm_mod_precision)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
py::class_<OrtModuleGraphBuilderConfiguration> module_graph_builder_config(
m, "OrtModuleGraphBuilderConfiguration",
R"pbdoc(Configuration information for module graph builder.)pbdoc");
py::enum_<Severity>(m, "Severity", py::arithmetic(), py::module_local())
.value("VERBOSE", logging::Severity::kVERBOSE)
.value("INFO", logging::Severity::kINFO)
.value("WARNING", logging::Severity::kWARNING)
.value("ERROR", logging::Severity::kERROR)
.value("FATAL", logging::Severity::kFATAL);
module_graph_builder_config.def(py::init())
.def_readwrite("initializer_names", &OrtModuleGraphBuilderConfiguration::initializer_names)
.def_readwrite("initializer_names_to_train", &OrtModuleGraphBuilderConfiguration::initializer_names_to_train)
.def_readwrite("input_names_require_grad", &OrtModuleGraphBuilderConfiguration::input_names_require_grad)
.def_readwrite("use_memory_efficient_gradient",
&OrtModuleGraphBuilderConfiguration::use_memory_efficient_gradient)
.def_readwrite("build_gradient_graph", &OrtModuleGraphBuilderConfiguration::build_gradient_graph)
.def_readwrite("graph_transformer_config", &OrtModuleGraphBuilderConfiguration::graph_transformer_config)
.def_readwrite("enable_caching", &OrtModuleGraphBuilderConfiguration::enable_caching)
.def_readwrite("loglevel", &OrtModuleGraphBuilderConfiguration::loglevel);
py::class_<GraphInfo> graph_info(m, "GraphInfo",
R"pbdoc(The information of split graphs for frontend.)pbdoc");
graph_info.def(py::init())
.def_readwrite("user_input_names", &GraphInfo::user_input_names)
.def_readwrite("user_input_grad_names", &GraphInfo::user_input_grad_names)
.def_readwrite("initializer_names", &GraphInfo::initializer_names)
.def_readwrite("initializer_names_to_train", &GraphInfo::initializer_names_to_train)
.def_readwrite("initializer_grad_names_to_train", &GraphInfo::initializer_grad_names_to_train)
.def_readwrite("user_output_names", &GraphInfo::user_output_names)
.def_readwrite("output_grad_indices_non_differentiable", &GraphInfo::output_grad_indices_non_differentiable)
.def_readwrite("output_grad_indices_require_full_shape", &GraphInfo::output_grad_indices_require_full_shape)
.def_readwrite("module_output_indices_requires_save_for_backward", &GraphInfo::module_output_indices_requires_save_for_backward)
.def_readwrite("frontier_node_arg_map", &GraphInfo::frontier_node_arg_map)
.def_readwrite("cached_node_arg_names", &GraphInfo::cached_node_arg_names)
.def_readwrite("module_output_gradient_name", &GraphInfo::module_output_gradient_name);
py::class_<OrtModuleGraphBuilder> ortmodule_graph_builder(m, "OrtModuleGraphBuilder");
ortmodule_graph_builder.def(py::init([]() { return std::make_unique<OrtModuleGraphBuilder>(); }))
.def("initialize",
[](OrtModuleGraphBuilder* ortmodule_graph_builder, const py::bytes& serialized_model,
const OrtModuleGraphBuilderConfiguration& config) {
std::istringstream buffer(serialized_model);
ORT_THROW_IF_ERROR(ortmodule_graph_builder->Initialize(buffer, config));
})
.def("build",
[](OrtModuleGraphBuilder* ortmodule_graph_builder) {
ORT_THROW_IF_ERROR(ortmodule_graph_builder->Build());
})
.def("build",
[](OrtModuleGraphBuilder* ortmodule_graph_builder,
const std::vector<std::vector<int64_t>>& input_shapes) {
ORT_THROW_IF_ERROR(ortmodule_graph_builder->Build(&input_shapes));
})
.def("get_model",
[](OrtModuleGraphBuilder* ortmodule_graph_builder) {
return py::bytes(ortmodule_graph_builder->GetModel());
})
.def("get_inference_optimized_model",
[](OrtModuleGraphBuilder* ortmodule_graph_builder) {
return py::bytes(ortmodule_graph_builder->GetInferenceOptimizedModel());
})
.def("get_graph_info", [](OrtModuleGraphBuilder* ortmodule_graph_builder) {
return ortmodule_graph_builder->GetGraphInfo();
});
py::class_<GradientNodeAttributeDefinition> gradient_node_attribute_definition(
m, "GradientNodeAttributeDefinition", R"pbdoc(Attribute definition for gradient graph nodes.)pbdoc");
gradient_node_attribute_definition.def(py::init())
.def_readwrite("name", &GradientNodeAttributeDefinition::name)
.def_readwrite("value_json", &GradientNodeAttributeDefinition::value_json)
.def_readwrite("dtype", &GradientNodeAttributeDefinition::dtype)
.def_readwrite("is_tensor", &GradientNodeAttributeDefinition::is_tensor);
py::class_<GradientNodeDefinition> gradient_node_definition(m, "GradientNodeDefinition",
R"pbdoc(Definition for gradient graph nodes.)pbdoc");
gradient_node_definition.def(py::init())
.def_readwrite("op_type", &GradientNodeDefinition::op_type)
.def_readwrite("domain", &GradientNodeDefinition::domain)
.def_readwrite("inputs", &GradientNodeDefinition::inputs)
.def_readwrite("outputs", &GradientNodeDefinition::outputs)
.def_readwrite("attributes", &GradientNodeDefinition::attributes);
m.def("register_gradient_definition",
[](const std::string& key, const std::vector<GradientNodeDefinition>& gradient_def) -> void {
GradientDefinitionRegistry::Instance().Register(key, gradient_def);
});
m.def("register_custom_stop_gradient_edges",
[](const std::string& key, const std::unordered_set<size_t> edges) -> void {
GradientDefinitionRegistry::Instance().SetStopGradientEdgesForNode(key, edges);
});
}