inline GRUOutputs GRU()

in Libraries/DirectMLX.h [3086:3194]


    inline GRUOutputs GRU(
        Expression input,
        Expression weight,
        Expression recurrence,
        Optional<Expression> bias,
        Optional<Expression> hiddenInit,
        Optional<Expression> sequenceLengths,
        Span<const FusedActivation> activationDescs,
        DML_RECURRENT_NETWORK_DIRECTION direction,
        bool linearBeforeReset,
        GRUOutputOptions outputOptions)
    {
        detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
        TensorDesc inputTensor = input.Impl()->GetOutputDesc();
        TensorDesc weightTensor = weight.Impl()->GetOutputDesc();
        TensorDesc recurrenceTensor = recurrence.Impl()->GetOutputDesc();
        TensorDesc biasTensor;
        TensorDesc hiddenInitTensor;
        TensorDesc sequenceLengthsTensor;
        TensorDesc outputSequenceTensor;
        TensorDesc outputSingleTensor;
        if (bias)
        {
            biasTensor = bias->Impl()->GetOutputDesc();
        }
        if (hiddenInit)
        {
            hiddenInitTensor = hiddenInit->Impl()->GetOutputDesc();
        }
        if (sequenceLengths)
        {
            sequenceLengthsTensor = sequenceLengths->Impl()->GetOutputDesc();
        }

        TensorDesc::Dimensions outputSequenceSizes(4);
        TensorDesc::Dimensions outputSingleSizes(4);
        uint32_t directionCount = (direction == DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL) ? 2 : 1;
        if (outputOptions == GRUOutputOptions::Sequence || outputOptions == GRUOutputOptions::Both)
        {
            outputSequenceSizes[0] = inputTensor.sizes[1]; // SequenceLength
            outputSequenceSizes[1] = directionCount;
            outputSequenceSizes[2] = inputTensor.sizes[2]; // BatchSize
            outputSequenceSizes[3] = recurrenceTensor.sizes[3]; // HiddenSize
            outputSequenceTensor = TensorDesc(inputTensor.dataType, outputSequenceSizes, builder->GetTensorPolicy());
        }
        if (outputOptions == GRUOutputOptions::Single || outputOptions == GRUOutputOptions::Both)
        {
            outputSingleSizes[0] = 1;
            outputSingleSizes[1] = directionCount;
            outputSingleSizes[2] = inputTensor.sizes[2]; // BatchSize
            outputSingleSizes[3] = recurrenceTensor.sizes[3]; // HiddenSize
            outputSingleTensor = TensorDesc(inputTensor.dataType, outputSingleSizes, builder->GetTensorPolicy());
        }

        uint32_t activationCount = static_cast<uint32_t>(activationDescs.size());
        if (activationCount > 4)
        {
            DMLX_THROW(E_INVALIDARG);
        }

        detail::FusedActivationStorage storage[4];
        DML_OPERATOR_DESC activationDescArray[4];
        for (uint32_t i = 0; i < activationCount; ++i)
        {
            activationDescArray[i] = *detail::GetFusedActivationPtr(activationDescs[i], &storage[i]);
        }

        DML_GRU_OPERATOR_DESC desc = {};
        desc.InputTensor = inputTensor.AsPtr<DML_TENSOR_DESC>();
        desc.WeightTensor = weightTensor.AsPtr<DML_TENSOR_DESC>();
        desc.RecurrenceTensor = recurrenceTensor.AsPtr<DML_TENSOR_DESC>();
        desc.BiasTensor = bias ? biasTensor.AsPtr<DML_TENSOR_DESC>() : nullptr;
        desc.HiddenInitTensor = hiddenInit ? hiddenInitTensor.AsPtr<DML_TENSOR_DESC>() : nullptr;
        desc.SequenceLengthsTensor = sequenceLengths ? sequenceLengthsTensor.AsPtr<DML_TENSOR_DESC>() : nullptr;
        desc.OutputSequenceTensor = outputSequenceTensor.sizes.empty() ? nullptr : outputSequenceTensor.AsPtr<DML_TENSOR_DESC>();
        desc.OutputSingleTensor = outputSingleTensor.sizes.empty() ? nullptr : outputSingleTensor.AsPtr<DML_TENSOR_DESC>();
        desc.ActivationDescCount = activationCount;
        desc.ActivationDescs = activationDescArray;
        desc.Direction = direction;
        desc.LinearBeforeReset = linearBeforeReset;

        SmallVector<detail::NodeOutput*, 6> inputs = { input.Impl(), weight.Impl(), recurrence.Impl() };
        if (bias)
        {
            inputs.push_back(bias->Impl());
        }
        if (hiddenInit)
        {
            inputs.push_back(hiddenInit->Impl());
        }
        if (sequenceLengths)
        {
            inputs.push_back(sequenceLengths->Impl());
        }

        detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_GRU, &desc, inputs);

        detail::NodeOutput* outputSequenceExpr = nullptr;
        detail::NodeOutput* outputSingleExpr = nullptr;
        if (outputOptions == GRUOutputOptions::Sequence || outputOptions == GRUOutputOptions::Both)
        {
             outputSequenceExpr = builder->CreateNodeOutput(node, 0, std::move(outputSequenceTensor));
        }
        if (outputOptions == GRUOutputOptions::Single || outputOptions == GRUOutputOptions::Both)
        {
             outputSingleExpr = builder->CreateNodeOutput(node, 1, std::move(outputSingleTensor));
        }
        return { outputSequenceExpr, outputSingleExpr };
    }