void Function::createOnnxLSTM()

in lib/Graph/Graph.cpp [4912:5284]


void Function::createOnnxLSTM(llvm::StringRef namePrefix, NodeValue X,
                              NodeValue W, NodeValue R, NodeValue B,
                              NodeValue initial_h, NodeValue initial_c,
                              NodeValue P, NodeValue &Y, NodeValue &Y_h,
                              NodeValue &Y_c, unsigned hiddenSize,
                              RnnDirection direction,
                              std::vector<RnnActivation> &activations,
                              bool inputForget) {

#define LSTM_X_SLICE_RANGE(idx)                                                \
  {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
#define LSTM_H_SLICE_RANGE(idx)                                                \
  {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
#define LSTM_C_SLICE_RANGE(idx)                                                \
  {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
#define LSTM_W_SLICE_RANGE(idx0, idx1)                                         \
  {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
#define LSTM_R_SLICE_RANGE(idx0, idx1)                                         \
  {idx0, idx1 * hiddenSize, 0}, {                                              \
    idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize                              \
  }
#define LSTM_B_SLICE_RANGE(idx0, idx1)                                         \
  {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
#define LSTM_P_SLICE_RANGE(idx0, idx1)                                         \
  {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
#define LSTM_CREATE_FC(name, LHS, RHS, BIAS)                                   \
  BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS)                    \
       : (Node *)createMatMul(name, LHS, RHS)

  // Operator name.
  const std::string &opName = namePrefix.str();

  // Get all size parameters.
  dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
  assert(X.dims().size() == 3 &&
         "ONNX LSTM input 'X' should have 3 dimensions!");
  dim_t seqLength = X.dims()[0];
  dim_t batchSize = X.dims()[1];
  dim_t inputSize = X.dims()[2];

  // Validate W size.
  assert(W.dims().size() == 3 &&
         "ONNX LSTM input 'W' should have 3 dimensions!");
  assert(W.dims()[0] == numDirections && W.dims()[1] == 4 * hiddenSize &&
         W.dims()[2] == inputSize && "ONNX LSTM 'W' tensor size invalid!");

  // Validate R size.
  assert(R.dims().size() == 3 &&
         "ONNX LSTM input 'R' should have 3 dimensions!");
  assert(R.dims()[0] == numDirections && R.dims()[1] == 4 * hiddenSize &&
         R.dims()[2] == hiddenSize && "ONNX LSTM 'R' tensor size invalid!");

  // Validate B size.
  if (B.getNode()) {
    assert(B.dims().size() == 2 &&
           "ONNX LSTM input 'B' should have 2 dimensions!");
    assert(B.dims()[0] == numDirections && B.dims()[1] == 8 * hiddenSize &&
           "ONNX LSTM 'B' tensor size invalid!");
  }

  // Validate initial_h size if given else create Splat with 0.
  if (initial_h.getNode()) {
    assert(initial_h.dims().size() == 3 &&
           "ONNX LSTM input 'initial_h' should have 2 dimensions!");
    assert(initial_h.dims()[0] == numDirections &&
           initial_h.dims()[1] == batchSize &&
           initial_h.dims()[2] == hiddenSize &&
           "ONNX LSTM 'initial_h' tensor size invalid!");
  } else {
    auto splatTy = getParent()->uniqueType(
        ElemKind::FloatTy, {numDirections, batchSize, hiddenSize});
    initial_h = createSplat(opName + ".initial_h", splatTy, 0.0);
  }

  // Validate initial_c size if given else create Splat with 0.
  if (initial_c.getNode()) {
    assert(initial_c.dims().size() == 3 &&
           "ONNX LSTM input 'initial_c' should have 2 dimensions!");
    assert(initial_c.dims()[0] == numDirections &&
           initial_c.dims()[1] == batchSize &&
           initial_c.dims()[2] == hiddenSize &&
           "ONNX LSTM 'initial_c' tensor size invalid!");
  } else {
    auto splatTy = getParent()->uniqueType(
        ElemKind::FloatTy, {numDirections, batchSize, hiddenSize});
    initial_c = createSplat(opName + ".initial_c", splatTy, 0.0);
  }

  // Validate P size.
  if (P.getNode()) {
    assert(P.dims().size() == 2 &&
           "ONNX LSTM input 'P' should have 2 dimensions!");
    assert(P.dims()[0] == numDirections && P.dims()[1] == 3 * hiddenSize &&
           "ONNX LSTM 'P' tensor size invalid!");
  }

  // Validate number of activations.
  assert(activations.size() == numDirections * 3 &&
         "ONNX LSTM activations vector invalid!");

  // Create X slices.
  std::vector<Node *> Xslices;
  for (dim_t t = 0; t < seqLength; t++) {
    auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
    Node *Xt = createSlice(XsliceName, X, LSTM_X_SLICE_RANGE(t));
    auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
    Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
    Xslices.push_back(Xt);
  }

  // Lambda to load forward/backward LSTM cell.
  auto loadLSTMCell = [&](bool forward, std::vector<NodeValue> &Yslices,
                          NodeValue &Hslice, NodeValue &Cslice) {
    // Name prefix.
    std::string dirLabel = forward ? ".fw" : ".bw";
    std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");

    // Slice index used for creating weights slices.
    dim_t sliceIdx0 = 0;
    if (direction == RnnDirection::Bidirectional) {
      sliceIdx0 = forward ? 0 : 1;
    }

    // Activations.
    size_t activationOffset = sliceIdx0 * 3;
    auto activationF = activations[activationOffset + 0];
    auto activationG = activations[activationOffset + 1];
    auto activationH = activations[activationOffset + 2];

    // Create W slices (Required).
    NodeValue Wi =
        createSlice(prefix + ".Wi.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 0));
    NodeValue Wo =
        createSlice(prefix + ".Wo.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 1));
    NodeValue Wf =
        createSlice(prefix + ".Wf.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 2));
    NodeValue Wc =
        createSlice(prefix + ".Wc.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 3));

    Wi = createReshape(prefix + ".Wi.reshape", Wi, {hiddenSize, inputSize});
    Wo = createReshape(prefix + ".Wo.reshape", Wo, {hiddenSize, inputSize});
    Wf = createReshape(prefix + ".Wf.reshape", Wf, {hiddenSize, inputSize});
    Wc = createReshape(prefix + ".Wc.reshape", Wc, {hiddenSize, inputSize});

    Wi = createTranspose(prefix + ".Wi.transp", Wi, {1, 0});
    Wo = createTranspose(prefix + ".Wo.transp", Wo, {1, 0});
    Wf = createTranspose(prefix + ".Wf.transp", Wf, {1, 0});
    Wc = createTranspose(prefix + ".Wc.transp", Wc, {1, 0});

    // Create R slices (Required).
    NodeValue Ri =
        createSlice(prefix + ".Ri.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 0));
    NodeValue Ro =
        createSlice(prefix + ".Ro.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 1));
    NodeValue Rf =
        createSlice(prefix + ".Rf.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 2));
    NodeValue Rc =
        createSlice(prefix + ".Rc.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 3));

    Ri = createReshape(prefix + ".Ri.reshape", Ri, {hiddenSize, hiddenSize});
    Ro = createReshape(prefix + ".Ro.reshape", Ro, {hiddenSize, hiddenSize});
    Rf = createReshape(prefix + ".Rf.reshape", Rf, {hiddenSize, hiddenSize});
    Rc = createReshape(prefix + ".Rc.reshape", Rc, {hiddenSize, hiddenSize});

    Ri = createTranspose(prefix + ".Ri.transp", Ri, {1, 0});
    Ro = createTranspose(prefix + ".Ro.transp", Ro, {1, 0});
    Rf = createTranspose(prefix + ".Rf.transp", Rf, {1, 0});
    Rc = createTranspose(prefix + ".Rc.transp", Rc, {1, 0});

    // Create B slices (optional).
    NodeValue bWi = nullptr;
    NodeValue bWo = nullptr;
    NodeValue bWf = nullptr;
    NodeValue bWc = nullptr;
    NodeValue bRi = nullptr;
    NodeValue bRo = nullptr;
    NodeValue bRf = nullptr;
    NodeValue bRc = nullptr;

    if (B) {

      bWi = createSlice(prefix + ".bWi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 0));
      bWo = createSlice(prefix + ".bWo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 1));
      bWf = createSlice(prefix + ".bWf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 2));
      bWc = createSlice(prefix + ".bWc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 3));
      bRi = createSlice(prefix + ".bRi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 4));
      bRo = createSlice(prefix + ".bRo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 5));
      bRf = createSlice(prefix + ".bRf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 6));
      bRc = createSlice(prefix + ".bRc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 7));

      bWi = createReshape(prefix + ".bWi.reshape", bWi, {hiddenSize});
      bWo = createReshape(prefix + ".bWo.reshape", bWo, {hiddenSize});
      bWf = createReshape(prefix + ".bWf.reshape", bWf, {hiddenSize});
      bWc = createReshape(prefix + ".bWc.reshape", bWc, {hiddenSize});
      bRi = createReshape(prefix + ".bRi.reshape", bRi, {hiddenSize});
      bRo = createReshape(prefix + ".bRo.reshape", bRo, {hiddenSize});
      bRf = createReshape(prefix + ".bRf.reshape", bRf, {hiddenSize});
      bRc = createReshape(prefix + ".bRc.reshape", bRc, {hiddenSize});
    }

    // Create P slices (optional).
    NodeValue Pi = nullptr;
    NodeValue Po = nullptr;
    NodeValue Pf = nullptr;

    if (P) {

      Pi = createSlice(prefix + ".Pi.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 0));
      Po = createSlice(prefix + ".Po.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 1));
      Pf = createSlice(prefix + ".Pf.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 2));

      // Repeat P slices to match [batchSize, hiddenSize].
      Pi = createTile(prefix + ".Pi.repeat", Pi, batchSize, 0);
      Po = createTile(prefix + ".Po.repeat", Po, batchSize, 0);
      Pf = createTile(prefix + ".Pf.repeat", Pf, batchSize, 0);
    }

    // Create H slice for this direction.
    Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
                              LSTM_H_SLICE_RANGE(sliceIdx0));
    Hinit =
        createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});

    // Create C slice for this direction.
    Node *Cinit = createSlice(prefix + ".C.slice", initial_c,
                              LSTM_C_SLICE_RANGE(sliceIdx0));
    Cinit =
        createReshape(prefix + ".C.reshape", Cinit, {batchSize, hiddenSize});

    // Initialize.
    Node *Ht = Hinit;
    Node *Ct = Cinit;

    // Unroll LSTM cell for all time steps.
    for (size_t t = 0; t < seqLength; t++) {

      // Input for current time step.
      // For the reverse LSTM cell the inputs are provided in reverse order.
      Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];

      // Forget gate: ft = f(Xt * Wf + bWf + Ht-1 * Rf + bRf + Pf . Ct-1).
      Node *ft = createAdd(prefix + ".F.add1",
                           LSTM_CREATE_FC(prefix + ".F.fc1", Xt, Wf, bWf),
                           LSTM_CREATE_FC(prefix + ".F.fc2", Ht, Rf, bRf));
      if (Pf) {
        ft = createAdd(prefix + ".F.add2", ft,
                       createMul(prefix + ".F.mult", Pf, Ct));
      }
      ft = activationF(prefix + ".F.act", ft);

      // Cell state candidate: ctild = g(Xt * Wc + bWc + Ht-1 * Rc + bRc).
      Node *ctild =
          createAdd(prefix + ".Ctild.add",
                    LSTM_CREATE_FC(prefix + ".Ctild.fc1", Xt, Wc, bWc),
                    LSTM_CREATE_FC(prefix + ".Ctild.fc2", Ht, Rc, bRc));
      ctild = activationG(prefix + ".Ctild.act", ctild);

      // Input gate:
      // For inputForget == true:
      //   it = 1 - ft.
      // For inputForget == false:
      //   it = f(Xt * Wi + bWi + Ht-1 * Ri + bRi + Pi . Ct-1).
      Node *it;
      if (inputForget) {
        auto splatTy = ft->getNthResult(0).getType();
        it = createSub(prefix + ".I.sub",
                       createSplat(prefix + ".I.splat", splatTy, 1.0), ft);
      } else {
        it = createAdd(prefix + ".I.add1",
                       LSTM_CREATE_FC(prefix + ".I.fc1", Xt, Wi, bWi),
                       LSTM_CREATE_FC(prefix + ".I.fc2", Ht, Ri, bRi));
        if (Pi) {
          it = createAdd(prefix + ".I.add2", it,
                         createMul(prefix + ".I.mult", Pi, Ct));
        }
        it = activationF(prefix + ".I.act", it);
      }

      // Cell state update: Ct = ft . Ct-1 + it . ctild.
      Ct = createAdd(prefix + ".C.add", createMul(prefix + ".C.mult1", ft, Ct),
                     createMul(prefix + ".C.mult2", it, ctild));

      // Output gate: ot = f(Xt * Wo + bWo + Ht-1 * Ro + bRo + Po . Ct).
      Node *ot = createAdd(prefix + ".O.add1",
                           LSTM_CREATE_FC(prefix + ".O.fc1", Xt, Wo, bWo),
                           LSTM_CREATE_FC(prefix + ".O.fc2", Ht, Ro, bRo));
      if (Po) {
        ot = createAdd(prefix + ".O.add2", ot,
                       createMul(prefix + ".O.mult", Po, Ct));
      }
      ot = activationF(prefix + ".O.act", ot);

      // Hidden state update: Ht = ot . h(Ct).
      Ht =
          createMul(prefix + ".H.mult", ot, activationH(prefix + ".H.act", Ct));

      // Output.
      Yslices.push_back(Ht);
    }

    // Updated states nodes.
    Hslice = Ht;
    Cslice = Ct;
  }; // End of local lambda "loadLSTMCell".

  bool forwardEnabled = ((direction == RnnDirection::Forward) ||
                         (direction == RnnDirection::Bidirectional));
  bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
                          (direction == RnnDirection::Bidirectional));

  std::vector<NodeValue> YSlices;
  std::vector<NodeValue> Hslices;
  std::vector<NodeValue> Cslices;

  // Load forward LSTM.
  std::vector<NodeValue> forwardYslices;
  if (forwardEnabled) {
    NodeValue forwardHslice;
    NodeValue forwardCslice;
    loadLSTMCell(/* forward */ true, forwardYslices, forwardHslice,
                 forwardCslice);
    Hslices.push_back(forwardHslice);
    Cslices.push_back(forwardCslice);
  }

  // Load backward LSTM.
  std::vector<NodeValue> backwardYslices;
  if (backwardEnabled) {
    NodeValue backwardHslice;
    NodeValue backwardCslice;
    loadLSTMCell(/* forward */ false, backwardYslices, backwardHslice,
                 backwardCslice);
    Hslices.push_back(backwardHslice);
    Cslices.push_back(backwardCslice);
  }

  // Gather Y slices.
  for (size_t t = 0; t < seqLength; t++) {
    if (forwardEnabled) {
      YSlices.push_back(forwardYslices[t]);
    }
    if (backwardEnabled) {
      YSlices.push_back(backwardYslices[seqLength - 1 - t]);
    }
  }

  // Concatenate Y slices.
  // Y size is [seqLength, numDirections, batchSize, hiddenSize].
  Y = createReshape(opName + ".Y.reshape",
                    createConcat(opName + ".Y.concat", YSlices, 0),
                    {seqLength, numDirections, batchSize, hiddenSize});

  // Concatenate Y_h slices.
  // Y_h size is [numDirections, batchSize, hiddenSize].
  Y_h = createReshape(opName + ".Y_h.reshape",
                      createConcat(opName + ".Y_h.concat", Hslices, 0),
                      {numDirections, batchSize, hiddenSize});

  // Concatenate Y_c slices.
  // Y_c size is [numDirections, batchSize, hiddenSize].
  Y_c = createReshape(opName + ".Y_c.reshape",
                      createConcat(opName + ".Y_c.concat", Cslices, 0),
                      {numDirections, batchSize, hiddenSize});

#undef LSTM_X_SLICE_RANGE
#undef LSTM_H_SLICE_RANGE
#undef LSTM_C_SLICE_RANGE
#undef LSTM_W_SLICE_RANGE
#undef LSTM_R_SLICE_RANGE
#undef LSTM_B_SLICE_RANGE
#undef LSTM_P_SLICE_RANGE
#undef LSTM_CREATE_FC
}