in src/operator/rnn_impl.h [1044:1384]
void GruBackwardSingleLayer(DType* ws,
DType* tmp_buf,
const int D,
const int T,
const int N,
const int I,
const int H,
const Tensor<cpu, 2, DType> &x,
const Tensor<cpu, 2, DType> &hx,
DType* wx_ptr,
DType* wh_ptr,
DType* y_ptr,
DType* dy_ptr,
DType* dhy_ptr,
DType* gateR,
DType* gateZ,
DType* gateN,
DType* Mnh,
DType* dx,
DType* dhx,
DType* dwx,
DType* dwh,
DType* dbx,
DType* dbh,
int req_data,
int req_params,
int req_state) {
DType* dyt;
DType* ht1; // [N, D, H]
DType* rt;
DType* zt;
DType* nt;
DType* dat;
DType* dart;
DType* dar = ws; // [T, N, 3 * H]
DType* da = dar + T * N * 3 * H; // [T, N, 3 * H]
DType* dht1 = da + T * N * 3 * H; // [D, N, H]
DType* hx_ = dht1 + D * N * H; // [N, D, H]
DType* Mnht = Mnh;
DType* back_ht1;
DType* back_dht1 = dht1 + N * H; // [N, H]
DType* back_Mnht = Mnh + T * N * H;
DType* back_gateR = gateR + T * N * H;
DType* back_gateZ = gateZ + T * N * H;
DType* back_gateN = gateN + T * N * H;
DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
DType* back_dwx = dwx + I * 3 * H + H * 3 * H;
DType* back_dwh = dwh + I * 3 * H + H * 3 * H;
DType* back_dbx = dbx + 3 * H * 2;
DType* back_dbh = dbh + 3 * H * 2;
DType alpha = 1.0;
DType beta = 0.0;
const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
if (req_params != kNullOp && req_params != kAddTo) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < D * H * 3 * H; ++i) {
dwh[i] = 0;
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < D * 3 * H; ++i) {
dbx[i] = 0;
dbh[i] = 0;
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N * H; ++i) {
if (dhy_ptr) {
dht1[i] = dhy_ptr[i];
} else {
dht1[i] = 0;
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
hx_[i * D * H + j] = hx[i][j];
}
}
if (D == 2) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N * H; ++i) {
if (dhy_ptr) {
back_dht1[i] = dhy_ptr[N * H + i];
} else {
back_dht1[i] = 0;
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
hx_[i * D * H + H + j] = hx[N + i][j];
}
}
}
for (int t = T - 1; t >= 0; --t) {
if (t) {
ht1 = y_ptr + (t - 1) * N * D * H;
} else {
ht1 = hx_;
}
// add dy[T, N, D, H] to dhy[D, N, H]
dyt = dy_ptr + t * N * D * H;
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
dht1[i * H + j] += dyt[i * D * H + j];
}
}
rt = gateR + t * N * H;
zt = gateZ + t * N * H;
nt = gateN + t * N * H;
Mnht = Mnh + t * N * H;
dat = da + t * N * 3 * H;
dart = dar + t * N * 3 * H;
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
int nid = i * 3 * H + 2 * H + j;
int zid = i * 3 * H + H + j;
int rid = i * 3 * H + j;
int id = i * H + j;
dat[nid] = dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]);
dart[zid] = dat[zid] = dht1[id] * (ht1[i * D * H + j] - nt[id]) *
zt[id] * (1 - zt[id]);
dart[rid] = dat[rid] = dat[nid] * Mnht[id] * rt[id] *
(1 - rt[id]);
dart[nid] = dat[nid] * rt[id];
dht1[id] = dht1[id] * zt[id];
}
}
if (req_params != kNullOp) {
alpha = 1.0;
beta = 1.0;
// dht1 = dart * wh [N, H] = [N, 3 * H] * [3 * H, H]
Tensor<cpu, 2, DType> d_dht1(dht1, Shape2(N, H));
Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H));
linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false);
if (req_params == kAddTo) {
beta = 2.0;
// dwx = da.T * x [3 * H, I] = [3 * H, N] * [N, I] for AddTo
Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
Tensor<cpu, 2, DType> d_dat(dat, Shape2(N, 3 * H));
Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I));
linalg_gemm(d_dat, d_xt, d_dwx, alpha, beta, true, false);
}
// dwh = dart.T * ht1 [3 * H, H] = [3 * H, N] * [N, H]
Tensor<cpu, 2, DType> d_ht1(ht1, Shape2(N, D * H));
Tensor<cpu, 2, DType> d_dwh(dwh, Shape2(3 * H, H));
Tensor<cpu, 3, DType> d_ht1_tmp = Tensor<cpu, 3, DType>
(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N));
linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true);
}
}
if (req_params != kNullOp) {
// dbx = e * da [1, 3 * H] = [1, N] * [N, 3 * H]
if (req_params != kAddTo) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < 3 * H; ++i) {
for (int j = 0; j < N * T; ++j) {
dbx[i] += da[j * 3 * H + i];
dbh[i] += dar[j * 3 * H + i];
}
}
} else {
const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T));
const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T));
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H * T * 3; ++i) {
tmp_dbx.dptr_[i] = 0;
tmp_dbh.dptr_[i] = 0;
}
for (int t = T - 1; t >= 0; --t) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < 3 * H; ++i) {
for (int j = 0; j < N; ++j) {
tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i];
tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i];
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < 3 * H; ++i) {
dbx[i] += tmp_dbx[i][t] + dbx[i];
dbh[i] += tmp_dbh[i][t] + dbh[i];
}
}
}
}
alpha = 1.0;
beta = 0.0;
// dx = da * wx [T * N, I] = [T * N, 3 * H] * [3 * H, I]
Tensor<cpu, 2, DType> d_da(da, Shape2(T * N, 3 * H));
if (req_data != kNullOp) {
Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
linalg_gemm(d_da, wx, d_dx, alpha, beta, false, false);
}
// dwx = da.T * x [3 * H, I] = [3 * H, T * N] * [T * N, I]
if (req_params != kNullOp && req_params != kAddTo) {
Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I));
linalg_gemm(d_da, x, d_dwx, alpha, beta, true, false);
}
if (D == 2) {
for (int t = 0; t < T; ++t) {
if (t == T-1) {
back_ht1 = hx_;
} else {
back_ht1 = y_ptr + (t + 1) * N * D * H;
}
// add dy[T, N, D, H] to dhy[D, N, H]
dyt = dy_ptr + t * N * D * H;
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
back_dht1[i * H + j] += dyt[i * D * H + H + j];
}
}
rt = back_gateR + t * N * H;
zt = back_gateZ + t * N * H;
nt = back_gateN + t * N * H;
back_Mnht = Mnh + (T + t) * N * H;
dat = da + t * N * 3 * H;
dart = dar + t * N * 3 * H;
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
int nid = i * 3 * H + 2 * H + j;
int zid = i * 3 * H + H + j;
int rid = i * 3 * H + j;
int id = i * H + j;
dat[nid] = back_dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]);
dart[zid] = dat[zid] = back_dht1[id] * (back_ht1[i * D * H + H + j] -
nt[id]) * zt[id] * (1 - zt[id]);
dart[rid] = dat[rid] = dat[nid] * back_Mnht[id] * rt[id] *
(1 - rt[id]);
dart[nid] = dat[nid] * rt[id];
back_dht1[id] = back_dht1[id] * zt[id];
}
}
if (req_params != kNullOp) {
alpha = 1.0;
beta = 1.0;
// dht1 = da * wh [N, H] = [N, 3 * H] * [3 * H, H]
Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H));
Tensor<cpu, 2, DType> d_back_dht1(back_dht1, Shape2(N, H));
linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false);
// dwh = da.T * ht1 [3 * H, H] = [3 * H, N] * [N, H]
Tensor<cpu, 2, DType> d_back_dwh(back_dwh, Shape2(3 * H, H));
Tensor<cpu, 2, DType> d_back_ht1(back_ht1 + H, Shape2(N, D * H));
Tensor<cpu, 3, DType> d_back_ht1_tmp = Tensor<cpu, 3, DType>
(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N));
if (req_params == kAddTo) {
beta = 2.0;
// dwx = da.T * x [3 * H, I] = [3 * H, N] * [N, I] for AddTo
Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
Tensor<cpu, 2, DType> d_dat(dat, Shape2(N, 3 * H));
Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(3 * H, I));
linalg_gemm(d_dat, d_xt, d_back_dwx, alpha, beta, true, false);
}
linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true);
}
}
if (req_params != kNullOp) {
// dbx = e * da [1, 3 * H] = [1, N] * [N, 3 * H]
if (req_params != kAddTo) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < 3 * H; ++i) {
for (int j = 0; j < N * T; ++j) {
back_dbx[i] += da[j * 3 * H + i];
back_dbh[i] += dar[j * 3 * H + i];
}
}
} else {
const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T));
const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T));
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H * T * 3; ++i) {
tmp_dbx.dptr_[i] = 0;
tmp_dbh.dptr_[i] = 0;
}
for (int t = T - 1; t >= 0; --t) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < 3 * H; ++i) {
for (int j = 0; j < N; ++j) {
tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i];
tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i];
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < 3 * H; ++i) {
back_dbx[i] += tmp_dbx[i][t] + back_dbx[i];
back_dbh[i] += tmp_dbh[i][t] + back_dbh[i];
}
}
}
}
alpha = 1.0;
beta = 1.0;
// dxt = da * wx [T * N, I] = [T * N, 3 * H] * [3 * H, I]
Tensor<cpu, 2, DType> d_da2(da, Shape2(T * N, 3 * H));
if (req_data != kNullOp) {
Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
linalg_gemm(d_da2, back_wx, d_dx, alpha, beta, false, false);
}
alpha = 1.0;
beta = 0.0;
// dwx = da.T * x [3 * H, I] = [3 * H, T * N] * [T * N, I]
if (req_params != kNullOp && req_params != kAddTo) {
Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(3 * H, I));
linalg_gemm(d_da2, x, d_back_dwx, alpha, beta, true, false);
}
}
if (req_state != kNullOp) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N * H * D; ++i) {
dhx[i] = dht1[i];
}
}
}