std::vector CreateRNNConstantHelper()

in Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp [864:1328]


std::vector<Variable> CreateRNNConstantHelper(
    const Node *parentNode, int index, const std::string &name, const onnx::TensorProto &valueProto, const DeviceDescriptor &computeDevice)
{
    std::vector<Variable> inputs;
    auto dataType = valueProto.data_type();

    switch (dataType)
    {
    case TensorProto_DataType_FLOAT:
    {
        if (valueProto.float_data().empty())
        {
            RetrieveRawDataAsFloat(valueProto);
        }
    }
    case TensorProto_DataType_DOUBLE:
    {
        if (valueProto.double_data().empty())
        {
            RetrieveRawDataAsDouble(valueProto);
        }
    }
    }

    string parentONNXOpName = parentNode->OpType();
    // index to LSTM inputs as specified in the ONNX document.
    // https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
    if (parentONNXOpName == "LSTM")
    {
        switch (index)
        {
        case LSTMInputIndexX:
            // X, should not come to here
            CNTK::LogicError("input to a recurrent node shall not be a constant");
        case LSTMInputIndexW:
        case LSTMInputIndexH:
            // W, R:
            {
                // see ONNX spec for the tensor shape
                int num_directions = valueProto.dims(0);
                size_t rows = valueProto.dims(1);
                size_t cols = valueProto.dims(2);

                // CNTK cpp requires shape being (input_size, 4 * hidden_size)
                NDShape weightShape({rows, cols});

                int input_size = cols;
                int cell_size = rows / 4;

                for (int dir = 0; dir < num_directions; dir++)
                {
                    std::string nodeName = name + (index == 1 ? "_W_" : "_R_") + (char) ('0' + dir);
                    int totalSizePerDirection = rows * cols;

                    // TODO: what about double?
                    DType *data = new DType[totalSizePerDirection];
                    for (size_t count = 0; count < totalSizePerDirection; count++)
                    {
                        int row = count / input_size;
                        int col = count % input_size;
                        int block = row / cell_size;

                        if (block == 1)
                        {
                            // o
                            row += cell_size * 2;
                        }
                        else if (block == 3)
                        {
                            // c
                            row -= cell_size * 2;
                        }

                        int sourceIndex = dir * totalSizePerDirection + count;
                        int targetIndex = col * cell_size * 4 + row;
                        CopyFromProto(valueProto, data, sourceIndex, targetIndex);
                    }

                    Constant constant = CreateConstantWithRawData(&data[0], weightShape, nodeName, computeDevice);
                    inputs.push_back(constant);
                }
                return inputs;
            }
        case LSTMInputIndexB:
            // B
            {
                // see ONNX spec for the tensor shape
                int num_directions = valueProto.dims(0);
                int cell_size = valueProto.dims(1) / 8;
                // there is an ONNX spec issue with bias input. It states that
                // "This tensor has shape `[num_directions, 8*hidden_size]", which means
                // hidden and input are applied with bias separately after weight.
                // In CNTK, bias is be applied in fused form, after hidden and input
                // are element-wise added. In this case
                // the bias shape is [num_directions, 4*hidden_size]
                NDShape weightShape({(size_t)(4 * cell_size)});
                for (int dir = 0; dir < num_directions; dir++)
                {
                    std::string nodeName = name + std::string(1, (char) ('0' + dir)) + LSTMInputBiasNameHint;
                    int totalSizePerDirection = 4 * cell_size;
                    DType *data = new DType[totalSizePerDirection];
                    for (size_t targetIndex = 0; targetIndex < totalSizePerDirection; targetIndex++)
                    {
                        int row = targetIndex;

                        // TODO: specific to LSTM. icfo (CNTK) to iofc(ONNX)
                        int block = row / cell_size;
                        if (block == 1)
                        {
                            // c
                            row += 2 * cell_size;
                        }
                        else if (block == 3)
                        {
                            // o
                            row -= 2 * cell_size;
                        }

                        // source is column major
                        int src_index = row;

                        // "fuse"
                        vector<int> srcIndexRange = {
                            dir * 2 * totalSizePerDirection + src_index,
                            dir * 2 * totalSizePerDirection + totalSizePerDirection + src_index};

                        CopyFromProto(valueProto, data, srcIndexRange, targetIndex);
                    }

                    Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
                    inputs.push_back(constant);
                }
                return inputs;
            }
        case LSTMInputIndexSequenceLens:
            // sequence length is treated as free dimension
            return inputs;
        case LSTMInputIndexinitial_h:
        case LSTMInputIndexinitial_c:
        {
            // initial_h, initial_c
            int num_directions = valueProto.dims(0);

            // TODO: batch shall be one?
            // int batchSize = valueProto.dims(1);
            int cell_size = valueProto.dims(2);
            // there is an ONNX spec issue with bias input. It states that
            // "This tensor has shape `[num_directions, 8*hidden_size]", which means
            // hidden and input are applied with bias separately after weight.
            // In CNTK, bias is be applied in fused form, after hidden and input
            // are element-wise added. In this case
            // the bias shape is [num_directions, 4*hidden_size]
            NDShape weightShape({(size_t)(cell_size)});
            for (int dir = 0; dir < num_directions; dir++)
            {
                std::string nodeName = name + std::string(1, (char) ('0' + dir));
                if (index == 5)
                    nodeName += LSTMInputInitialHNameHint;
                else
                    nodeName += LSTMInputInitialCNameHint;

                DType *data = new DType[cell_size];
                for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++)
                {
                    CopyFromProto(valueProto, data, dir * cell_size + targetIndex, targetIndex);
                }

                Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
                inputs.push_back(constant);
            }
            return inputs;
        }
        break;
        case LSTMInputIndexP:
            // P
            {
                int num_directions = valueProto.dims(0);
                int cell_size = valueProto.dims(1) / 3;
                for (int dir = 0; dir < num_directions; dir++)
                    for (int i = 0; i < 3; i++)
                    {
                        std::string nodeName = name + ((i == 0) ? "_i" : ((i == 1) ? "_o" : "_f")) +
                                               std::string(1, (char) ('0' + dir)) + LSTMInputPeepholeNameHint;

                        DType *data = new DType[cell_size];
                        NDShape weightShape({(size_t)(cell_size)});
                        for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++)
                        {
                            CopyFromProto(valueProto, data, (dir * 3 + i) * cell_size + targetIndex, targetIndex);
                        }

                        Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
                        inputs.push_back(constant);
                    }
                return inputs;
            }
        default:
            CNTK::LogicError("CreateRNNConstant received unexpected index: %d", index);
        }
    }
    else if (parentONNXOpName == "GRU")
    {
        // https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---6
        switch (index)
        {
        case GRUInputIndexX:
            // X, should not come to here
            CNTK::LogicError("input to a recurrent node shall not be a constant");
        case GRUInputIndexW:
        {
            // see ONNX spec for the tensor shape
            int num_directions = valueProto.dims(0);
            size_t rows = valueProto.dims(1);
            size_t cols = valueProto.dims(2);

            // CNTK cpp requires shape: (input_size, 3 * hidden_size)
            NDShape weightShape({rows, cols});

            int input_size = cols;
            int cell_size = rows / 3;

            for (int dir = 0; dir < num_directions; dir++)
            {
                std::string nodeName = name + "_W_" + (char) ('0' + dir);
                int totalSizePerDirection = rows * cols;

                // TODO: what about double?
                DType *data = new DType[totalSizePerDirection];
                for (size_t count = 0; count < totalSizePerDirection; count++)
                {
                    int row = count / input_size;
                    int col = count % input_size;
                    int sourceIndex = dir * totalSizePerDirection + count;
                    int targetIndex = col * cell_size * GRUWeightDimensionHiddenMultiplier + row;

                    CopyFromProto(valueProto, data, sourceIndex, targetIndex);
                }

                Constant constant = CreateConstantWithRawData(&data[0], weightShape, nodeName, computeDevice);
                inputs.push_back(constant);
            }
            return inputs;
        }
        case GRUInputIndexR:
        {
            // split into H and H1 for CNTK GRU implementation
            int num_directions = valueProto.dims(0);
            size_t rows = valueProto.dims(1);
            size_t cols = valueProto.dims(2);

            int input_size = cols;
            int cell_size = rows / 3;

            NDShape hShape({(size_t) cell_size * 2, (size_t) input_size});
            NDShape h1Shape({(size_t) cell_size, (size_t) input_size});

            inputs.resize(num_directions * 2);
            for (int dir = 0; dir < num_directions; dir++)
            {
                std::string hNodeName = name + "_H_" + (char) ('0' + dir);
                std::string h1NodeName = name + "_H1_" + (char) ('0' + dir);
                int totalSizePerDirection = rows * cols;

                DType *hData = new DType[hShape.TotalSize()];
                DType *h1Data = new DType[h1Shape.TotalSize()];
                for (size_t count = 0; count < totalSizePerDirection; count++)
                {
                    int row = count / input_size;
                    int col = count % input_size;
                    int block = row / cell_size;
                    int sourceIndex = dir * totalSizePerDirection + count;
                    if (block < CNTKGRUZRWeightMultiplier)
                    {
                        int targetIndex = col * cell_size * CNTKGRUZRWeightMultiplier + row;

                        CopyFromProto(valueProto, hData, sourceIndex, targetIndex);
                    }
                    else
                    {
                        int targetIndex = col * cell_size + row - cell_size * CNTKGRUZRWeightMultiplier;

                        CopyFromProto(valueProto, h1Data, sourceIndex, targetIndex);
                    }
                }

                Constant constantH = CreateConstantWithRawData(&hData[0], hShape, hNodeName, computeDevice);
                Constant constantH1 = CreateConstantWithRawData(&h1Data[0], h1Shape, h1NodeName, computeDevice);
                inputs[dir] = constantH;
                inputs[dir + num_directions] = constantH1;
            }
            return inputs;
        }
        case GRUInputIndexB:
            // B
            {
                // see ONNX spec for the tensor shape
                int num_directions = valueProto.dims(0);
                int cell_size = valueProto.dims(1) / GRUBiasDimensionHiddenMultiplier;
                // shape size is divided by 2 so that it only applies to input (CNTK)
                // TODO: this incompatibility needs further investigation.
                NDShape weightShape({(size_t)(GRUBiasDimensionHiddenMultiplier / 2 * cell_size)});
                for (int dir = 0; dir < num_directions; dir++)
                {
                    std::string nodeName = name + std::string(1, '0' + dir) + LSTMInputBiasNameHint;
                    int totalSizePerDirection = GRUBiasDimensionHiddenMultiplier / 2 * cell_size;
                    DType *data = new DType[totalSizePerDirection];
                    for (size_t targetIndex = 0; targetIndex < totalSizePerDirection; targetIndex++)
                    {
                        int row = targetIndex;
                        // source is column major
                        int src_index = row;
                        // "fuse"
                        vector<int> sourceIndexRange = {
                            dir * 2 * totalSizePerDirection + src_index,
                            dir * 2 * totalSizePerDirection + totalSizePerDirection + src_index};

                        CopyFromProto(valueProto, data, sourceIndexRange, targetIndex);
                    }

                    Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
                    inputs.push_back(constant);
                }
                return inputs;
            }
        case GRUInputIndexSequenceLens:
            return inputs;
        case GRUInitialH:
        {
            // initial_h
            int num_directions = valueProto.dims(0);
            int cell_size = valueProto.dims(2);
            NDShape weightShape({(size_t)(cell_size)});
            for (int dir = 0; dir < num_directions; dir++)
            {
                std::string nodeName = name + std::string(1, (char) ('0' + dir)) + LSTMInputInitialHNameHint;

                DType *data = new DType[cell_size];
                for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++)
                {
                    CopyFromProto(valueProto, data, dir * cell_size + targetIndex, targetIndex);
                }

                Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
                inputs.push_back(constant);
            }
            return inputs;
        }
        default:
            CNTK::LogicError("CreateRNNConstant for GRU op received unexpected index: %d", index);
        }
    }
    else if (parentONNXOpName == "RNN")
    {
        // https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---6-1
        switch (index)
        {
        case RNNInputIndexX:
            // X, should not come to here
            CNTK::LogicError("input to a recurrent node shall not be a constant");
        case RNNInputIndexW:
        case RNNInputIndexR:
        {
            // see ONNX spec for the tensor shape
            int num_directions = valueProto.dims(0);
            size_t rows = valueProto.dims(1);
            size_t cols = valueProto.dims(2);

            // CNTK cpp requires shape: (input_size, 3 * hidden_size)
            NDShape weightShape({rows, cols});

            int input_size = cols;
            int cell_size = rows;

            for (int dir = 0; dir < num_directions; dir++)
            {
                std::string nodeName = name + (index == RNNInputIndexW ? "_W_" : "_R_") + (char) ('0' + dir);
                int totalSizePerDirection = rows * cols;

                // TODO: what about double?
                DType *data = new DType[totalSizePerDirection];
                for (size_t count = 0; count < totalSizePerDirection; count++)
                {
                    int row = count / input_size;
                    int col = count % input_size;
                    int sourceIndex = dir * totalSizePerDirection + count;
                    int targetIndex = col * cell_size + row;

                    CopyFromProto(valueProto, data, sourceIndex, targetIndex);
                }

                Constant constant = CreateConstantWithRawData(&data[0], weightShape, nodeName, computeDevice);
                inputs.push_back(constant);
            }
            return inputs;
        }
        case RNNInputIndexB:
            // B
            {
                // see ONNX spec for the tensor shape:
                // https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---6-1
                // shape of bias is [num_directions, 2*hidden_size] thus we divide dim(1) by 2
                // to get cell_size.
                int num_directions = valueProto.dims(0);
                int cell_size = valueProto.dims(1) / 2;
                NDShape weightShape({(size_t)(cell_size)});
                for (int dir = 0; dir < num_directions; dir++)
                {
                    std::string nodeName = name + std::string(1, '0' + dir) + LSTMInputBiasNameHint;
                    int totalSizePerDirection = cell_size;
                    DType *data = new DType[totalSizePerDirection];
                    for (size_t targetIndex = 0; targetIndex < totalSizePerDirection; targetIndex++)
                    {
                        int row = targetIndex;
                        // source is column major
                        int src_index = row;
                        // "fuse"
                        // RNN only has one bias vector. It is applied after element-wise addition
                        // of projected input and hidden states. Therefore we need to fuse two biases
                        // in ONNX into one.
                        // RNNBiasMultiplier = 2

                        vector<int> srcIndexRange = {
                            dir * RNNBiasMultiplier * totalSizePerDirection + src_index,
                            dir * RNNBiasMultiplier * totalSizePerDirection + totalSizePerDirection + src_index};

                        CopyFromProto(valueProto, data, srcIndexRange, targetIndex);
                    }

                    Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
                    inputs.push_back(constant);
                }
                return inputs;
            }
        case RNNInputIndexSequenceLens:
            return inputs;
        case RNNInitialH:
        {
            // initial_h
            int num_directions = valueProto.dims(0);
            int cell_size = valueProto.dims(2);
            NDShape weightShape({(size_t)(cell_size)});
            for (int dir = 0; dir < num_directions; dir++)
            {
                std::string nodeName = name + std::string(1, (char) ('0' + dir)) + LSTMInputInitialHNameHint;

                DType *data = new DType[cell_size];
                for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++)
                {
                    CopyFromProto(valueProto, data, dir * cell_size + targetIndex, targetIndex);
                }

                Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
                inputs.push_back(constant);
            }
            return inputs;
        }
        default:
            CNTK::LogicError("CreateRNNConstant for GRU op received unexpected index: %d", index);
        }
    }
    else
    {
        NOT_IMPLEMENTED;
    }
}