void LstmBackwardSingleLayer()

in src/operator/rnn_impl.h [328:496]


void LstmBackwardSingleLayer(DType* ws,
                             DType* rs,
                             DType* tmp_buf,
                             bool bid,
                             const int T,
                             const int N,
                             const int I,
                             const int H,
                             const Tensor<cpu, 2, DType> &x,
                             const Tensor<cpu, 2, DType> &hx,
                             const Tensor<cpu, 2, DType> &cx,
                             const Tensor<cpu, 3, DType> &y,
                             const Tensor<cpu, 3, DType> &dy,
                             const Tensor<cpu, 2, DType> &dx,
                             const Tensor<cpu, 2, DType> &dhx,
                             const Tensor<cpu, 2, DType> &dcx,
                             DType* dhy_ptr,
                             DType* dcy_ptr,
                             DType* w_ptr,
                             DType* dw_ptr,
                             DType* db_ptr,
                             int req_data,
                             int req_params,
                             int req_state,
                             int req_statecell) {
  using namespace mshadow;
  const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I));
  const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H));
  Tensor<cpu, 2, DType> dwx(dw_ptr, Shape2(H * 4, I));
  Tensor<cpu, 2, DType> dwh(dw_ptr + I * H * 4, Shape2(H * 4, H));
  Tensor<cpu, 1, DType> dbx(db_ptr, Shape1(H * 4));
  Tensor<cpu, 1, DType> dbh(dbx.dptr_ + H * 4, Shape1(H * 4));
  DType *c_ptr = bid ? rs + T * N * H * 7 : rs;
  const Tensor<cpu, 3, DType> c(c_ptr, Shape3(T, N, H));
  const Tensor<cpu, 4, DType> ifgo(c_ptr + T * N * H, Shape4(T, N, H, 4));
  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
  if (req_params != kNullOp && req_params != kAddTo) {
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < H * 4 * H; ++i) {
      dwh.dptr_[i] = 0;
    }
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < 4 * H; ++i) {
      dbx.dptr_[i] = 0;
      dbh.dptr_[i] = 0;
    }
  }
  Tensor<cpu, 4, DType> difgo(ws, Shape4(T, N, 4, H));
  Tensor<cpu, 2, DType> dh(ws + T * N * H * 4, Shape2(N, H));
  Tensor<cpu, 2, DType> dc(dh.dptr_ + N * H, Shape2(N, H));
  Tensor<cpu, 2, DType> htmp(dc.dptr_ + N * H, Shape2(N, H));
  const int offset = bid ? H : 0;
  const DType alpha = 1.0;
  const DType beta0 = 0.0;
  const DType beta1 = 1.0;
  const DType beta2 = 2.0;
  const int cell_size = N * H;
  if (dhy_ptr != NULL) {
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < cell_size; ++i) {
      dh.dptr_[i] = dhy_ptr[i];
    }
  } else {
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < cell_size; ++i) {
      dh.dptr_[i] = 0;
    }
  }
  if (dcy_ptr != NULL) {
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < cell_size; ++i) {
      dc.dptr_[i] = dcy_ptr[i];
    }
  } else {
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < cell_size; ++i) {
      dc.dptr_[i] = 0;
    }
  }

  for (int i = T - 1; i >= 0; --i) {
    int t = bid ? T - 1 - i : i;
    int tnext = bid ? t + 1 : t - 1;
    const Tensor<cpu, 2, DType>& dhnext = i ? dh : dhx;
    const Tensor<cpu, 2, DType>& dcnext = i ? dc : dcx;
    const Tensor<cpu, 2, DType>& hnext = i ? htmp : hx;
    const Tensor<cpu, 2, DType>& cnext = i ? c[i - 1] : cx;
    #pragma omp parallel for num_threads(omp_threads)
    for (int jk = 0; jk < cell_size; ++jk) {
      int j = jk / H;
      int k = jk % H;
      DType tc = tanh(c[i][j][k]);
      DType it = ifgo[i][j][k][0];
      DType ft = ifgo[i][j][k][1];
      DType gt = ifgo[i][j][k][2];
      DType ot = ifgo[i][j][k][3];
      dh[j][k] += dy[t][j][k + offset];
      dc[j][k] += dh[j][k] * ot * (1 - tc * tc);
      difgo[t][j][0][k] = dc[j][k] * gt * it * (1 - it);
      difgo[t][j][1][k] = dc[j][k] * cnext[j][k] * ft * (1 - ft);
      difgo[t][j][2][k] = dc[j][k] * it * (1 - gt * gt);
      difgo[t][j][3][k] = dh[j][k] * tc * ot * (1 - ot);
      if (req_statecell != kNullOp || i > 0) {
        dcnext[j][k] = dc[j][k] * ft;
      }
      if (i) {
        htmp[j][k] = y[tnext][j][k + offset];
      }
    }
    Tensor<cpu, 2, DType> dyh(difgo[t].dptr_, Shape2(N, H * 4));
    if (req_state != kNullOp || i > 0) {
      linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false);
    }
    if (req_params != kNullOp) {
      if (req_params != kAddTo) {
        linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false);
      } else {
        linalg_gemm(dyh, hnext, dwh, alpha, beta2, true, false);

        //  generate dwx every time step for AddTo
        Tensor<cpu, 2, DType> x_t(x.dptr_ + i * N * I, Shape2(N, I));
        Tensor<cpu, 2, DType> dyx_t(difgo.dptr_ + i * N * H * 4, Shape2(N, H * 4));
        linalg_gemm(dyx_t, x_t, dwx, alpha, beta2, true, false);
      }
    }
  }
  Tensor<cpu, 2, DType> dyx(difgo.dptr_, Shape2(T * N, H * 4));
  if (req_data != kNullOp) {
    linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false);
  }
  if (req_params != kNullOp && req_params != kAddTo) {
    linalg_gemm(dyx, x, dwx, alpha, beta0, true, false);
  }
  const int row = T * N;
  const int col = H * 4;
  if (req_params != kNullOp) {
    if (req_params != kAddTo) {
      for (int i = 0; i < row; ++i) {
        #pragma omp parallel for num_threads(omp_threads)
        for (int j = 0; j < col; ++j) {
          dbx[j] += dyx[i][j];
          dbh[j] = dbx[j];
        }
      }
    } else {
      const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf, Shape2(col, T));
      const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + col * T, Shape2(col, T));
      #pragma omp parallel for num_threads(omp_threads)
      for (int i = 0; i < col * T; ++i) {
        tmp_dbx.dptr_[i] = 0;
        tmp_dbh.dptr_[i] = 0;
      }
      for (int t = T - 1; t >= 0; --t) {
        #pragma omp parallel for num_threads(omp_threads)
        for (int j = 0; j < col; ++j) {
          for (int i = 0; i < N; ++i) {
            tmp_dbx[j][t] += dyx[t * N + i][j];
            tmp_dbh[j][t] = tmp_dbx[j][t];
          }
        }
        #pragma omp parallel for num_threads(omp_threads)
        for (int j = 0; j < col; ++j) {
          dbx[j] += tmp_dbx[j][t] + dbx[j];
          dbh[j] += tmp_dbh[j][t] + dbh[j];
        }
      }
    }
  }
}