void GruBackwardSingleLayer()

in src/operator/rnn_impl.h [1044:1384]


void GruBackwardSingleLayer(DType* ws,
                            DType* tmp_buf,
                            const int D,
                            const int T,
                            const int N,
                            const int I,
                            const int H,
                            const Tensor<cpu, 2, DType> &x,
                            const Tensor<cpu, 2, DType> &hx,
                            DType* wx_ptr,
                            DType* wh_ptr,
                            DType* y_ptr,
                            DType* dy_ptr,
                            DType* dhy_ptr,
                            DType* gateR,
                            DType* gateZ,
                            DType* gateN,
                            DType* Mnh,
                            DType* dx,
                            DType* dhx,
                            DType* dwx,
                            DType* dwh,
                            DType* dbx,
                            DType* dbh,
                            int req_data,
                            int req_params,
                            int req_state) {
  DType* dyt;
  DType* ht1;  // [N, D, H]
  DType* rt;
  DType* zt;
  DType* nt;
  DType* dat;
  DType* dart;
  DType* dar = ws;  // [T, N, 3 * H]
  DType* da = dar + T * N * 3 * H;  // [T, N, 3 * H]
  DType* dht1 = da + T * N * 3 * H;  // [D, N, H]
  DType* hx_ = dht1 + D * N * H;  // [N, D, H]
  DType* Mnht = Mnh;
  DType* back_ht1;
  DType* back_dht1 = dht1 + N * H;  // [N, H]
  DType* back_Mnht = Mnh + T * N * H;
  DType* back_gateR = gateR + T * N * H;
  DType* back_gateZ = gateZ + T * N * H;
  DType* back_gateN = gateN + T * N * H;
  DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
  DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
  DType* back_dwx = dwx + I * 3 * H + H * 3 * H;
  DType* back_dwh = dwh + I * 3 * H + H * 3 * H;
  DType* back_dbx = dbx + 3 * H * 2;
  DType* back_dbh = dbh + 3 * H * 2;

  DType alpha = 1.0;
  DType beta = 0.0;
  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
  if (req_params != kNullOp && req_params != kAddTo) {
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < D * H * 3 * H; ++i) {
      dwh[i] = 0;
    }
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < D * 3 * H; ++i) {
      dbx[i] = 0;
      dbh[i] = 0;
    }
  }
  #pragma omp parallel for num_threads(omp_threads)
  for (int i = 0; i < N * H; ++i) {
    if (dhy_ptr) {
      dht1[i] = dhy_ptr[i];
    } else {
      dht1[i] = 0;
    }
  }

  #pragma omp parallel for num_threads(omp_threads)
  for (int i = 0; i < N; ++i) {
    for (int j = 0; j < H; ++j) {
      hx_[i * D * H + j] = hx[i][j];
    }
  }

  if (D == 2) {
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < N * H; ++i) {
      if (dhy_ptr) {
        back_dht1[i] = dhy_ptr[N * H + i];
      } else {
        back_dht1[i] = 0;
      }
    }
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < N; ++i) {
      for (int j = 0; j < H; ++j) {
        hx_[i * D * H + H + j] = hx[N + i][j];
      }
    }
  }
  for (int t = T - 1; t >= 0; --t) {
    if (t) {
      ht1 = y_ptr + (t - 1) * N * D * H;
    } else {
      ht1 = hx_;
    }
    // add dy[T, N, D, H] to dhy[D, N, H]
    dyt = dy_ptr + t * N * D * H;

    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < N; ++i) {
      for (int j = 0; j < H; ++j) {
        dht1[i * H + j] += dyt[i * D * H + j];
      }
    }

    rt = gateR + t * N * H;
    zt = gateZ + t * N * H;
    nt = gateN + t * N * H;
    Mnht = Mnh +  t * N * H;
    dat = da + t * N * 3 * H;
    dart = dar + t * N * 3 * H;
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < N; ++i) {
      for (int j = 0; j < H; ++j) {
        int nid = i * 3 * H + 2 * H + j;
        int zid = i * 3 * H + H + j;
        int rid = i * 3 * H + j;
        int id = i * H + j;
        dat[nid] = dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]);
        dart[zid] = dat[zid] = dht1[id] * (ht1[i * D * H + j] - nt[id]) *
            zt[id] * (1 - zt[id]);
        dart[rid] = dat[rid] = dat[nid] * Mnht[id] * rt[id] *
            (1 - rt[id]);
        dart[nid] = dat[nid] * rt[id];
        dht1[id] = dht1[id] * zt[id];
      }
    }
    if (req_params != kNullOp) {
      alpha = 1.0;
      beta = 1.0;
      // dht1 = dart * wh    [N, H] = [N, 3 * H] * [3 * H, H]
      Tensor<cpu, 2, DType> d_dht1(dht1, Shape2(N, H));
      Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H));
      linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false);

      if (req_params == kAddTo) {
        beta = 2.0;
        // dwx = da.T * x    [3 * H, I] = [3 * H, N] * [N, I] for AddTo
        Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
        Tensor<cpu, 2, DType> d_dat(dat, Shape2(N, 3 * H));
        Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I));
        linalg_gemm(d_dat, d_xt, d_dwx, alpha, beta, true, false);
      }
      // dwh = dart.T * ht1    [3 * H, H] = [3 * H, N] * [N, H]
      Tensor<cpu, 2, DType> d_ht1(ht1, Shape2(N, D * H));
      Tensor<cpu, 2, DType> d_dwh(dwh, Shape2(3 * H, H));
      Tensor<cpu, 3, DType> d_ht1_tmp = Tensor<cpu, 3, DType>
          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
      d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N));
      linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true);
    }
  }

  if (req_params != kNullOp) {
    // dbx = e * da       [1, 3 * H] = [1, N] * [N, 3 * H]
    if (req_params != kAddTo) {
      #pragma omp parallel for num_threads(omp_threads)
      for (int i = 0; i < 3 * H; ++i) {
        for (int j = 0; j < N * T; ++j) {
          dbx[i] += da[j * 3 * H + i];
          dbh[i] += dar[j * 3 * H + i];
        }
      }
    } else {
      const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T));
      const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T));
      #pragma omp parallel for num_threads(omp_threads)
      for (int i = 0; i < H * T * 3; ++i) {
        tmp_dbx.dptr_[i] = 0;
        tmp_dbh.dptr_[i] = 0;
      }

      for (int t = T - 1; t >= 0; --t) {
        #pragma omp parallel for num_threads(omp_threads)
        for (int i = 0; i < 3 * H; ++i) {
          for (int j = 0; j < N; ++j) {
            tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i];
            tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i];
          }
        }
        #pragma omp parallel for num_threads(omp_threads)
        for (int i = 0; i < 3 * H; ++i) {
          dbx[i] += tmp_dbx[i][t] + dbx[i];
          dbh[i] += tmp_dbh[i][t] + dbh[i];
        }
      }
    }
  }
  alpha = 1.0;
  beta = 0.0;

  // dx = da * wx    [T * N, I] = [T * N, 3 * H] * [3 * H, I]
  Tensor<cpu, 2, DType> d_da(da, Shape2(T * N, 3 * H));
  if (req_data != kNullOp) {
    Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
    linalg_gemm(d_da, wx, d_dx, alpha, beta, false, false);
  }

  // dwx = da.T * x    [3 * H, I] = [3 * H, T * N] * [T * N, I]
  if (req_params != kNullOp && req_params != kAddTo) {
    Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I));
    linalg_gemm(d_da, x, d_dwx, alpha, beta, true, false);
  }

  if (D == 2) {
    for (int t = 0; t < T; ++t) {
      if (t == T-1) {
        back_ht1 = hx_;
      } else {
        back_ht1 = y_ptr + (t + 1) * N * D * H;
      }

      //  add dy[T, N, D, H] to dhy[D, N, H]
      dyt = dy_ptr + t * N * D * H;
      #pragma omp parallel for num_threads(omp_threads)
      for (int i = 0; i < N; ++i) {
        for (int j = 0; j < H; ++j) {
          back_dht1[i * H + j] += dyt[i * D * H + H + j];
        }
      }

      rt = back_gateR + t * N * H;
      zt = back_gateZ + t * N * H;
      nt = back_gateN + t * N * H;
      back_Mnht = Mnh + (T + t) * N * H;
      dat = da + t * N * 3 * H;
      dart = dar + t * N * 3 * H;

      #pragma omp parallel for num_threads(omp_threads)
      for (int i = 0; i < N; ++i) {
        for (int j = 0; j < H; ++j) {
          int nid = i * 3 * H + 2 * H + j;
          int zid = i * 3 * H + H + j;
          int rid = i * 3 * H + j;
          int id = i * H + j;
          dat[nid] = back_dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]);
          dart[zid] = dat[zid] = back_dht1[id] * (back_ht1[i * D * H + H + j] -
              nt[id]) * zt[id] * (1 - zt[id]);
          dart[rid] = dat[rid] = dat[nid] * back_Mnht[id] * rt[id] *
              (1 - rt[id]);
          dart[nid] = dat[nid] * rt[id];
          back_dht1[id] = back_dht1[id] * zt[id];
        }
      }

      if (req_params != kNullOp) {
        alpha = 1.0;
        beta = 1.0;
        // dht1 = da * wh    [N, H] = [N, 3 * H] * [3 * H, H]
        Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H));
        Tensor<cpu, 2, DType> d_back_dht1(back_dht1, Shape2(N, H));
        linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false);

        // dwh = da.T * ht1     [3 * H, H] = [3 * H, N] * [N, H]
        Tensor<cpu, 2, DType> d_back_dwh(back_dwh, Shape2(3 * H, H));
        Tensor<cpu, 2, DType> d_back_ht1(back_ht1 + H, Shape2(N, D * H));
        Tensor<cpu, 3, DType> d_back_ht1_tmp = Tensor<cpu, 3, DType>
            (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
        d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N));
        if (req_params == kAddTo) {
          beta = 2.0;
          // dwx = da.T * x    [3 * H, I] = [3 * H, N] * [N, I] for AddTo
          Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
          Tensor<cpu, 2, DType> d_dat(dat, Shape2(N, 3 * H));
          Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(3 * H, I));
          linalg_gemm(d_dat, d_xt, d_back_dwx, alpha, beta, true, false);
        }
        linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true);
      }
    }

    if (req_params != kNullOp) {
    // dbx = e * da       [1, 3 * H] = [1, N] * [N, 3 * H]
      if (req_params != kAddTo) {
        #pragma omp parallel for num_threads(omp_threads)
        for (int i = 0; i < 3 * H; ++i) {
          for (int j = 0; j < N * T; ++j) {
            back_dbx[i] += da[j * 3 * H + i];
            back_dbh[i] += dar[j * 3 * H + i];
          }
        }
      } else {
        const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T));
        const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T));
        #pragma omp parallel for num_threads(omp_threads)
        for (int i = 0; i < H * T * 3; ++i) {
          tmp_dbx.dptr_[i] = 0;
          tmp_dbh.dptr_[i] = 0;
        }
        for (int t = T - 1; t >= 0; --t) {
          #pragma omp parallel for num_threads(omp_threads)
          for (int i = 0; i < 3 * H; ++i) {
            for (int j = 0; j < N; ++j) {
              tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i];
              tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i];
            }
          }
          #pragma omp parallel for num_threads(omp_threads)
          for (int i = 0; i < 3 * H; ++i) {
            back_dbx[i] += tmp_dbx[i][t] + back_dbx[i];
            back_dbh[i] += tmp_dbh[i][t] + back_dbh[i];
          }
        }
      }
    }
    alpha = 1.0;
    beta = 1.0;
    // dxt = da * wx    [T * N, I] = [T * N, 3 * H] * [3 * H, I]
    Tensor<cpu, 2, DType> d_da2(da, Shape2(T * N, 3 * H));
    if (req_data != kNullOp) {
      Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
      linalg_gemm(d_da2, back_wx, d_dx, alpha, beta, false, false);
    }
    alpha = 1.0;
    beta = 0.0;
    // dwx = da.T * x    [3 * H, I] = [3 * H, T * N] * [T * N, I]
    if (req_params != kNullOp && req_params != kAddTo) {
      Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(3 * H, I));
      linalg_gemm(d_da2, x, d_back_dwx, alpha, beta, true, false);
    }
  }
  if (req_state != kNullOp) {
    #pragma omp parallel for num_threads(omp_threads)
    for (int i = 0; i < N * H * D; ++i) {
      dhx[i] = dht1[i];
    }
  }
}