in src/operator/contrib/ctc_include/detail/gpu_ctc_kernels.h [220:452]
void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes,
const int *utt_length, const int *repeats_in_labels,
const int *labels_with_blanks, ProbT *alphas,
const ProbT* nll_forward, ProbT *nll_backward,
ProbT *grads, int stride, int out_dim,
int S_memoffset, int T_memoffset, int blank_label) {
ctc_helper::log_plus<ProbT> log_plus_f;
typedef CTASegReduce<NT, VT, ProbT, int, ctc_helper::log_plus<ProbT>> SegReduce;
const int tid = threadIdx.x;
const int L = label_sizes[blockIdx.x];
const int T = utt_length[blockIdx.x];
const int S = 2*L + 1;
const int prob_offset = out_dim * blockIdx.x;
const int repeats = repeats_in_labels[blockIdx.x];
const ProbT log_partition = -nll_forward[blockIdx.x];
const int* labels = labels_with_blanks;
const int* label_global = &labels[blockIdx.x * S_memoffset];
ProbT* alpha = &alphas[blockIdx.x * (S_memoffset * T_memoffset)];
const int NV = NT * VT;
union TempStorage {
ProbT beta[NV];
int result[NV];
};
__shared__ TempStorage temp_buffer;
__shared__ int label[NV];
// Temporaries needed for segmented reduce
// TODO: see if we can combine the shared memory requirements
__shared__ int keys_shared[NV];
__shared__ int gather_indices[NV];
__shared__ ProbT output[NV];
ProbT beta_val[VT];
if ((L + repeats) > T)
return;
int start = S > 1 ? (S - 2) : 0;
int end = (L + repeats < T) ? S : S-1;
// Setup shared memory buffers
#pragma unroll
for (int idx = tid; idx < NV; idx += NT) {
label[idx] = (idx < S) ? label_global[idx] : INT_MAX;
}
__syncthreads();
// int flags;
int uniquelabels;
int seg_start[VT];
int seg_end[VT];
// Sort labels and record indices from which to gather from
{
int key[VT];
int gather_val[VT];
#pragma unroll
for (int i = 0; i < VT; ++i) {
const int idx = tid * VT + i;
gather_val[i] = idx;
key[i] = label[idx];
}
__syncthreads();
CTAMergesort<NT, VT, true, true, int, int, mgpu::less<int>>
(key, gather_val, keys_shared, gather_indices, S, tid, mgpu::less<int>());
__syncthreads();
for (int i = 0; i < VT; ++i) {
const int idx = tid * VT + i;
gather_indices[idx] = gather_val[i];
}
__syncthreads();
SegReduce::preprocessKeys(keys_shared, S, &uniquelabels, seg_start, seg_end,
temp_buffer.result);
__syncthreads();
}
// TODO: probably not necessary
__syncthreads();
// Load labels back
#pragma unroll
for (int idx = tid; idx < NV; idx += NT) {
temp_buffer.beta[idx] = ctc_helper::neg_inf<ProbT>();
}
__syncthreads();
// Initialize the two rightmost values in the last row (assuming L non-zero)
for(int i = tid; i < (end-start); i += blockDim.x)
temp_buffer.beta[i + start] =
log(probs[prob_offset + (T - 1) * (out_dim * stride) + label[i + start]]);
__syncthreads();
// Load output data in registers through the transpose trick - should really be a function
#pragma unroll
for (int idx = tid; idx < S; idx += NT) {
output[idx] = alpha[idx + (T - 1) * S] + temp_buffer.beta[idx];
}
__syncthreads();
// Start at the second to last row and backward in time
for(int t = T - 1; t >= 0; --t) {
// Start offsets into the current and next row
const int start_cur_row = t * S;
// Starting offset of column that we read from the probs array
const int start_prob_col = t * (out_dim * stride);
if (t < T-1) {
// Filling up one row at at time but going back in time from the last row
// to the first. As in the forward pass, there is no loop dependence and we
// do a variable length filter of maximum filter size of 3
#pragma unroll
for(int idx = tid, i = 0; idx < (S-1); idx += NT, i++) {
ProbT next_sum = log_plus_f(temp_buffer.beta[idx], temp_buffer.beta[idx+1]);
// Skip two if not on blank and not on repeat.
if ((label[idx] != blank_label) &&
(idx != (S-2)) && (label[idx] != label[idx+2]))
next_sum = log_plus_f(next_sum, temp_buffer.beta[idx+2]);
beta_val[i] = next_sum + log(probs[prob_offset + start_prob_col + label[idx]]);
}
__syncthreads();
// Initialize values for the rightmost column since there is nothing to the right
// Update input buffer for next iteration
if ((tid == 0) && (end == S))
temp_buffer.beta[(S-1)] = temp_buffer.beta[(S-1)] +
log(probs[prob_offset + start_prob_col + blank_label]);
#pragma unroll
for(int idx = tid, i = 0; idx < (S-1); idx += NT, i++) {
temp_buffer.beta[idx] = beta_val[i];
}
__syncthreads();
// Beta Computation done - add to alpha and update the gradient. Reload
// the gradient back for segmented reduce later on
#pragma unroll
for(int idx = tid; idx < S; idx += NT) {
output[idx] = alpha[idx + start_cur_row] + temp_buffer.beta[idx];
}
__syncthreads();
}
__syncthreads();
// Compute segmented reduction of output by using label as key
{
// Somewhat faster key value reduce
ProbT accum[VT];
for (int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) {
accum[j] = ctc_helper::neg_inf<ProbT>();
for (int i = seg_start[j]; i <= seg_end[j]; ++i) {
accum[j] = log_plus_f(accum[j], output[gather_indices[i]]);
}
}
__syncthreads();
// Write accumulated value into output since that is not used
for (int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) {
output[idx] = accum[j];
}
__syncthreads();
for (int idx = tid; idx < out_dim; idx += blockDim.x) {
const int grads_offset = prob_offset + start_prob_col + idx;
grads[grads_offset] = probs[grads_offset];
}
__syncthreads();
for (int idx = tid; idx < uniquelabels; idx += blockDim.x) {
const int grads_offset = prob_offset + start_prob_col + keys_shared[idx];
ProbT grad = output[idx];
if ((grad == 0.0) || (probs[grads_offset] == 0.0) ||
(grad == ctc_helper::neg_inf<ProbT>())) {
} else {
grads[grads_offset] =
probs[grads_offset] - exp(grad - log(probs[grads_offset]) - log_partition);
}
}
__syncthreads();
}
// Output backward log likelihood
if ((t == 0) && (tid == 0)) {
ProbT loglike = ctc_helper::neg_inf<ProbT>();
const int val = 2 * (L-1) + 1 - (((L + repeats) == T) ? 1 : 0);
start = (-val * (L != 0) + start);
end = (-val * (L != 0) + end);
// Sum and return the leftmost one/two value(s) in first row
for(int i = start; i < end; ++i)
loglike = log_plus_f(loglike, temp_buffer.beta[i]);
nll_backward[blockIdx.x] = -loglike;
}
// For some reason this is important
__syncthreads();
}
}