in Libraries/DirectMLX.h [3086:3194]
inline GRUOutputs GRU(
Expression input,
Expression weight,
Expression recurrence,
Optional<Expression> bias,
Optional<Expression> hiddenInit,
Optional<Expression> sequenceLengths,
Span<const FusedActivation> activationDescs,
DML_RECURRENT_NETWORK_DIRECTION direction,
bool linearBeforeReset,
GRUOutputOptions outputOptions)
{
detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
TensorDesc inputTensor = input.Impl()->GetOutputDesc();
TensorDesc weightTensor = weight.Impl()->GetOutputDesc();
TensorDesc recurrenceTensor = recurrence.Impl()->GetOutputDesc();
TensorDesc biasTensor;
TensorDesc hiddenInitTensor;
TensorDesc sequenceLengthsTensor;
TensorDesc outputSequenceTensor;
TensorDesc outputSingleTensor;
if (bias)
{
biasTensor = bias->Impl()->GetOutputDesc();
}
if (hiddenInit)
{
hiddenInitTensor = hiddenInit->Impl()->GetOutputDesc();
}
if (sequenceLengths)
{
sequenceLengthsTensor = sequenceLengths->Impl()->GetOutputDesc();
}
TensorDesc::Dimensions outputSequenceSizes(4);
TensorDesc::Dimensions outputSingleSizes(4);
uint32_t directionCount = (direction == DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL) ? 2 : 1;
if (outputOptions == GRUOutputOptions::Sequence || outputOptions == GRUOutputOptions::Both)
{
outputSequenceSizes[0] = inputTensor.sizes[1]; // SequenceLength
outputSequenceSizes[1] = directionCount;
outputSequenceSizes[2] = inputTensor.sizes[2]; // BatchSize
outputSequenceSizes[3] = recurrenceTensor.sizes[3]; // HiddenSize
outputSequenceTensor = TensorDesc(inputTensor.dataType, outputSequenceSizes, builder->GetTensorPolicy());
}
if (outputOptions == GRUOutputOptions::Single || outputOptions == GRUOutputOptions::Both)
{
outputSingleSizes[0] = 1;
outputSingleSizes[1] = directionCount;
outputSingleSizes[2] = inputTensor.sizes[2]; // BatchSize
outputSingleSizes[3] = recurrenceTensor.sizes[3]; // HiddenSize
outputSingleTensor = TensorDesc(inputTensor.dataType, outputSingleSizes, builder->GetTensorPolicy());
}
uint32_t activationCount = static_cast<uint32_t>(activationDescs.size());
if (activationCount > 4)
{
DMLX_THROW(E_INVALIDARG);
}
detail::FusedActivationStorage storage[4];
DML_OPERATOR_DESC activationDescArray[4];
for (uint32_t i = 0; i < activationCount; ++i)
{
activationDescArray[i] = *detail::GetFusedActivationPtr(activationDescs[i], &storage[i]);
}
DML_GRU_OPERATOR_DESC desc = {};
desc.InputTensor = inputTensor.AsPtr<DML_TENSOR_DESC>();
desc.WeightTensor = weightTensor.AsPtr<DML_TENSOR_DESC>();
desc.RecurrenceTensor = recurrenceTensor.AsPtr<DML_TENSOR_DESC>();
desc.BiasTensor = bias ? biasTensor.AsPtr<DML_TENSOR_DESC>() : nullptr;
desc.HiddenInitTensor = hiddenInit ? hiddenInitTensor.AsPtr<DML_TENSOR_DESC>() : nullptr;
desc.SequenceLengthsTensor = sequenceLengths ? sequenceLengthsTensor.AsPtr<DML_TENSOR_DESC>() : nullptr;
desc.OutputSequenceTensor = outputSequenceTensor.sizes.empty() ? nullptr : outputSequenceTensor.AsPtr<DML_TENSOR_DESC>();
desc.OutputSingleTensor = outputSingleTensor.sizes.empty() ? nullptr : outputSingleTensor.AsPtr<DML_TENSOR_DESC>();
desc.ActivationDescCount = activationCount;
desc.ActivationDescs = activationDescArray;
desc.Direction = direction;
desc.LinearBeforeReset = linearBeforeReset;
SmallVector<detail::NodeOutput*, 6> inputs = { input.Impl(), weight.Impl(), recurrence.Impl() };
if (bias)
{
inputs.push_back(bias->Impl());
}
if (hiddenInit)
{
inputs.push_back(hiddenInit->Impl());
}
if (sequenceLengths)
{
inputs.push_back(sequenceLengths->Impl());
}
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_GRU, &desc, inputs);
detail::NodeOutput* outputSequenceExpr = nullptr;
detail::NodeOutput* outputSingleExpr = nullptr;
if (outputOptions == GRUOutputOptions::Sequence || outputOptions == GRUOutputOptions::Both)
{
outputSequenceExpr = builder->CreateNodeOutput(node, 0, std::move(outputSequenceTensor));
}
if (outputOptions == GRUOutputOptions::Single || outputOptions == GRUOutputOptions::Both)
{
outputSingleExpr = builder->CreateNodeOutput(node, 1, std::move(outputSingleTensor));
}
return { outputSequenceExpr, outputSingleExpr };
}