void ComputationNetwork::SaveToDbnFile()

in Source/ComputationNetworkLib/ComputationNetwork.cpp [1226:1574]


void ComputationNetwork::SaveToDbnFile(ComputationNetworkPtr net, const std::wstring& fileName) const 
{
    // Helper methods
    auto VerifyTypeAll = [](const std::vector<ComputationNodeBasePtr>& nodes, const std::wstring& typeValue) -> bool
    {
        return std::find_if(nodes.begin(), nodes.end(), [&typeValue](ComputationNodeBasePtr node)->bool { return node->OperationName() != typeValue; }) == nodes.end();
    };
    auto GetNodeConsumers = [&net](const ComputationNodeBasePtr node) -> std::vector<ComputationNodeBasePtr>
    {
        std::vector<ComputationNodeBasePtr> consumers;
        for (auto& item : net->GetAllNodes())
        {
            for (auto& input : item->GetInputs())
            {
                if (input == node)
                {
                    consumers.push_back(item);
                    break;
                }
            }
        }

        return consumers;
    };
    auto GetFirstDifferentNode = [](const std::vector<ComputationNodeBasePtr>& list, const ComputationNodeBasePtr node) -> ComputationNodeBasePtr
    {
        auto foundNode = std::find_if(list.begin(), list.end(), [&node](ComputationNodeBasePtr item)->bool { return item != node; });
        return foundNode == list.end() ? nullptr : *foundNode;
    };
    auto GetFirstNodeWithDifferentType = [](const std::vector<ComputationNodeBasePtr>& list, const std::wstring& type) -> ComputationNodeBasePtr
    {
        auto foundNode = std::find_if(list.begin(), list.end(), [&type](ComputationNodeBasePtr item)->bool { return item->OperationName() != type; });
        return foundNode == list.end() ? nullptr : *foundNode;
    };
    auto WhereNode = [](const std::vector<ComputationNodeBasePtr>& nodes, const function<bool(ComputationNodeBasePtr)>& predicate) -> std::vector<ComputationNodeBasePtr>
    {
        std::vector<ComputationNodeBasePtr> results;

        for (auto& node : nodes)
        {
            if (predicate(node))
            {
                results.push_back(node);
            }
        }

        return results;
    };
    auto GetNodesWithType = [](const std::vector<ComputationNodeBasePtr>& list, const std::wstring& type) -> std::vector<ComputationNodeBasePtr>
    {
        std::vector<ComputationNodeBasePtr> results;

        for (auto& node : list)
        {
            if (node->OperationName() == type)
            {
                results.push_back(node);
            }
        }

        return results;
    };
    auto GetAllPriorNodes = [](ComputationNodeBasePtr node)->bool
    {
        std::wstring lowerName = node->GetName();
        std::transform(lowerName.begin(), lowerName.end(), lowerName.begin(), [](wchar_t v) { return (wchar_t)::tolower(v); });

        return node->OperationName() == OperationNameOf(LearnableParameter) && (lowerName.find(L"prior") != wstring::npos);
    };
    auto FindReplicationContext = [](std::vector<ElemType>& arr)->int
    {
        for (int i = 25; i >= 1; i--)
        {
            int ctx = i * 2 + 1;
            if (arr.size() % ctx != 0)
                continue;

            int baseLen = arr.size() / ctx;
            bool matched = true;

            for (int k = 1; k < ctx && matched; k++)
            {
                for (int j = 0; j < baseLen; j++)
                {
                    if (arr[j] != arr[k * baseLen + j])
                    {
                        matched = false;
                        break;
                    }
                }
            }

            if (matched)
                return ctx;
        }

        return 1;
    };

    // Get output node
    std::list<ComputationNodeBasePtr> outputNodes = net->GetNodesWithType(OperationNameOf(ClassificationErrorNode));
    ComputationNodeBasePtr outputNode = GetFirstNodeWithDifferentType(outputNodes.front()->GetInputs(), OperationNameOf(InputValue));

    if (outputNode == nullptr)
    {
        RuntimeError("Cannot find output node");
    }

    std::list<ComputationNodeBasePtr> orderList;
    std::stack<ComputationNodeBasePtr> nodeStack;

    nodeStack.push(outputNode);

    while (nodeStack.size() > 0)
    {
        auto node = nodeStack.top();
        nodeStack.pop();
        auto nodeInputs = node->GetInputs();
        for (auto& input : nodeInputs)
        {
            bool cyclic = false;
            for (auto& item : orderList)
            {
                if (item == input)
                {
                    Warning("Cyclic dependency on node '%ls'\n", item->GetName().c_str());
                    cyclic = true;
                }
            }

            if (!cyclic)
                nodeStack.push(input);
        }
        orderList.push_back(node);
    }

    orderList.reverse();

    // All multiplication nodes that multiply a symbolic variable
    std::list<ComputationNodeBasePtr> multNodes;
    typedef shared_ptr<DbnLayer> DbnLayerPtr;
    std::list<DbnLayerPtr> dbnLayers;

    for (auto& item : orderList)
    {
        if (item->OperationName() == OperationNameOf(TimesNode) && !VerifyTypeAll(item->GetInputs(), OperationNameOf(LearnableParameter)))
        {
            multNodes.push_back(item);
        }
    }

    for (auto& node : multNodes)
    {
        std::vector<ComputationNodeBasePtr> consumers = GetNodeConsumers(node);
        if (consumers.size() == 1)
        {
            bool sigmoided = false;
            std::wstring layerId(node->GetName());

            ComputationNodeBasePtr firstConsumer = consumers.front();

            if (firstConsumer->OperationName() != OperationNameOf(PlusNode))
            {
                RuntimeError("Expected a plus node to consume the times node.");
            }

            ComputationNodeBasePtr bias = GetFirstDifferentNode(firstConsumer->GetInputs(), node);

            auto consumer2 = GetNodeConsumers(consumers.front()).front();
            if (consumer2->OperationName() == L"Sigmoid")
            {
                sigmoided = true;
                layerId = consumer2->GetName();
            }
            else
            {
                layerId = firstConsumer->GetName();
            }

            // If one of its inputs was itself a multiplication node, then split it out
            // into dbn-style.  
            std::vector<ComputationNodeBasePtr> aggTimes = GetNodesWithType(node->GetInputs(), OperationNameOf(TimesNode));
            if (aggTimes.size() > 0)
            {
                ComputationNodeBasePtr multNode = aggTimes.front();
                DbnLayerPtr l1 = make_shared<DbnLayer>();
                DbnLayerPtr l2 = make_shared<DbnLayer>();

                auto firstInput = multNode->GetInputs()[0];
                auto secondInput = multNode->GetInputs()[1];
                l2->Bias = bias;
                l2->Node = firstInput;

                l1->Bias = nullptr;
                l1->Node = secondInput;

                l1->Sigmoided = false;
                l2->Sigmoided = sigmoided;

                dbnLayers.push_back(l1);
                dbnLayers.push_back(l2);
            }
            else
            {
                auto paramNode = GetNodesWithType(node->GetInputs(), OperationNameOf(LearnableParameter)).front();
                DbnLayerPtr l1 = make_shared<DbnLayer>();
                l1->Bias = bias;
                l1->Node = paramNode;
                l1->Sigmoided = sigmoided;

                dbnLayers.push_back(l1);
            }
        }
    }

    // Write the layers to the output 
    // DBN wants std not invstd, so need to invert each element
    std::vector<ComputationNodeBasePtr> normalizationNodes = GetNodesWithType(net->GetAllNodes(), OperationNameOf(PerDimMeanVarNormalizationNode));
    if (normalizationNodes.size() == 0)
    {
        RuntimeError("Model does not contain at least one node with the '%ls' operation.", OperationNameOf(PerDimMeanVarNormalizationNode).c_str());
    }

    ComputationNodeBasePtr meanNode = normalizationNodes.front()->GetInputs()[1];
    ComputationNodeBasePtr stdNode = normalizationNodes.front()->GetInputs()[2];
    
    Matrix<ElemType> meanNodeMatrix = meanNode->As<ComputationNode<ElemType>>()->Value().DeepClone();
    Matrix<ElemType> invStdNodeMatrix(std::move(stdNode->As<ComputationNode<ElemType>>()->Value().DeepClone().ElementInverse()));
    std::vector<ElemType> arr(invStdNodeMatrix.GetNumElements());
    ElemType* refArr = &arr[0];
    size_t arrSize = arr.size();
    invStdNodeMatrix.CopyToArray(refArr, arrSize);

    int ctx = FindReplicationContext(arr);
    std::vector<ComputationNodeBasePtr> priorNodes = WhereNode(net->GetAllNodes(), GetAllPriorNodes);
    if (priorNodes.size() != 1)
    {
        Warning("Could not reliably determine the prior node!");
    }

    // =================
    // Write to the file
    // =================
    File fstream(fileName, FileOptions::fileOptionsBinary | FileOptions::fileOptionsWrite);

    // local helper functions for writing stuff in DBN.exe-expected format
    auto PutTag = [&fstream](const char * tag) { while (*tag) fstream << *tag++; };
    auto PutString = [&fstream](const char * string) { fstream.WriteString(string, 0); };
    auto PutInt = [&fstream](int val) { fstream << val; };

    // write a DBN matrix object, optionally applying a function
    auto PutMatrixConverted = [&](const Matrix<ElemType> * m, size_t maxelem, const char * name, float(*f)(float))
    {
        PutTag("BMAT");
        PutString(name);
        size_t numRows = m->GetNumRows();
        size_t numCols = m->GetNumCols();

        if (maxelem == SIZE_MAX)
        {
            PutInt(numRows);
            PutInt(numCols);
        }
        else    // this allows to shorten a vector, as we need for mean/invstd
        {
            PutInt(maxelem);
            PutInt(1);
        }

        // this code transposes the matrix on the fly, and outputs at most maxelem floating point numbers to the stream
        size_t k = 0;
        for (size_t j = 0; j < numCols && k < maxelem; j++)
            for (size_t i = 0; i < numRows && k < maxelem; i++, k++)
                fstream << f((float)(*m)(i, j));

        PutTag("EMAT");
    };
    auto PutMatrix = [&](const Matrix<ElemType> * m, const char * name) { PutMatrixConverted(m, SIZE_MAX, name, [](float v) { return v; }); };

    // write out the data
    // Dump DBN header
    PutString("DBN");
    PutTag("BDBN");
    PutInt(0);                                                                              // a version number
    PutInt(static_cast<int>(dbnLayers.size()));                                             // number of layers

    // Dump feature norm
    PutMatrixConverted(&meanNodeMatrix, meanNodeMatrix.GetNumRows() / ctx, "gmean", [](float v) { return v; });
    PutMatrixConverted(&invStdNodeMatrix, invStdNodeMatrix.GetNumRows() / ctx, "gstddev", [](float v) { return v; });

    PutTag("BNET");
    auto lastOne = dbnLayers.end();
    --lastOne;
    for (auto ii = dbnLayers.begin(), e = dbnLayers.end(); ii != e; ++ii)
    {
        DbnLayerPtr& layer = *ii;

        if (ii == dbnLayers.begin())
        {
            PutString("rbmgaussbernoulli");
        }
        else if (ii == lastOne)
        {
            PutString("perceptron");
        }
        else if (layer->Sigmoided)
        {
            PutString("rbmbernoullibernoulli");
        }
        else
        {
            PutString("rbmisalinearbernoulli");
        }

        // Write out the main weight matrix
        auto weight = (layer->Node->As<ComputationNode<ElemType>>()->Value().DeepClone());
        auto transpose = weight.Transpose();
        PutMatrix(&transpose, "W");

        // Write out biasing vector
        // Is mandatory, so pack with zeroes if not given
        auto rows = layer->Node->GetAsMatrixNumRows();
        if (layer->Bias == nullptr)
        {
            auto zeros = Matrix<ElemType>::Zeros(rows, 1, CPUDEVICE);
            PutMatrixConverted(&zeros, rows, "a", [](float v) { return v; });
        }
        else
        {
            PutMatrixConverted(&(layer->Bias->As<ComputationNode<ElemType>>()->Value()), rows, "a", [](float v) { return v; });
        }

        // Some sort of legacy vector that is useless
        auto zeros = Matrix<ElemType>::Zeros(0, 0, CPUDEVICE);
        PutMatrix(&(zeros), "b");
    }

    // Dump the priors
    PutTag("ENET");
    if (priorNodes.size() > 0)
    {
        PutMatrix(&(priorNodes.front()->As<ComputationNode<ElemType>>()->Value()), "Pu");
    }
    else
    {
        Warning("No priority node(s) found!");
    }
    PutTag("EDBN");
}