in src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNLSTM.java [60:262]
public static void lstmTile(int n, int d, int T, int m, int start, int end, MatrixBlock x, MatrixBlock w,
MatrixBlock bias, MatrixBlock out0, MatrixBlock c0, boolean return_sequences,
MatrixBlock out, MatrixBlock cout, MatrixBlock cache_out, MatrixBlock cache_c, MatrixBlock cache_ifog){
//inputs arrays
double[] c_0_values = c0.getDenseBlockValues();
double[] bias_values = bias.getDenseBlockValues();
double[] out0_values = out0.getDenseBlockValues();
double[] w_values = w.getDenseBlockValues();
double[] x_values = x.getDenseBlockValues();
double[] out_values = out.getDenseBlockValues();
double[] cout_values = cout.getDenseBlockValues();
double[] cache_out_values = cache_out.getDenseBlockValues();
double[] cache_c_values = cache_c.getDenseBlockValues();
double[] cache_ifog_values = cache_ifog.getDenseBlockValues();
int c_prev_pointer;
//constants
final boolean biasAllocated = bias.isAllocated();
final boolean xAllocated = x.isAllocated();
final boolean wAllocated = w.isAllocated();
final int tile_size_i = row_tile_size;
final int tile_size_j = 32;
final int tile_size_k = 1024;
final int m_4 = 4*m;
final int m_T = T*m;
int[] pos_in_x = new int[tile_size_i];
int pos_in_sequence;
double[] ifog = new double[tile_size_i*4*m];
KahanObject kbuff[] = kahan ? new KahanObject[tile_size_i*4*m] : null;
if(kahan)
for (int i = 0; i < tile_size_i*4*m; i++)
kbuff[i] = new KahanObject(0,0);
KahanPlus kplus = kahan ? getKahanPlusFnObject() : null;
double[] out_prev_values = null;
double[] c_prev_values = null;
for( int bi = start; bi < end; bi+=tile_size_i ) {
int bimin = Math.min(end, bi + tile_size_i);
//init out_prev
if (out0_values != null) {
if (out_prev_values == null)
out_prev_values = new double[m * tile_size_i];
for (int i = bi, i_internal = 0; i < bimin; i++, i_internal++) {
c_prev_pointer = i * m;
for (int j = 0; j < m; j++)
out_prev_values[j + i_internal * m] = out0_values[c_prev_pointer + j];
}
} else
out_prev_values = new double[m * tile_size_i];
//init c_prev
if (c_0_values != null) {
if (c_prev_values == null)
c_prev_values = new double[m * tile_size_i];
for (int i = bi, i_internal = 0; i < bimin; i++, i_internal++) {
c_prev_pointer = i * m;
for (int j = 0; j < m; j++)
c_prev_values[j + i_internal * m] = c_0_values[c_prev_pointer + j];
}
} else
c_prev_values = new double[m * tile_size_i];
//calculate position of input token sequence for all rows in tile
for (int i = bi, i_internal = 0; i < bimin; i++, i_internal++) {
pos_in_x[i_internal] = i * x.getNumColumns();
}
//iterate timesteps
for (int t = 0; t < T; t++) {
pos_in_sequence = t * d;
int offset_t_internal = t*m;
int offset_t = offset_t_internal*n;
int offset_t2 = offset_t*4;
//init ifog with bias values
for (int j = 0; j < 4 * m; j++) {
//for all rows in the row tile
for (int i = bi, i_internal = 0; i < bimin; i++, i_internal++) {
if(kahan)
kbuff[j + i_internal * m_4].set(biasAllocated ? bias_values[j] : 0.0, 0.0);
else
ifog[j + i_internal * m_4] = biasAllocated ? bias_values[j] : 0.0;
}
}
//iterate input token tiles
if(xAllocated)
for (int bj = 0; bj < d; bj += tile_size_j)
//iterate weight tiles
if(wAllocated)
for (int bk = 0, bjmin = Math.min(d, bj + tile_size_j); bk < m_4; bk += tile_size_k) {
int bkmin = Math.min(m_4, bk + tile_size_k);
//core loop: adds the input token to the ifog-gates
for (int i = bi, i_internal = 0; i < bimin; i++, i_internal++) {
int pos_internal_ifog_i = i_internal * m_4;
int pos = pos_in_x[i_internal] + pos_in_sequence;
for (int j = bj; j < bjmin; j++) {
int offset_w = j * 4 * m;
int offset_x = pos + j;
for (int k = bk; k < bkmin; k++) {
if (kahan)
kplus.execute2(kbuff[pos_internal_ifog_i + k], x_values[offset_x] * w_values[k + offset_w]);
else
ifog[pos_internal_ifog_i + k] += x_values[offset_x] * w_values[k + offset_w];
}
}
}
}
//iterate hidden state tiles
for (int bj = 0; bj < m; bj += tile_size_j)
//iterate weight tiles
if(wAllocated)
for (int bk = 0, bjmin = Math.min(m, bj + tile_size_j); bk < 4 * m; bk += tile_size_k) {
int bkmin = Math.min(4 * m, bk + tile_size_k);
//core loop: adds the hidden state to the ifog-gates
for (int i = bi, i_internal = 0; i < bimin; i++, i_internal++) {
int offset_out_prev = i_internal * m;
int offset_internal = offset_out_prev*4;
for (int j = bj; j < bjmin; j++){
int offset_tmp = (j + d) * m_4;
for (int k = bk; k < bkmin; k++){
int offset_w = k + offset_tmp;
if(kahan)
kplus.execute2(kbuff[offset_internal + k], out_prev_values[offset_out_prev + j] * w_values[offset_w]);
else
ifog[offset_internal + k] += out_prev_values[offset_out_prev + j] * w_values[offset_w];
}
}
}
}
//calculate new hidden state for the current tile
for (int i = bi, i_internal = 0; i < bimin; i++, i_internal++) {
//from now on only elementwise operations
//calculate index offset for array operations
int offset_internal_i = i_internal * 4 * m;
int offset_internal_f = offset_internal_i + m;
int offset_internal_o = offset_internal_f + m;
int offset_internal_g = offset_internal_o + m;
int offset_c_internal = i_internal * m;
int offset_out = i*m_T + offset_t_internal;
int offset_i = i*m;
int offset_cache = offset_t + offset_i;
int offset_cache_i = offset_t2 + offset_i*4;
int offset_cache_f = offset_cache_i + m;
int offset_cache_o = offset_cache_f + m;
int offset_cache_g = offset_cache_o + m;
for (int j = 0; j < m; j++) {
double ig, fg, og,gg;
if(kahan){
ig = 1.0 / (FastMath.exp(-kbuff[offset_internal_i + j]._sum) + 1.0);
fg = 1.0 / (FastMath.exp(-kbuff[offset_internal_f + j]._sum) + 1.0);
og = 1.0 / (FastMath.exp(-kbuff[offset_internal_o + j]._sum) + 1.0);
gg = FastMath.tanh(kbuff[offset_internal_g + j]._sum);
} else{
ig = 1.0 / (FastMath.exp(-ifog[offset_internal_i + j]) + 1.0);
fg = 1.0 / (FastMath.exp(-ifog[offset_internal_f + j]) + 1.0);
og = 1.0 / (FastMath.exp(-ifog[offset_internal_o + j]) + 1.0);
gg = FastMath.tanh(ifog[offset_internal_g + j]);
}
//c_prev_values.shape = (N,M)
double c = c_prev_values[offset_c_internal + j] * fg + ig * gg;
double o = FastMath.tanh(c) * og;
//out.shape = (N,T*M)
if (return_sequences)
out_values[offset_out + j] = o;
//out.setValue(i, t * m + j, o);
//set caches
cache_out_values[offset_cache + j] = o;
cache_c_values[offset_cache + j] = c;
cache_ifog_values[offset_cache_i + j] = ig;
cache_ifog_values[offset_cache_f + j] = fg;
cache_ifog_values[offset_cache_o + j] = og;
cache_ifog_values[offset_cache_g + j] = gg;
c_prev_values[offset_c_internal + j] = c;
out_prev_values[offset_c_internal + j] = o;
}
}
}
for (int i = bi, i_internal = 0; i < bimin; i++, i_internal++) {
int offset_i = i*m;
for (int j = 0; j < m; j++) {
cout_values[offset_i + j] = c_prev_values[i_internal * m + j];
if (!return_sequences)
out_values[offset_i + j] = out_prev_values[i_internal * m + j];
}
}
}
}