src/node/serialization.cc (591 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file node/serialization.cc * \brief Utilities to serialize TVM AST/IR objects. */ #include <dmlc/json.h> #include <dmlc/memory_io.h> #include <tvm/ir/attrs.h> #include <tvm/node/reflection.h> #include <tvm/node/serialization.h> #include <tvm/runtime/ndarray.h> #include <tvm/runtime/packed_func.h> #include <tvm/runtime/registry.h> #include <cctype> #include <map> #include <string> #include "../runtime/object_internal.h" #include "../support/base64.h" namespace tvm { inline std::string Type2String(const DataType& t) { return runtime::DLDataTypeToString(t); } inline DataType String2Type(std::string s) { return DataType(runtime::StringToDLDataType(s)); } inline std::string Base64Decode(std::string s) { dmlc::MemoryStringStream mstrm(&s); support::Base64InStream b64strm(&mstrm); std::string output; b64strm.InitPosition(); dmlc::Stream* strm = &b64strm; strm->Read(&output); return output; } inline std::string Base64Encode(std::string s) { std::string blob; dmlc::MemoryStringStream mstrm(&blob); support::Base64OutStream b64strm(&mstrm); dmlc::Stream* strm = &b64strm; strm->Write(s); b64strm.Finish(); return blob; } // indexer to index all the nodes class NodeIndexer : public AttrVisitor { public: std::unordered_map<Any, size_t, ffi::AnyHash, ffi::AnyEqual> node_index_{{Any(nullptr), 0}}; std::vector<Any> node_list_{Any(nullptr)}; std::unordered_map<DLTensor*, size_t> tensor_index_; std::vector<DLTensor*> tensor_list_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); void Visit(const char* key, double* value) final {} void Visit(const char* key, int64_t* value) final {} void Visit(const char* key, uint64_t* value) final {} void Visit(const char* key, int* value) final {} void Visit(const char* key, bool* value) final {} void Visit(const char* key, std::string* value) final {} void Visit(const char* key, void** value) final {} void Visit(const char* key, DataType* value) final {} void Visit(const char* key, runtime::NDArray* value) final { DLTensor* ptr = const_cast<ffi::NDArrayObj*>((*value).operator->()); if (tensor_index_.count(ptr)) return; ICHECK_EQ(tensor_index_.size(), tensor_list_.size()); tensor_index_[ptr] = tensor_list_.size(); tensor_list_.push_back(ptr); } void Visit(const char* key, Optional<double>* value) final {} void Visit(const char* key, Optional<int64_t>* value) final {} void Visit(const char* key, ObjectRef* value) final { MakeIndex(Any(*value)); } void MakeNodeIndex(Any node) { if (node == nullptr) return; if (node_index_.count(node)) { return; } ICHECK_EQ(node_index_.size(), node_list_.size()); node_index_[node] = node_list_.size(); node_list_.push_back(node); } // make index of all the children of node void MakeIndex(Any node) { if (node == nullptr) return; if (node_index_.count(node)) { return; } MakeNodeIndex(node); if (auto opt_array = node.as<const ArrayObj*>()) { const ArrayObj* n = opt_array.value(); for (auto elem : *n) { MakeIndex(elem); } } else if (auto opt_map = node.as<const MapObj*>()) { const MapObj* n = opt_map.value(); bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { return v.first.template as<const ffi::StringObj*>().has_value(); }); if (is_str_map) { for (const auto& kv : *n) { MakeIndex(kv.second); } } else { for (const auto& kv : *n) { MakeIndex(kv.first); MakeIndex(kv.second); } } } else if (auto opt_object = node.as<const Object*>()) { Object* n = const_cast<Object*>(opt_object.value()); // if the node already have repr bytes, no need to visit Attrs. if (!reflection_->GetReprBytes(n, nullptr)) { reflection_->VisitAttrs(n, this); } } } }; // use map so attributes are ordered. using AttrMap = std::map<std::string, std::string>; /*! \brief Node structure for json format. */ struct JSONNode { /*! \brief The type of key of the object. */ std::string type_key; /*! \brief The str repr representation. */ std::string repr_bytes; /*! \brief the attributes */ AttrMap attrs; /*! \brief keys of a map. */ std::vector<std::string> keys; /*! \brief values of a map or array. */ std::vector<size_t> data; /*! * \brief field member dependency. * NOTE: This is an auxiliary data structure for loading, and it won't be serialized to json. */ std::vector<size_t> fields; void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("type_key", type_key); if (repr_bytes.size() != 0) { // choose to use str representation or base64, based on whether // the byte representation is printable. if (std::all_of(repr_bytes.begin(), repr_bytes.end(), [](char ch) { return std::isprint(ch); })) { writer->WriteObjectKeyValue("repr_str", repr_bytes); } else { writer->WriteObjectKeyValue("repr_b64", Base64Encode(repr_bytes)); } } if (attrs.size() != 0) { writer->WriteObjectKeyValue("attrs", attrs); } if (keys.size() != 0) { writer->WriteObjectKeyValue("keys", keys); } if (data.size() != 0) { writer->WriteObjectKeyValue("data", data); } writer->EndObject(); } void Load(dmlc::JSONReader* reader) { attrs.clear(); data.clear(); repr_bytes.clear(); type_key.clear(); std::string repr_b64, repr_str; dmlc::JSONObjectReadHelper helper; helper.DeclareOptionalField("type_key", &type_key); helper.DeclareOptionalField("repr_b64", &repr_b64); helper.DeclareOptionalField("repr_str", &repr_str); helper.DeclareOptionalField("attrs", &attrs); helper.DeclareOptionalField("keys", &keys); helper.DeclareOptionalField("data", &data); helper.ReadAllFields(reader); if (repr_str.size() != 0) { ICHECK_EQ(repr_b64.size(), 0U); repr_bytes = std::move(repr_str); } else if (repr_b64.size() != 0) { repr_bytes = Base64Decode(repr_b64); } } }; // Helper class to populate the json node // using the existing index. class JSONAttrGetter : public AttrVisitor { public: const std::unordered_map<Any, size_t, ffi::AnyHash, ffi::AnyEqual>* node_index_; const std::unordered_map<DLTensor*, size_t>* tensor_index_; JSONNode* node_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); void Visit(const char* key, double* value) final { std::ostringstream s; // Save 17 decimal digits for type <double> to avoid precision loss during loading JSON s.precision(17); s << (*value); node_->attrs[key] = s.str(); } void Visit(const char* key, int64_t* value) final { node_->attrs[key] = std::to_string(*value); } void Visit(const char* key, uint64_t* value) final { node_->attrs[key] = std::to_string(*value); } void Visit(const char* key, int* value) final { node_->attrs[key] = std::to_string(*value); } void Visit(const char* key, bool* value) final { node_->attrs[key] = std::to_string(*value); } void Visit(const char* key, std::string* value) final { node_->attrs[key] = *value; } void Visit(const char* key, void** value) final { LOG(FATAL) << "not allowed to serialize a pointer"; } void Visit(const char* key, DataType* value) final { node_->attrs[key] = Type2String(*value); } void Visit(const char* key, runtime::NDArray* value) final { node_->attrs[key] = std::to_string(tensor_index_->at(const_cast<ffi::NDArrayObj*>((*value).operator->()))); } void Visit(const char* key, Optional<int64_t>* value) final { if (value->has_value()) { node_->attrs[key] = std::to_string(value->value()); } else { node_->attrs[key] = "null"; } } void Visit(const char* key, Optional<double>* value) final { if (value->has_value()) { double val = **value; Visit(key, &val); } else { node_->attrs[key] = "null"; } } void Visit(const char* key, ObjectRef* value) final { node_->attrs[key] = std::to_string(node_index_->at(Any(*value))); } // Get the node void Get(Any node) { if (node == nullptr) { node_->type_key.clear(); return; } node_->type_key = node.GetTypeKey(); // populates the fields. node_->attrs.clear(); node_->data.clear(); if (auto opt_array = node.as<const ArrayObj*>()) { const ArrayObj* n = opt_array.value(); for (size_t i = 0; i < n->size(); ++i) { node_->data.push_back(node_index_->at(n->at(i))); } } else if (auto opt_map = node.as<const MapObj*>()) { const MapObj* n = opt_map.value(); bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { return v.first.template as<const ffi::StringObj*>().has_value(); }); if (is_str_map) { for (const auto& kv : *n) { node_->keys.push_back(kv.first.cast<String>()); node_->data.push_back(node_index_->at(kv.second)); } } else { for (const auto& kv : *n) { node_->data.push_back(node_index_->at(kv.first)); node_->data.push_back(node_index_->at(kv.second)); } } } else if (auto opt_object = node.as<const Object*>()) { Object* n = const_cast<Object*>(opt_object.value()); // do not need to print additional things once we have repr bytes. if (!reflection_->GetReprBytes(n, &(node_->repr_bytes))) { // recursively index normal object. reflection_->VisitAttrs(n, this); } } else { // handling primitive types // use switch since it is faster than if-else switch (node.type_index()) { case ffi::TypeIndex::kTVMFFIBool: case ffi::TypeIndex::kTVMFFIInt: { node_->attrs["v_int64"] = std::to_string(node.cast<int64_t>()); break; } case ffi::TypeIndex::kTVMFFIFloat: { node_->attrs["v_float64"] = std::to_string(node.cast<double>()); break; } case ffi::TypeIndex::kTVMFFIDataType: { node_->attrs["v_type"] = Type2String(DataType(node.cast<DLDataType>())); break; } case ffi::TypeIndex::kTVMFFIDevice: { DLDevice dev = node.cast<DLDevice>(); node_->attrs["v_device_type"] = std::to_string(dev.device_type); node_->attrs["v_device_id"] = std::to_string(dev.device_id); break; } default: { LOG(FATAL) << "Unsupported type: " << node.GetTypeKey(); } } } } }; class FieldDependencyFinder : public AttrVisitor { public: JSONNode* jnode_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; } template <typename T> void ParseValue(const char* key, T* value) const { std::istringstream is(GetValue(key)); is >> *value; if (is.fail()) { LOG(FATAL) << "Wrong value format for field " << key; } } void Visit(const char* key, double* value) final {} void Visit(const char* key, int64_t* value) final {} void Visit(const char* key, uint64_t* value) final {} void Visit(const char* key, int* value) final {} void Visit(const char* key, bool* value) final {} void Visit(const char* key, std::string* value) final {} void Visit(const char* key, void** value) final {} void Visit(const char* key, DataType* value) final {} void Visit(const char* key, runtime::NDArray* value) final {} void Visit(const char* key, Optional<int64_t>* value) final {} void Visit(const char* key, Optional<double>* value) final {} void Visit(const char* key, ObjectRef* value) final { size_t index; ParseValue(key, &index); jnode_->fields.push_back(index); } void Find(Any node, JSONNode* jnode) { // Skip None if (node == nullptr) { return; } if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { return; } // Skip the objects that have their own string repr if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(node.cast<const Object*>(), nullptr)) { return; } // Skip containers if (jnode->type_key == ArrayObj::_type_key || jnode->type_key == MapObj::_type_key) { return; } jnode_ = jnode; if (auto opt_object = node.as<const Object*>()) { Object* n = const_cast<Object*>(opt_object.value()); reflection_->VisitAttrs(n, this); } } }; // Helper class to set the attributes of a node // from given json node. class JSONAttrSetter : public AttrVisitor { public: const std::vector<Any>* node_list_; const std::vector<runtime::NDArray>* tensor_list_; JSONNode* jnode_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; } void ParseDouble(const char* key, double* value) const { std::istringstream is(GetValue(key)); if (is.str() == "inf") { *value = std::numeric_limits<double>::infinity(); } else if (is.str() == "-inf") { *value = -std::numeric_limits<double>::infinity(); } else { is >> *value; if (is.fail()) { LOG(FATAL) << "Wrong value format for field " << key; } } } template <typename T> void ParseValue(const char* key, T* value) const { std::istringstream is(GetValue(key)); is >> *value; if (is.fail()) { LOG(FATAL) << "Wrong value format for field " << key; } } template <typename T, typename Fallback> void ParseOptionalValue(const char* key, Optional<T>* value, Fallback fallback) const { if (GetValue(key) == "null") { *value = std::nullopt; } else { T temp; fallback(key, &temp); *value = temp; } } void Visit(const char* key, double* value) final { ParseDouble(key, value); } void Visit(const char* key, int64_t* value) final { ParseValue(key, value); } void Visit(const char* key, uint64_t* value) final { ParseValue(key, value); } void Visit(const char* key, int* value) final { ParseValue(key, value); } void Visit(const char* key, bool* value) final { ParseValue(key, value); } void Visit(const char* key, std::string* value) final { *value = GetValue(key); } void Visit(const char* key, Optional<double>* value) final { ParseOptionalValue<double>(key, value, [this](const char* key, double* value) { ParseDouble(key, value); }); } void Visit(const char* key, Optional<int64_t>* value) final { ParseOptionalValue<int64_t>( key, value, [this](const char* key, int64_t* value) { ParseValue(key, value); }); } void Visit(const char* key, void** value) final { LOG(FATAL) << "not allowed to deserialize a pointer"; } void Visit(const char* key, DataType* value) final { std::string stype = GetValue(key); *value = String2Type(stype); } void Visit(const char* key, runtime::NDArray* value) final { size_t index; ParseValue(key, &index); ICHECK_LE(index, tensor_list_->size()); *value = tensor_list_->at(index); } void Visit(const char* key, ObjectRef* value) final { size_t index; ParseValue(key, &index); ICHECK_LE(index, node_list_->size()); *value = node_list_->at(index).cast<ObjectRef>(); } static Any CreateInitAny(ReflectionVTable* reflection, JSONNode* jnode) { JSONAttrSetter setter; setter.jnode_ = jnode; if (jnode->type_key == ffi::StaticTypeKey::kTVMFFINone || jnode->type_key.empty()) { // empty key type means None in current implementation return Any(); } if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBool) { int64_t value; setter.ParseValue("v_int64", &value); return Any(static_cast<bool>(value)); } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIInt) { int64_t value; setter.ParseValue("v_int64", &value); return Any(value); } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIFloat) { double value; setter.ParseValue("v_float64", &value); return Any(value); } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIDataType) { std::string value; setter.ParseValue("v_type", &value); return Any(String2Type(value).operator DLDataType()); } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIDevice) { int32_t device_type; int32_t device_id; setter.ParseValue("v_device_type", &device_type); setter.ParseValue("v_device_id", &device_id); return Any(DLDevice{static_cast<DLDeviceType>(device_type), device_id}); } else { return ObjectRef(reflection->CreateInitObject(jnode->type_key, jnode->repr_bytes)); } } // set node to be current JSONNode void SetAttrs(Any* node, JSONNode* jnode) { jnode_ = jnode; // handling Array if (jnode->type_key == ArrayObj::_type_key) { Array<Any> result; for (auto index : jnode->data) { result.push_back(node_list_->at(index)); } *node = result; } else if (jnode->type_key == MapObj::_type_key) { Map<Any, Any> result; if (jnode->keys.empty()) { ICHECK_EQ(jnode->data.size() % 2, 0U); for (size_t i = 0; i < jnode->data.size(); i += 2) { result.Set(node_list_->at(jnode->data[i]), node_list_->at(jnode->data[i + 1])); } } else { ICHECK_EQ(jnode->data.size(), jnode->keys.size()); for (size_t i = 0; i < jnode->data.size(); ++i) { result.Set(String(jnode->keys[i]), node_list_->at(jnode->data[i])); } } *node = result; } else if (auto opt_object = node->as<const Object*>()) { Object* n = const_cast<Object*>(opt_object.value()); if (n == nullptr) return; // Skip the objects that have their own string repr if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(n, nullptr)) { return; } reflection_->VisitAttrs(n, this); } } }; // json graph structure to store node struct JSONGraph { // the root of the graph size_t root; // the nodes of the graph std::vector<JSONNode> nodes; // base64 b64ndarrays of arrays std::vector<std::string> b64ndarrays; // global attributes AttrMap attrs; void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("root", root); writer->WriteObjectKeyValue("nodes", nodes); writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays); if (attrs.size() != 0) { writer->WriteObjectKeyValue("attrs", attrs); } writer->EndObject(); } void Load(dmlc::JSONReader* reader) { attrs.clear(); dmlc::JSONObjectReadHelper helper; helper.DeclareField("root", &root); helper.DeclareField("nodes", &nodes); helper.DeclareOptionalField("b64ndarrays", &b64ndarrays); helper.DeclareOptionalField("attrs", &attrs); helper.ReadAllFields(reader); } static JSONGraph Create(Any root) { JSONGraph g; NodeIndexer indexer; indexer.MakeIndex(root); JSONAttrGetter getter; getter.node_index_ = &indexer.node_index_; getter.tensor_index_ = &indexer.tensor_index_; for (Any n : indexer.node_list_) { JSONNode jnode; getter.node_ = &jnode; getter.Get(n); g.nodes.emplace_back(std::move(jnode)); } g.attrs["tvm_version"] = TVM_VERSION; g.root = indexer.node_index_.at(root); // serialize tensor for (DLTensor* tensor : indexer.tensor_list_) { std::string blob; dmlc::MemoryStringStream mstrm(&blob); support::Base64OutStream b64strm(&mstrm); runtime::SaveDLTensor(&b64strm, tensor); b64strm.Finish(); g.b64ndarrays.emplace_back(std::move(blob)); } return g; } std::vector<size_t> TopoSort() const { size_t n_nodes = nodes.size(); std::vector<size_t> topo_order; std::vector<size_t> in_degree(n_nodes, 0); for (const JSONNode& jnode : nodes) { for (size_t i : jnode.data) { ++in_degree[i]; } for (size_t i : jnode.fields) { ++in_degree[i]; } } for (size_t i = 0; i < n_nodes; ++i) { if (in_degree[i] == 0) { topo_order.push_back(i); } } for (size_t p = 0; p < topo_order.size(); ++p) { const JSONNode& jnode = nodes[topo_order[p]]; for (size_t i : jnode.data) { if (--in_degree[i] == 0) { topo_order.push_back(i); } } for (size_t i : jnode.fields) { if (--in_degree[i] == 0) { topo_order.push_back(i); } } } ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file"; std::reverse(std::begin(topo_order), std::end(topo_order)); return topo_order; } }; std::string SaveJSON(Any n) { auto jgraph = JSONGraph::Create(n); std::ostringstream os; dmlc::JSONWriter writer(&os); jgraph.Save(&writer); return os.str(); } Any LoadJSON(std::string json_str) { ReflectionVTable* reflection = ReflectionVTable::Global(); JSONGraph jgraph; { // load in json graph. std::istringstream is(json_str); dmlc::JSONReader reader(&is); jgraph.Load(&reader); } size_t n_nodes = jgraph.nodes.size(); std::vector<runtime::NDArray> tensors; { // load in tensors for (const std::string& blob : jgraph.b64ndarrays) { dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob)); support::Base64InStream b64strm(&mstrm); b64strm.InitPosition(); runtime::NDArray temp; ICHECK(temp.Load(&b64strm)); tensors.emplace_back(std::move(temp)); } } // Pass 1: create all non-container objects std::vector<Any> nodes(n_nodes, nullptr); for (size_t i = 0; i < n_nodes; ++i) { nodes[i] = JSONAttrSetter::CreateInitAny(reflection, &(jgraph.nodes[i])); } // Pass 2: figure out all field dependency { FieldDependencyFinder dep_finder; for (size_t i = 0; i < n_nodes; ++i) { dep_finder.Find(nodes[i], &jgraph.nodes[i]); } } // Pass 3: topo sort std::vector<size_t> topo_order = jgraph.TopoSort(); // Pass 4: set all values { JSONAttrSetter setter; setter.node_list_ = &nodes; setter.tensor_list_ = &tensors; for (size_t i : topo_order) { setter.SetAttrs(&nodes[i], &jgraph.nodes[i]); } } return nodes.at(jgraph.root); } TVM_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON); TVM_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON); } // namespace tvm