in recipes/joint_training_vox_populi/cpc/CPCSpecAugment.cpp [46:91]
fl::Variable CPCSpecAugment::maskFunction(
const fl::Variable& input,
const fl::Variable& mask_emb,
double mask_prob,
int mask_length,
int dim) {
int T = input.dims(dim);
int N = input.dims(2);
int numMask = (mask_prob * T) / mask_length;
auto mask = af::constant(0., af::dim4(T, N), f32);
for (int i = 0; i < N; i++) {
for (int j = 0; j < numMask; j++) {
int startIdx = generateRandomInt(0, T);
int endIdx = std::min(startIdx + mask_length - 1, T - 1);
mask(af::seq(startIdx, endIdx), i) = 1.;
}
}
// restrict by min len
int minLen = af::min<int>(af::sum(mask, 0));
auto maskMinLen = af::constant(0., af::dim4(T, N), f32);
for (int i = 0; i < N; i++) {
auto maskIdx = af::where(mask(af::span, i));
auto tmp = af::randu(maskIdx.dims(0));
af::array val, idx;
af::sort(val, idx, tmp);
idx = idx(af::seq(0, minLen - 1));
maskIdx = maskIdx(idx);
maskMinLen(maskIdx, i) = 1.;
}
mask = maskMinLen * 1;
if (dim == 0) {
mask = af::moddims(mask, af::dim4(T, 1, N));
} else {
mask = af::moddims(mask, af::dim4(1, T, N));
}
auto totalMask = tileAs(fl::Variable(mask, false), input.dims());
auto maskEmbedding = tileAs(mask_emb, input.dims());
auto inputMasked =
input.as(f32) * (1 - totalMask) + maskEmbedding * totalMask;
return inputMasked.as(input.type());
}