void compute_betas_and_grad_kernel()

in src/operator/contrib/ctc_include/detail/gpu_ctc_kernels.h [220:452]


void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes,
                                    const int *utt_length, const int *repeats_in_labels,
                                    const int *labels_with_blanks, ProbT *alphas,
                                    const ProbT* nll_forward, ProbT *nll_backward,
                                    ProbT *grads, int stride, int out_dim,
                                    int S_memoffset, int T_memoffset, int blank_label) {

    ctc_helper::log_plus<ProbT> log_plus_f;
    typedef CTASegReduce<NT, VT, ProbT, int, ctc_helper::log_plus<ProbT>> SegReduce;

    const int tid = threadIdx.x;
    const int L = label_sizes[blockIdx.x];
    const int T = utt_length[blockIdx.x];
    const int S = 2*L + 1;
    const int prob_offset = out_dim * blockIdx.x;
    const int repeats = repeats_in_labels[blockIdx.x];
    const ProbT log_partition = -nll_forward[blockIdx.x];

    const int* labels = labels_with_blanks;
    const int* label_global = &labels[blockIdx.x * S_memoffset];
    ProbT* alpha = &alphas[blockIdx.x * (S_memoffset * T_memoffset)];

    const int NV = NT * VT;

    union TempStorage {
        ProbT beta[NV];
        int result[NV];
    };

    __shared__ TempStorage temp_buffer;

    __shared__ int label[NV];

    // Temporaries needed for segmented reduce
    // TODO: see if we can combine the shared memory requirements
    __shared__ int keys_shared[NV];
    __shared__ int gather_indices[NV];
    __shared__ ProbT output[NV];

    ProbT beta_val[VT];

    if ((L + repeats) > T)
        return;

    int start = S > 1 ? (S - 2) : 0;
    int end = (L + repeats < T) ? S : S-1;

    // Setup shared memory buffers
    #pragma unroll
    for (int idx = tid; idx < NV; idx += NT) {
        label[idx] = (idx < S) ? label_global[idx] : INT_MAX;
    }

    __syncthreads();

    // int flags;
    int uniquelabels;
    int seg_start[VT];
    int seg_end[VT];

    // Sort labels and record indices from which to gather from
    {
        int key[VT];
        int gather_val[VT];

        #pragma unroll
        for (int i = 0; i < VT; ++i) {
            const int idx = tid * VT + i;
            gather_val[i] = idx;
            key[i] = label[idx];
        }

        __syncthreads();

        CTAMergesort<NT, VT, true, true, int, int, mgpu::less<int>>
            (key, gather_val, keys_shared, gather_indices, S, tid, mgpu::less<int>());

        __syncthreads();

        for (int i = 0; i < VT; ++i) {
            const int idx = tid * VT + i;
            gather_indices[idx] = gather_val[i];
        }

        __syncthreads();

        SegReduce::preprocessKeys(keys_shared, S, &uniquelabels, seg_start, seg_end,
                                  temp_buffer.result);
        __syncthreads();
    }

    // TODO: probably not necessary
    __syncthreads();

    // Load labels back
    #pragma unroll
    for (int idx = tid; idx < NV; idx += NT) {
        temp_buffer.beta[idx] = ctc_helper::neg_inf<ProbT>();
    }
    __syncthreads();

    // Initialize the two rightmost values in the last row (assuming L non-zero)
    for(int i = tid; i < (end-start); i += blockDim.x)
        temp_buffer.beta[i + start] =
            log(probs[prob_offset + (T - 1) * (out_dim * stride) + label[i + start]]);

    __syncthreads();

    // Load output data in registers through the transpose trick - should really be a function
    #pragma unroll
    for (int idx = tid; idx < S; idx += NT) {
        output[idx] = alpha[idx + (T - 1) * S] + temp_buffer.beta[idx];
    }

    __syncthreads();

    // Start at the second to last row and backward in time
    for(int t = T - 1; t >= 0; --t) {

        // Start offsets into the current and next row
        const int start_cur_row = t * S;

        // Starting offset of column that we read from the probs array
        const int start_prob_col = t * (out_dim * stride);

        if (t < T-1) {

            // Filling up one row at at time but going back in time from the last row
            // to the first. As in the forward pass, there is no loop dependence and we
            // do a variable length filter of maximum filter size of 3
            #pragma unroll
            for(int idx = tid, i = 0; idx < (S-1); idx += NT, i++) {
                ProbT next_sum = log_plus_f(temp_buffer.beta[idx], temp_buffer.beta[idx+1]);

                    // Skip two if not on blank and not on repeat.
                if ((label[idx] != blank_label) &&
                    (idx != (S-2)) && (label[idx] != label[idx+2]))
                    next_sum = log_plus_f(next_sum, temp_buffer.beta[idx+2]);

                beta_val[i] = next_sum + log(probs[prob_offset + start_prob_col + label[idx]]);
            }

            __syncthreads();

            // Initialize values for the rightmost column since there is nothing to the right
            // Update input buffer for next iteration
            if ((tid == 0) && (end == S))
                temp_buffer.beta[(S-1)] = temp_buffer.beta[(S-1)] +
                                          log(probs[prob_offset + start_prob_col + blank_label]);

            #pragma unroll
            for(int idx = tid, i = 0; idx < (S-1); idx += NT, i++) {
               temp_buffer.beta[idx] = beta_val[i];
            }

            __syncthreads();

            // Beta Computation done - add to alpha and update the gradient. Reload
            // the gradient back for segmented reduce later on
            #pragma unroll
            for(int idx = tid; idx < S; idx += NT) {
               output[idx] = alpha[idx + start_cur_row] + temp_buffer.beta[idx];
            }

            __syncthreads();

        }

        __syncthreads();

        // Compute segmented reduction of output by using label as key
        {
            // Somewhat faster key value reduce
            ProbT accum[VT];

            for (int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) {

                accum[j] = ctc_helper::neg_inf<ProbT>();
                for (int i = seg_start[j]; i <= seg_end[j]; ++i) {
                    accum[j] = log_plus_f(accum[j], output[gather_indices[i]]);
                }
            }
            __syncthreads();

            // Write accumulated value into output since that is not used
            for (int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) {
                output[idx] = accum[j];
            }
            __syncthreads();

            for (int idx = tid; idx < out_dim; idx += blockDim.x) {
                const int grads_offset = prob_offset + start_prob_col + idx;
                grads[grads_offset] = probs[grads_offset];
            }

            __syncthreads();

            for (int idx = tid; idx < uniquelabels; idx += blockDim.x) {
                const int grads_offset = prob_offset + start_prob_col + keys_shared[idx];

                ProbT grad = output[idx];

                if ((grad == 0.0) || (probs[grads_offset] == 0.0) ||
                    (grad == ctc_helper::neg_inf<ProbT>())) {
                } else {
                    grads[grads_offset] =
                        probs[grads_offset] - exp(grad - log(probs[grads_offset]) - log_partition);
                }
            }

            __syncthreads();
        }

        // Output backward log likelihood
        if ((t == 0) && (tid == 0)) {
            ProbT loglike = ctc_helper::neg_inf<ProbT>();

            const int val = 2 * (L-1) + 1 - (((L + repeats) == T) ? 1 : 0);

            start = (-val * (L != 0) + start);
            end = (-val * (L != 0) + end);

            // Sum and return the leftmost one/two value(s) in first row
            for(int i = start; i < end; ++i)
                loglike = log_plus_f(loglike, temp_buffer.beta[i]);

            nll_backward[blockIdx.x] = -loglike;
        }

        // For some reason this is important
        __syncthreads();
    }
}