in Source/ComputationNetworkLib/ComputationNetwork.cpp [1226:1574]
void ComputationNetwork::SaveToDbnFile(ComputationNetworkPtr net, const std::wstring& fileName) const
{
// Helper methods
auto VerifyTypeAll = [](const std::vector<ComputationNodeBasePtr>& nodes, const std::wstring& typeValue) -> bool
{
return std::find_if(nodes.begin(), nodes.end(), [&typeValue](ComputationNodeBasePtr node)->bool { return node->OperationName() != typeValue; }) == nodes.end();
};
auto GetNodeConsumers = [&net](const ComputationNodeBasePtr node) -> std::vector<ComputationNodeBasePtr>
{
std::vector<ComputationNodeBasePtr> consumers;
for (auto& item : net->GetAllNodes())
{
for (auto& input : item->GetInputs())
{
if (input == node)
{
consumers.push_back(item);
break;
}
}
}
return consumers;
};
auto GetFirstDifferentNode = [](const std::vector<ComputationNodeBasePtr>& list, const ComputationNodeBasePtr node) -> ComputationNodeBasePtr
{
auto foundNode = std::find_if(list.begin(), list.end(), [&node](ComputationNodeBasePtr item)->bool { return item != node; });
return foundNode == list.end() ? nullptr : *foundNode;
};
auto GetFirstNodeWithDifferentType = [](const std::vector<ComputationNodeBasePtr>& list, const std::wstring& type) -> ComputationNodeBasePtr
{
auto foundNode = std::find_if(list.begin(), list.end(), [&type](ComputationNodeBasePtr item)->bool { return item->OperationName() != type; });
return foundNode == list.end() ? nullptr : *foundNode;
};
auto WhereNode = [](const std::vector<ComputationNodeBasePtr>& nodes, const function<bool(ComputationNodeBasePtr)>& predicate) -> std::vector<ComputationNodeBasePtr>
{
std::vector<ComputationNodeBasePtr> results;
for (auto& node : nodes)
{
if (predicate(node))
{
results.push_back(node);
}
}
return results;
};
auto GetNodesWithType = [](const std::vector<ComputationNodeBasePtr>& list, const std::wstring& type) -> std::vector<ComputationNodeBasePtr>
{
std::vector<ComputationNodeBasePtr> results;
for (auto& node : list)
{
if (node->OperationName() == type)
{
results.push_back(node);
}
}
return results;
};
auto GetAllPriorNodes = [](ComputationNodeBasePtr node)->bool
{
std::wstring lowerName = node->GetName();
std::transform(lowerName.begin(), lowerName.end(), lowerName.begin(), [](wchar_t v) { return (wchar_t)::tolower(v); });
return node->OperationName() == OperationNameOf(LearnableParameter) && (lowerName.find(L"prior") != wstring::npos);
};
auto FindReplicationContext = [](std::vector<ElemType>& arr)->int
{
for (int i = 25; i >= 1; i--)
{
int ctx = i * 2 + 1;
if (arr.size() % ctx != 0)
continue;
int baseLen = arr.size() / ctx;
bool matched = true;
for (int k = 1; k < ctx && matched; k++)
{
for (int j = 0; j < baseLen; j++)
{
if (arr[j] != arr[k * baseLen + j])
{
matched = false;
break;
}
}
}
if (matched)
return ctx;
}
return 1;
};
// Get output node
std::list<ComputationNodeBasePtr> outputNodes = net->GetNodesWithType(OperationNameOf(ClassificationErrorNode));
ComputationNodeBasePtr outputNode = GetFirstNodeWithDifferentType(outputNodes.front()->GetInputs(), OperationNameOf(InputValue));
if (outputNode == nullptr)
{
RuntimeError("Cannot find output node");
}
std::list<ComputationNodeBasePtr> orderList;
std::stack<ComputationNodeBasePtr> nodeStack;
nodeStack.push(outputNode);
while (nodeStack.size() > 0)
{
auto node = nodeStack.top();
nodeStack.pop();
auto nodeInputs = node->GetInputs();
for (auto& input : nodeInputs)
{
bool cyclic = false;
for (auto& item : orderList)
{
if (item == input)
{
Warning("Cyclic dependency on node '%ls'\n", item->GetName().c_str());
cyclic = true;
}
}
if (!cyclic)
nodeStack.push(input);
}
orderList.push_back(node);
}
orderList.reverse();
// All multiplication nodes that multiply a symbolic variable
std::list<ComputationNodeBasePtr> multNodes;
typedef shared_ptr<DbnLayer> DbnLayerPtr;
std::list<DbnLayerPtr> dbnLayers;
for (auto& item : orderList)
{
if (item->OperationName() == OperationNameOf(TimesNode) && !VerifyTypeAll(item->GetInputs(), OperationNameOf(LearnableParameter)))
{
multNodes.push_back(item);
}
}
for (auto& node : multNodes)
{
std::vector<ComputationNodeBasePtr> consumers = GetNodeConsumers(node);
if (consumers.size() == 1)
{
bool sigmoided = false;
std::wstring layerId(node->GetName());
ComputationNodeBasePtr firstConsumer = consumers.front();
if (firstConsumer->OperationName() != OperationNameOf(PlusNode))
{
RuntimeError("Expected a plus node to consume the times node.");
}
ComputationNodeBasePtr bias = GetFirstDifferentNode(firstConsumer->GetInputs(), node);
auto consumer2 = GetNodeConsumers(consumers.front()).front();
if (consumer2->OperationName() == L"Sigmoid")
{
sigmoided = true;
layerId = consumer2->GetName();
}
else
{
layerId = firstConsumer->GetName();
}
// If one of its inputs was itself a multiplication node, then split it out
// into dbn-style.
std::vector<ComputationNodeBasePtr> aggTimes = GetNodesWithType(node->GetInputs(), OperationNameOf(TimesNode));
if (aggTimes.size() > 0)
{
ComputationNodeBasePtr multNode = aggTimes.front();
DbnLayerPtr l1 = make_shared<DbnLayer>();
DbnLayerPtr l2 = make_shared<DbnLayer>();
auto firstInput = multNode->GetInputs()[0];
auto secondInput = multNode->GetInputs()[1];
l2->Bias = bias;
l2->Node = firstInput;
l1->Bias = nullptr;
l1->Node = secondInput;
l1->Sigmoided = false;
l2->Sigmoided = sigmoided;
dbnLayers.push_back(l1);
dbnLayers.push_back(l2);
}
else
{
auto paramNode = GetNodesWithType(node->GetInputs(), OperationNameOf(LearnableParameter)).front();
DbnLayerPtr l1 = make_shared<DbnLayer>();
l1->Bias = bias;
l1->Node = paramNode;
l1->Sigmoided = sigmoided;
dbnLayers.push_back(l1);
}
}
}
// Write the layers to the output
// DBN wants std not invstd, so need to invert each element
std::vector<ComputationNodeBasePtr> normalizationNodes = GetNodesWithType(net->GetAllNodes(), OperationNameOf(PerDimMeanVarNormalizationNode));
if (normalizationNodes.size() == 0)
{
RuntimeError("Model does not contain at least one node with the '%ls' operation.", OperationNameOf(PerDimMeanVarNormalizationNode).c_str());
}
ComputationNodeBasePtr meanNode = normalizationNodes.front()->GetInputs()[1];
ComputationNodeBasePtr stdNode = normalizationNodes.front()->GetInputs()[2];
Matrix<ElemType> meanNodeMatrix = meanNode->As<ComputationNode<ElemType>>()->Value().DeepClone();
Matrix<ElemType> invStdNodeMatrix(std::move(stdNode->As<ComputationNode<ElemType>>()->Value().DeepClone().ElementInverse()));
std::vector<ElemType> arr(invStdNodeMatrix.GetNumElements());
ElemType* refArr = &arr[0];
size_t arrSize = arr.size();
invStdNodeMatrix.CopyToArray(refArr, arrSize);
int ctx = FindReplicationContext(arr);
std::vector<ComputationNodeBasePtr> priorNodes = WhereNode(net->GetAllNodes(), GetAllPriorNodes);
if (priorNodes.size() != 1)
{
Warning("Could not reliably determine the prior node!");
}
// =================
// Write to the file
// =================
File fstream(fileName, FileOptions::fileOptionsBinary | FileOptions::fileOptionsWrite);
// local helper functions for writing stuff in DBN.exe-expected format
auto PutTag = [&fstream](const char * tag) { while (*tag) fstream << *tag++; };
auto PutString = [&fstream](const char * string) { fstream.WriteString(string, 0); };
auto PutInt = [&fstream](int val) { fstream << val; };
// write a DBN matrix object, optionally applying a function
auto PutMatrixConverted = [&](const Matrix<ElemType> * m, size_t maxelem, const char * name, float(*f)(float))
{
PutTag("BMAT");
PutString(name);
size_t numRows = m->GetNumRows();
size_t numCols = m->GetNumCols();
if (maxelem == SIZE_MAX)
{
PutInt(numRows);
PutInt(numCols);
}
else // this allows to shorten a vector, as we need for mean/invstd
{
PutInt(maxelem);
PutInt(1);
}
// this code transposes the matrix on the fly, and outputs at most maxelem floating point numbers to the stream
size_t k = 0;
for (size_t j = 0; j < numCols && k < maxelem; j++)
for (size_t i = 0; i < numRows && k < maxelem; i++, k++)
fstream << f((float)(*m)(i, j));
PutTag("EMAT");
};
auto PutMatrix = [&](const Matrix<ElemType> * m, const char * name) { PutMatrixConverted(m, SIZE_MAX, name, [](float v) { return v; }); };
// write out the data
// Dump DBN header
PutString("DBN");
PutTag("BDBN");
PutInt(0); // a version number
PutInt(static_cast<int>(dbnLayers.size())); // number of layers
// Dump feature norm
PutMatrixConverted(&meanNodeMatrix, meanNodeMatrix.GetNumRows() / ctx, "gmean", [](float v) { return v; });
PutMatrixConverted(&invStdNodeMatrix, invStdNodeMatrix.GetNumRows() / ctx, "gstddev", [](float v) { return v; });
PutTag("BNET");
auto lastOne = dbnLayers.end();
--lastOne;
for (auto ii = dbnLayers.begin(), e = dbnLayers.end(); ii != e; ++ii)
{
DbnLayerPtr& layer = *ii;
if (ii == dbnLayers.begin())
{
PutString("rbmgaussbernoulli");
}
else if (ii == lastOne)
{
PutString("perceptron");
}
else if (layer->Sigmoided)
{
PutString("rbmbernoullibernoulli");
}
else
{
PutString("rbmisalinearbernoulli");
}
// Write out the main weight matrix
auto weight = (layer->Node->As<ComputationNode<ElemType>>()->Value().DeepClone());
auto transpose = weight.Transpose();
PutMatrix(&transpose, "W");
// Write out biasing vector
// Is mandatory, so pack with zeroes if not given
auto rows = layer->Node->GetAsMatrixNumRows();
if (layer->Bias == nullptr)
{
auto zeros = Matrix<ElemType>::Zeros(rows, 1, CPUDEVICE);
PutMatrixConverted(&zeros, rows, "a", [](float v) { return v; });
}
else
{
PutMatrixConverted(&(layer->Bias->As<ComputationNode<ElemType>>()->Value()), rows, "a", [](float v) { return v; });
}
// Some sort of legacy vector that is useless
auto zeros = Matrix<ElemType>::Zeros(0, 0, CPUDEVICE);
PutMatrix(&(zeros), "b");
}
// Dump the priors
PutTag("ENET");
if (priorNodes.size() > 0)
{
PutMatrix(&(priorNodes.front()->As<ComputationNode<ElemType>>()->Value()), "Pu");
}
else
{
Warning("No priority node(s) found!");
}
PutTag("EDBN");
}