src/contrib/msc/core/ir/plugin.cc (268 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/contrib/msc/core/ir/plugin.cc
*/
#include "plugin.h"
#include <algorithm>
#include <map>
#include <queue>
#include <set>
#include <utility>
namespace tvm {
namespace contrib {
namespace msc {
PluginAttr::PluginAttr(const String& name, const String& type, const String& default_value,
const String& describe) {
ObjectPtr<PluginAttrNode> n = make_object<PluginAttrNode>();
n->name = std::move(name);
n->type = std::move(type);
n->default_value = std::move(default_value);
n->describe = std::move(describe);
data_ = std::move(n);
}
PluginAttr::PluginAttr(const JsonPluginAttr& j_attr) {
ObjectPtr<PluginAttrNode> n = make_object<PluginAttrNode>();
n->FromJson(j_attr);
data_ = std::move(n);
}
PluginAttr::PluginAttr(const std::string& json_str) {
ObjectPtr<PluginAttrNode> n = make_object<PluginAttrNode>();
n->FromJson(json_str);
data_ = std::move(n);
}
const JsonPluginAttr PluginAttrNode::ToJson() const {
JsonPluginAttr j_attr;
j_attr.name = name;
j_attr.type = type;
j_attr.default_value = default_value;
j_attr.describe = describe;
return j_attr;
}
void PluginAttrNode::FromJson(const JsonPluginAttr& j_attr) {
name = j_attr.name;
type = j_attr.type;
default_value = j_attr.default_value;
describe = j_attr.describe;
}
void PluginAttrNode::FromJson(const std::string& json_str) {
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JsonPluginAttr j_attr;
reader.Read(&j_attr);
FromJson(j_attr);
}
PluginTensor::PluginTensor(const String& name, const String& dtype, const Integer& ndim,
const String& device, const String& describe) {
ObjectPtr<PluginTensorNode> n = make_object<PluginTensorNode>();
n->name = std::move(name);
n->dtype = std::move(dtype);
n->ndim = std::move(ndim);
n->device = std::move(device);
n->describe = std::move(describe);
data_ = std::move(n);
}
PluginTensor::PluginTensor(const JsonPluginTensor& j_tensor) {
ObjectPtr<PluginTensorNode> n = make_object<PluginTensorNode>();
n->FromJson(j_tensor);
data_ = std::move(n);
}
PluginTensor::PluginTensor(const std::string& json_str) {
ObjectPtr<PluginTensorNode> n = make_object<PluginTensorNode>();
n->FromJson(json_str);
data_ = std::move(n);
}
const JsonPluginTensor PluginTensorNode::ToJson() const {
JsonPluginTensor j_tensor;
j_tensor.name = name;
j_tensor.dtype = dtype;
j_tensor.ndim = ndim->value;
j_tensor.device = device;
j_tensor.describe = describe;
return j_tensor;
}
void PluginTensorNode::FromJson(const JsonPluginTensor& j_tensor) {
name = j_tensor.name;
dtype = j_tensor.dtype;
ndim = Integer(j_tensor.ndim);
device = j_tensor.device;
describe = j_tensor.describe;
}
void PluginTensorNode::FromJson(const std::string& json_str) {
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JsonPluginTensor j_tensor;
reader.Read(&j_tensor);
FromJson(j_tensor);
}
PluginExtern::PluginExtern(const String& name, const String& header, const String& source,
const String& lib, const String& describe) {
ObjectPtr<PluginExternNode> n = make_object<PluginExternNode>();
n->name = std::move(name);
n->header = std::move(header);
n->source = std::move(source);
n->lib = std::move(lib);
n->describe = std::move(describe);
data_ = std::move(n);
}
PluginExtern::PluginExtern(const JsonPluginExtern& j_extern) {
ObjectPtr<PluginExternNode> n = make_object<PluginExternNode>();
n->FromJson(j_extern);
data_ = std::move(n);
}
PluginExtern::PluginExtern(const std::string& json_str) {
ObjectPtr<PluginExternNode> n = make_object<PluginExternNode>();
n->FromJson(json_str);
data_ = std::move(n);
}
const JsonPluginExtern PluginExternNode::ToJson() const {
JsonPluginExtern j_extern;
j_extern.name = name;
j_extern.header = header;
j_extern.source = source;
j_extern.lib = lib;
j_extern.describe = describe;
return j_extern;
}
void PluginExternNode::FromJson(const JsonPluginExtern& j_extern) {
name = j_extern.name;
header = j_extern.header;
source = j_extern.source;
lib = j_extern.lib;
describe = j_extern.describe;
}
void PluginExternNode::FromJson(const std::string& json_str) {
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JsonPluginExtern j_extern;
reader.Read(&j_extern);
FromJson(j_extern);
}
Plugin::Plugin(const String& name, const String& version, const String& describe,
const Array<PluginAttr>& attrs, const Array<PluginTensor>& inputs,
const Array<PluginTensor>& outputs, const Array<PluginTensor>& buffers,
const Map<String, PluginExtern>& externs,
const Map<String, Array<String>>& support_dtypes,
const Map<String, String>& options) {
ObjectPtr<PluginNode> n = make_object<PluginNode>();
n->name = std::move(name);
n->version = std::move(version);
n->describe = std::move(describe);
n->attrs = std::move(attrs);
n->inputs = std::move(inputs);
n->outputs = std::move(outputs);
n->buffers = std::move(buffers);
n->externs = std::move(externs);
n->support_dtypes = std::move(support_dtypes);
n->options = std::move(options);
data_ = std::move(n);
}
Plugin::Plugin(const JsonPlugin& j_plugin) {
ObjectPtr<PluginNode> n = make_object<PluginNode>();
n->FromJson(j_plugin);
data_ = std::move(n);
}
Plugin::Plugin(const std::string& json_str) {
ObjectPtr<PluginNode> n = make_object<PluginNode>();
n->FromJson(json_str);
data_ = std::move(n);
}
const JsonPlugin PluginNode::ToJson() const {
JsonPlugin j_plugin;
j_plugin.name = name;
j_plugin.version = version;
j_plugin.describe = describe;
for (const auto& a : attrs) {
j_plugin.attrs.push_back(a->ToJson());
}
for (const auto& t : inputs) {
j_plugin.inputs.push_back(t->ToJson());
}
for (const auto& t : outputs) {
j_plugin.inputs.push_back(t->ToJson());
}
for (const auto& t : buffers) {
j_plugin.inputs.push_back(t->ToJson());
}
for (const auto& pair : externs) {
j_plugin.externs[pair.first] = pair.second->ToJson();
}
for (const auto& pair : support_dtypes) {
std::vector<std::string> dtypes;
for (const auto& d : pair.second) {
dtypes.push_back(d);
}
j_plugin.support_dtypes[pair.first] = dtypes;
}
for (const auto& pair : options) {
j_plugin.options[pair.first] = pair.second;
}
return j_plugin;
}
void PluginNode::FromJson(const JsonPlugin& j_plugin) {
name = j_plugin.name;
version = j_plugin.version;
describe = j_plugin.describe;
for (const auto& a : j_plugin.attrs) {
attrs.push_back(PluginAttr(a));
}
for (const auto& t : j_plugin.inputs) {
inputs.push_back(PluginTensor(t));
}
for (const auto& t : j_plugin.outputs) {
outputs.push_back(PluginTensor(t));
}
for (const auto& t : j_plugin.buffers) {
buffers.push_back(PluginTensor(t));
}
for (const auto& pair : j_plugin.externs) {
externs.Set(pair.first, PluginExtern(pair.second));
}
for (const auto& pair : j_plugin.support_dtypes) {
Array<String> dtypes;
for (const auto& d : pair.second) {
dtypes.push_back(d);
}
support_dtypes.Set(pair.first, dtypes);
}
for (const auto& pair : j_plugin.options) {
options.Set(pair.first, pair.second);
}
}
void PluginNode::FromJson(const std::string& json_str) {
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JsonPlugin j_plugin;
reader.Read(&j_plugin);
FromJson(j_plugin);
}
int PluginNode::FindDtypeRefIdx(const PluginTensor& tensor) const {
for (size_t i = 0; i < inputs.size(); i++) {
if (inputs[i]->dtype == tensor->dtype) {
return i;
}
}
return -1;
}
int PluginNode::FindDeviceRefIdx(const PluginTensor& tensor) const {
for (size_t i = 0; i < inputs.size(); i++) {
if (inputs[i]->device == tensor->device) {
return i;
}
}
return -1;
}
const Array<String> ListPluginNames() { return PluginRegistry::Global()->ListAllNames(); }
const Plugin GetPlugin(const String& name) { return PluginRegistry::Global()->Get(name); }
bool IsPlugin(const String& name) { return PluginRegistry::Global()->Registered(name); }
TVM_REGISTER_GLOBAL("msc.core.RegisterPlugin")
.set_body_typed([](const String& name, const String& json_str) {
PluginRegistry::Global()->Register(name, json_str);
});
TVM_REGISTER_GLOBAL("msc.core.ListPluginNames").set_body_typed([]() -> Array<String> {
return ListPluginNames();
});
TVM_REGISTER_GLOBAL("msc.core.GetPlugin").set_body_typed([](const String& name) -> Plugin {
return GetPlugin(name);
});
TVM_REGISTER_GLOBAL("msc.core.IsPlugin").set_body_typed([](const String& name) -> Bool {
return Bool(IsPlugin(name));
});
} // namespace msc
} // namespace contrib
} // namespace tvm