in src/operator/rnn_impl.h [1926:2241]
void VanillaRNNBackwardSingleLayer(DType* ws,
DType* tmp_buf,
const int D,
const int T,
const int N,
const int I,
const int H,
const Tensor<cpu, 2, DType> &x,
const Tensor<cpu, 2, DType> &hx,
DType* wx_ptr,
DType* wh_ptr,
DType* y_ptr,
DType* dy_ptr,
DType* dhy_ptr,
DType* gateN,
DType* dx,
DType* dhx,
DType* dwx,
DType* dwh,
DType* dbx,
DType* dbh,
int req_data,
int req_params,
int req_state,
int mode) {
DType* dyt;
DType* ht1; // [N, D, H]
DType* dart;
DType* nt;
DType* dar = ws; // [T, N, H]
DType* dht1 = dar + T * N * H; // [D, N, H]
DType* hx_ = dht1 + D * N * H; // [N, D, H]
DType* back_ht1;
DType* back_dht1 = dht1 + N * H; // [N, H]
DType* back_gateN = gateN + T * N * H;
DType* back_wx_ptr = wx_ptr + I * H + H * H;
DType* back_wh_ptr = wh_ptr + I * H + H * H;
DType* back_dwx = dwx + I * H + H * H;
DType* back_dwh = dwh + I * H + H * H;
DType* back_dbx = dbx + H * 2;
DType* back_dbh = dbh + H * 2;
DType alpha = 1.0;
DType beta = 0.0;
const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H, I));
const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H, H));
const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H, I));
const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H, H));
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
if (req_params != kNullOp && req_params != kAddTo) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < D * H * H; ++i) {
dwh[i] = 0;
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < D * H; ++i) {
dbx[i] = 0;
dbh[i] = 0;
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N * H; ++i) {
if (dhy_ptr) {
dht1[i] = dhy_ptr[i];
} else {
dht1[i] = 0;
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
hx_[i * D * H + j] = hx[i][j];
}
}
if (D == 2) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N * H; ++i) {
if (dhy_ptr) {
back_dht1[i] = dhy_ptr[N * H + i];
} else {
back_dht1[i] = 0;
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
hx_[i * D * H + H + j] = hx[N + i][j];
}
}
}
for (int t = T - 1; t >= 0; --t) {
if (t) {
ht1 = y_ptr + (t - 1) * N * D * H;
} else {
ht1 = hx_;
}
// add dy[T, N, D, H] to dhy[D, N, H]
dyt = dy_ptr + t * N * D * H;
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
dht1[i * H + j] += dyt[i * D * H + j];
}
}
nt = gateN + t * N * H;
dart = dar + t * N * H;
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
int id = i * H + j;
if (mode == 1) {
dart[id] = dht1[id] * (1 - nt[id] * nt[id]);
} else {
dart[id] = nt[id] > 0.0f ? static_cast<float>(dht1[id]) : 0.0f;
}
dht1[id] = 0;
}
}
if (req_params != kNullOp) {
alpha = 1.0;
beta = 1.0;
// dht1 = dart * wh [N, H] = [N, H] * [H, H]
Tensor<cpu, 2, DType> d_dht1(dht1, Shape2(N, H));
Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, H));
linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false);
if (req_params == kAddTo) {
beta = 2.0;
// dwx = da.T * x [H, I] = [H, N] * [N, I] for AddTo
Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(H, I));
linalg_gemm(d_dart, d_xt, d_dwx, alpha, beta, true, false);
}
// dwh = dart.T * ht1 [H, H] = [H, N] * [N, H]
Tensor<cpu, 2, DType> d_ht1(ht1, Shape2(N, D * H));
Tensor<cpu, 2, DType> d_dwh(dwh, Shape2(H, H));
Tensor<cpu, 3, DType> d_ht1_tmp = Tensor<cpu, 3, DType>
(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N));
linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true);
}
}
if (req_params != kNullOp) {
// dbx = e * da [1, H] = [1, N] * [N, H]
if (req_params != kAddTo) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H; ++i) {
for (int j = 0; j < N * T; ++j) {
dbx[i] += dar[j * H + i];
dbh[i] = dbx[i];
}
}
} else {
const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, T));
const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + H * T, Shape2(H, T));
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H * T; ++i) {
tmp_dbx.dptr_[i] = 0;
tmp_dbh.dptr_[i] = 0;
}
for (int t = T - 1; t >= 0; --t) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H; ++i) {
for (int j = 0; j < N; ++j) {
tmp_dbx[i][t] += dar[t * N * H + j * H + i];
tmp_dbh[i][t] = tmp_dbx[i][t];
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H; ++i) {
dbx[i] += tmp_dbx[i][t] + dbx[i];
dbh[i] = dbx[i];
}
}
}
}
alpha = 1.0;
beta = 0.0;
// dx = da * wx [T * N, I] = [T * N, H] * [H, I]
Tensor<cpu, 2, DType> d_dar(dar, Shape2(T * N, H));
if (req_data != kNullOp) {
Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
linalg_gemm(d_dar, wx, d_dx, alpha, beta, false, false);
}
// dwx = da.T * x [H, I] = [H, T * N] * [T * N, I]
if (req_params != kNullOp && req_params != kAddTo) {
Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(H, I));
linalg_gemm(d_dar, x, d_dwx, alpha, beta, true, false);
}
if (D == 2) {
for (int t = 0; t < T; ++t) {
if (t == T-1) {
back_ht1 = hx_;
} else {
back_ht1 = y_ptr + (t + 1) * N * D * H;
}
// add dy[T, N, D, H] to dhy[D, N, H]
dyt = dy_ptr + t * N * D * H;
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
back_dht1[i * H + j] += dyt[i * D * H + H + j];
}
}
nt = back_gateN + t * N * H;
dart = dar + t * N * H;
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int j = 0; j < H; ++j) {
int id = i * H + j;
if (mode == 1) {
dart[id] = back_dht1[id] * (1 - nt[id] * nt[id]);
} else {
dart[id] = nt[id] > 0.0f ? static_cast<float>(back_dht1[id]) : 0.0f;
}
back_dht1[id] = 0;
}
}
if (req_params != kNullOp) {
alpha = 1.0;
beta = 1.0;
// dht1 = da * wh [N, H] = [N, H] * [H, H]
Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, H));
Tensor<cpu, 2, DType> d_back_dht1(back_dht1, Shape2(N, H));
linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false);
// dwh = da.T * ht1 [H, H] = [H, N] * [N, H]
Tensor<cpu, 2, DType> d_back_dwh(back_dwh, Shape2(H, H));
Tensor<cpu, 2, DType> d_back_ht1(back_ht1 + H, Shape2(N, D * H));
Tensor<cpu, 3, DType> d_back_ht1_tmp = Tensor<cpu, 3, DType>
(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N));
if (req_params == kAddTo) {
beta = 2.0;
// dwx = da.T * x [ H, I] = [H, N] * [N, I] for AddTo
Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(H, I));
linalg_gemm(d_dart, d_xt, d_back_dwx, alpha, beta, true, false);
}
linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true);
}
}
if (req_params != kNullOp) {
// dbx = e * da [1, H] = [1, N] * [N, H]
if (req_params != kAddTo) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H; ++i) {
for (int j = 0; j < N * T; ++j) {
back_dbx[i] += dar[j * H + i];
back_dbh[i] = back_dbx[i];
}
}
} else {
const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, T));
const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + H * T, Shape2(H, T));
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H * T; ++i) {
tmp_dbx.dptr_[i] = 0;
tmp_dbh.dptr_[i] = 0;
}
for (int t = T - 1; t >= 0; --t) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H; ++i) {
for (int j = 0; j < N; ++j) {
tmp_dbx[i][t] += dar[t * N * H + j * H + i];
tmp_dbh[i][t] = tmp_dbx[i][t];
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H; ++i) {
back_dbx[i] += tmp_dbx[i][t] + back_dbx[i];
back_dbh[i] = back_dbx[i];
}
}
}
}
alpha = 1.0;
beta = 1.0;
// dxt = da * wx [T * N, I] = [T * N, H] * [H, I]
Tensor<cpu, 2, DType> d_dar2(dar, Shape2(T * N, H));
if (req_data != kNullOp) {
Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
linalg_gemm(d_dar2, back_wx, d_dx, alpha, beta, false, false);
}
alpha = 1.0;
beta = 0.0;
// dwx = da.T * x [H, I] = [H, T * N] * [T * N, I]
if (req_params != kNullOp && req_params != kAddTo) {
Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(H, I));
linalg_gemm(d_dar2, x, d_back_dwx, alpha, beta, true, false);
}
}
if (req_state != kNullOp) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N * H * D; ++i) {
dhx[i] = dht1[i];
}
}
}