in src/operator/rnn_impl.h [328:496]
void LstmBackwardSingleLayer(DType* ws,
DType* rs,
DType* tmp_buf,
bool bid,
const int T,
const int N,
const int I,
const int H,
const Tensor<cpu, 2, DType> &x,
const Tensor<cpu, 2, DType> &hx,
const Tensor<cpu, 2, DType> &cx,
const Tensor<cpu, 3, DType> &y,
const Tensor<cpu, 3, DType> &dy,
const Tensor<cpu, 2, DType> &dx,
const Tensor<cpu, 2, DType> &dhx,
const Tensor<cpu, 2, DType> &dcx,
DType* dhy_ptr,
DType* dcy_ptr,
DType* w_ptr,
DType* dw_ptr,
DType* db_ptr,
int req_data,
int req_params,
int req_state,
int req_statecell) {
using namespace mshadow;
const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I));
const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H));
Tensor<cpu, 2, DType> dwx(dw_ptr, Shape2(H * 4, I));
Tensor<cpu, 2, DType> dwh(dw_ptr + I * H * 4, Shape2(H * 4, H));
Tensor<cpu, 1, DType> dbx(db_ptr, Shape1(H * 4));
Tensor<cpu, 1, DType> dbh(dbx.dptr_ + H * 4, Shape1(H * 4));
DType *c_ptr = bid ? rs + T * N * H * 7 : rs;
const Tensor<cpu, 3, DType> c(c_ptr, Shape3(T, N, H));
const Tensor<cpu, 4, DType> ifgo(c_ptr + T * N * H, Shape4(T, N, H, 4));
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
if (req_params != kNullOp && req_params != kAddTo) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < H * 4 * H; ++i) {
dwh.dptr_[i] = 0;
}
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < 4 * H; ++i) {
dbx.dptr_[i] = 0;
dbh.dptr_[i] = 0;
}
}
Tensor<cpu, 4, DType> difgo(ws, Shape4(T, N, 4, H));
Tensor<cpu, 2, DType> dh(ws + T * N * H * 4, Shape2(N, H));
Tensor<cpu, 2, DType> dc(dh.dptr_ + N * H, Shape2(N, H));
Tensor<cpu, 2, DType> htmp(dc.dptr_ + N * H, Shape2(N, H));
const int offset = bid ? H : 0;
const DType alpha = 1.0;
const DType beta0 = 0.0;
const DType beta1 = 1.0;
const DType beta2 = 2.0;
const int cell_size = N * H;
if (dhy_ptr != NULL) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < cell_size; ++i) {
dh.dptr_[i] = dhy_ptr[i];
}
} else {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < cell_size; ++i) {
dh.dptr_[i] = 0;
}
}
if (dcy_ptr != NULL) {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < cell_size; ++i) {
dc.dptr_[i] = dcy_ptr[i];
}
} else {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < cell_size; ++i) {
dc.dptr_[i] = 0;
}
}
for (int i = T - 1; i >= 0; --i) {
int t = bid ? T - 1 - i : i;
int tnext = bid ? t + 1 : t - 1;
const Tensor<cpu, 2, DType>& dhnext = i ? dh : dhx;
const Tensor<cpu, 2, DType>& dcnext = i ? dc : dcx;
const Tensor<cpu, 2, DType>& hnext = i ? htmp : hx;
const Tensor<cpu, 2, DType>& cnext = i ? c[i - 1] : cx;
#pragma omp parallel for num_threads(omp_threads)
for (int jk = 0; jk < cell_size; ++jk) {
int j = jk / H;
int k = jk % H;
DType tc = tanh(c[i][j][k]);
DType it = ifgo[i][j][k][0];
DType ft = ifgo[i][j][k][1];
DType gt = ifgo[i][j][k][2];
DType ot = ifgo[i][j][k][3];
dh[j][k] += dy[t][j][k + offset];
dc[j][k] += dh[j][k] * ot * (1 - tc * tc);
difgo[t][j][0][k] = dc[j][k] * gt * it * (1 - it);
difgo[t][j][1][k] = dc[j][k] * cnext[j][k] * ft * (1 - ft);
difgo[t][j][2][k] = dc[j][k] * it * (1 - gt * gt);
difgo[t][j][3][k] = dh[j][k] * tc * ot * (1 - ot);
if (req_statecell != kNullOp || i > 0) {
dcnext[j][k] = dc[j][k] * ft;
}
if (i) {
htmp[j][k] = y[tnext][j][k + offset];
}
}
Tensor<cpu, 2, DType> dyh(difgo[t].dptr_, Shape2(N, H * 4));
if (req_state != kNullOp || i > 0) {
linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false);
}
if (req_params != kNullOp) {
if (req_params != kAddTo) {
linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false);
} else {
linalg_gemm(dyh, hnext, dwh, alpha, beta2, true, false);
// generate dwx every time step for AddTo
Tensor<cpu, 2, DType> x_t(x.dptr_ + i * N * I, Shape2(N, I));
Tensor<cpu, 2, DType> dyx_t(difgo.dptr_ + i * N * H * 4, Shape2(N, H * 4));
linalg_gemm(dyx_t, x_t, dwx, alpha, beta2, true, false);
}
}
}
Tensor<cpu, 2, DType> dyx(difgo.dptr_, Shape2(T * N, H * 4));
if (req_data != kNullOp) {
linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false);
}
if (req_params != kNullOp && req_params != kAddTo) {
linalg_gemm(dyx, x, dwx, alpha, beta0, true, false);
}
const int row = T * N;
const int col = H * 4;
if (req_params != kNullOp) {
if (req_params != kAddTo) {
for (int i = 0; i < row; ++i) {
#pragma omp parallel for num_threads(omp_threads)
for (int j = 0; j < col; ++j) {
dbx[j] += dyx[i][j];
dbh[j] = dbx[j];
}
}
} else {
const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf, Shape2(col, T));
const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + col * T, Shape2(col, T));
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < col * T; ++i) {
tmp_dbx.dptr_[i] = 0;
tmp_dbh.dptr_[i] = 0;
}
for (int t = T - 1; t >= 0; --t) {
#pragma omp parallel for num_threads(omp_threads)
for (int j = 0; j < col; ++j) {
for (int i = 0; i < N; ++i) {
tmp_dbx[j][t] += dyx[t * N + i][j];
tmp_dbh[j][t] = tmp_dbx[j][t];
}
}
#pragma omp parallel for num_threads(omp_threads)
for (int j = 0; j < col; ++j) {
dbx[j] += tmp_dbx[j][t] + dbx[j];
dbh[j] += tmp_dbh[j][t] + dbh[j];
}
}
}
}
}