fl::Variable CPCSpecAugment::maskFunction()

in recipes/joint_training_vox_populi/cpc/CPCSpecAugment.cpp [46:91]


fl::Variable CPCSpecAugment::maskFunction(
    const fl::Variable& input,
    const fl::Variable& mask_emb,
    double mask_prob,
    int mask_length,
    int dim) {
  int T = input.dims(dim);
  int N = input.dims(2);

  int numMask = (mask_prob * T) / mask_length;
  auto mask = af::constant(0., af::dim4(T, N), f32);
  for (int i = 0; i < N; i++) {
    for (int j = 0; j < numMask; j++) {
      int startIdx = generateRandomInt(0, T);
      int endIdx = std::min(startIdx + mask_length - 1, T - 1);
      mask(af::seq(startIdx, endIdx), i) = 1.;
    }
  }

  // restrict by min len
  int minLen = af::min<int>(af::sum(mask, 0));
  auto maskMinLen = af::constant(0., af::dim4(T, N), f32);
  for (int i = 0; i < N; i++) {
    auto maskIdx = af::where(mask(af::span, i));
    auto tmp = af::randu(maskIdx.dims(0));
    af::array val, idx;
    af::sort(val, idx, tmp);
    idx = idx(af::seq(0, minLen - 1));
    maskIdx = maskIdx(idx);
    maskMinLen(maskIdx, i) = 1.;
  }
  mask = maskMinLen * 1;

  if (dim == 0) {
    mask = af::moddims(mask, af::dim4(T, 1, N));
  } else {
    mask = af::moddims(mask, af::dim4(1, T, N));
  }

  auto totalMask = tileAs(fl::Variable(mask, false), input.dims());
  auto maskEmbedding = tileAs(mask_emb, input.dims());
  auto inputMasked =
      input.as(f32) * (1 - totalMask) + maskEmbedding * totalMask;

  return inputMasked.as(input.type());
}