libvmaf/src/feature/cuda/integer_vif/vif_statistics.cuh (148 lines of code) (raw):
/**
*
* Copyright 2016-2023 Netflix, Inc.
* Copyright 2021 NVIDIA Corporation.
*
* Licensed under the BSD+Patent License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSDplusPatent
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#include "cuda_helper.cuh"
#include "cuda/integer_vif_cuda.h"
#include "common.h"
__device__ __forceinline__ uint16_t get_best16_from32(uint32_t temp, int *x) {
int k = __clz(temp);
k = 16 - k;
temp = temp >> k;
*x = -k;
return temp;
}
__device__ __forceinline__ uint16_t get_best16_from64(uint64_t temp, int *x) {
int k = __clzll(temp);
if (k > 48) {
k -= 48;
temp = temp << k;
*x = k;
} else if (k < 47) {
k = 48 - k;
temp = temp >> k;
*x = -k;
} else {
*x = 0;
if (temp >> 16) {
temp = temp >> 1;
*x = -1;
}
}
return (uint16_t)temp;
}
__device__ __forceinline__ uint16_t log_generate(int i) {
// if (i < 32767 || i >= 65536)
// return 0;
return (uint16_t)roundf(log2f(float(i)) * 2048.f);
}
template <typename aligned_dtype = uint4>
__device__ __forceinline__ void vif_statistic_calculation(
const aligned_dtype &mu1, const aligned_dtype &mu2,
const aligned_dtype &xx_filt, const aligned_dtype &yy_filt,
const aligned_dtype &xy_filt, int cur_col, int w, int h,
double vif_enhn_gain_limit, vif_accums &thread_accum) {
// float equivalent of 2. (2 * 65536)
constexpr int32_t sigma_nsq = 65536 << 1;
const uint32_t *mu1_val = reinterpret_cast<const uint32_t *>(&mu1);
const uint32_t *mu2_val = reinterpret_cast<const uint32_t *>(&mu2);
const uint32_t *xx_filt_val = reinterpret_cast<const uint32_t *>(&xx_filt);
const uint32_t *yy_filt_val = reinterpret_cast<const uint32_t *>(&yy_filt);
const uint32_t *xy_filt_val = reinterpret_cast<const uint32_t *>(&xy_filt);
constexpr int aligned_dtype_values = sizeof(aligned_dtype) / sizeof(int32_t);
// calculate thread relative sums for all preloaded values
for (int v = 0; v < aligned_dtype_values; ++v) {
if (cur_col + v < w) {
int64_t num_val, den_val;
uint32_t mu1_sq_val =
(uint32_t)((((uint64_t)mu1_val[v] * mu1_val[v]) + 2147483648) >> 32);
uint32_t mu2_sq_val =
(uint32_t)((((uint64_t)mu2_val[v] * mu2_val[v]) + 2147483648) >> 32);
uint32_t mu1_mu2_val =
(uint32_t)((((uint64_t)mu1_val[v] * mu2_val[v]) + 2147483648) >> 32);
int32_t sigma1_sq = (int32_t)(xx_filt_val[v] - mu1_sq_val);
int32_t sigma2_sq = (int32_t)(yy_filt_val[v] - mu2_sq_val);
int32_t sigma12 = (int32_t)(xy_filt_val[v] - mu1_mu2_val);
sigma1_sq = max(sigma1_sq, 0);
sigma2_sq = max(sigma2_sq, 0);
// eps is zero, an int will not be less then 1.0e-10, it can be
// changed to one
const double eps = 65536 * 1.0e-10;
double g = 0.0;
int32_t sv_sq = sigma2_sq;
// if sigma1_sq > 0 then sigma1_sq >= 1 and thus greater eps => only
// the case sigma1_sq == 0 matters
// as g can only be < 0 if sigma12 is < 0 we can also check for that
double tmp = sigma12 / (sigma1_sq + eps);
if (sigma12 > 0 && sigma1_sq != 0 && sigma2_sq != 0) {
g = tmp;
}
sv_sq = sigma2_sq - g * sigma12;
sv_sq = (uint32_t)(max(sv_sq, (int32_t)eps));
g = min(g, vif_enhn_gain_limit);
if (sigma1_sq >= sigma_nsq) {
uint32_t log_den_stage1 = (uint32_t)(sigma_nsq + sigma1_sq);
int x;
uint16_t log_den1 = get_best16_from32(log_den_stage1, &x);
/**
* log values are taken from the look-up table generated by
* log_generate() function which is called in
* integer_combo_threadfunc den_val in float is log2(1 +
* sigma1_sq/2) here it is converted to equivalent of
* log2(2+sigma1_sq) - log2(2) i.e log2(2*65536+sigma1_sq) - 17
* multiplied by 2048 as log_value = log2(i)*2048 i=16384 to 65535
* generated using log_value x because best 16 bits are taken
*/
thread_accum.num_x++;
thread_accum.x += x;
den_val = log_generate(log_den1);
if (sigma12 >= 0) {
// num_val = log2f(1.0f + (g * g * sigma1_sq) / (sv_sq +
// sigma_nsq));
/**
* In floating-point numerator = log2((1.0f + (g * g *
* sigma1_sq)/(sv_sq + sigma_nsq))
*
* In Fixed-point the above is converted to
* numerator = log2((sv_sq + sigma_nsq)+(g * g * sigma1_sq))-
* log2(sv_sq + sigma_nsq)
*/
int x1, x2;
uint32_t numer1 = (sv_sq + sigma_nsq);
int64_t numer1_tmp =
(int64_t)((g * g * sigma1_sq)) + numer1; // numerator
uint16_t numlog = get_best16_from64((uint64_t)numer1_tmp, &x1);
// we do not check against numer1 > 0 as sv_sq >= and sigma_nsq >
// 0 and therefore the sum is > 0
uint16_t denlog = get_best16_from64((uint64_t)numer1, &x2);
thread_accum.x2 += (x2 - x1);
num_val = log_generate(numlog) - log_generate(denlog);
thread_accum.num_log += num_val;
thread_accum.den_log += den_val;
} else {
num_val = 0;
thread_accum.num_log += num_val;
thread_accum.den_log += den_val;
}
} else {
den_val = 1;
thread_accum.num_non_log += sigma2_sq;
thread_accum.den_non_log += den_val;
}
}
}
}