recipes/joint_training_vox_populi/cpc/CPCSpecAugment.cpp (95 lines of code) (raw):
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <sstream>
#include <stdexcept>
#include "CPCSpecAugment.h"
namespace w2l {
CPCSpecAugment::CPCSpecAugment(
int tWarpW,
int fMaskF,
int nFMask,
int tMaskT,
float tMaskP,
int nTMask,
MaskingStrategy mStrategy /* = MaskingStrategy::ZERO */)
: timeWarpW_(tWarpW),
freqMaskF_(fMaskF),
numFreqMask_(nFMask),
timeMaskT_(tMaskT),
timeMaskP_(tMaskP),
numTimeMask_(nTMask),
maskStrategy_(mStrategy) {
if (numFreqMask_ > 0 && freqMaskF_ <= 0) {
throw std::invalid_argument("invalid arguments for frequency masking.");
}
if (numTimeMask_ > 0 && timeMaskT_ <= 0) {
throw std::invalid_argument("invalid arguments for time masking.");
}
if (numTimeMask_ > 0 && (timeMaskP_ <= 0 || timeMaskP_ > 1.0)) {
throw std::invalid_argument("invalid arguments for time masking.");
}
}
int CPCSpecAugment::generateRandomInt(int low, int high) {
std::uniform_int_distribution<int> uniformDist(low, high - 1);
return uniformDist(eng_);
}
fl::Variable CPCSpecAugment::maskFunction(
const fl::Variable& input,
const fl::Variable& mask_emb,
double mask_prob,
int mask_length,
int dim) {
int T = input.dims(dim);
int N = input.dims(2);
int numMask = (mask_prob * T) / mask_length;
auto mask = af::constant(0., af::dim4(T, N), f32);
for (int i = 0; i < N; i++) {
for (int j = 0; j < numMask; j++) {
int startIdx = generateRandomInt(0, T);
int endIdx = std::min(startIdx + mask_length - 1, T - 1);
mask(af::seq(startIdx, endIdx), i) = 1.;
}
}
// restrict by min len
int minLen = af::min<int>(af::sum(mask, 0));
auto maskMinLen = af::constant(0., af::dim4(T, N), f32);
for (int i = 0; i < N; i++) {
auto maskIdx = af::where(mask(af::span, i));
auto tmp = af::randu(maskIdx.dims(0));
af::array val, idx;
af::sort(val, idx, tmp);
idx = idx(af::seq(0, minLen - 1));
maskIdx = maskIdx(idx);
maskMinLen(maskIdx, i) = 1.;
}
mask = maskMinLen * 1;
if (dim == 0) {
mask = af::moddims(mask, af::dim4(T, 1, N));
} else {
mask = af::moddims(mask, af::dim4(1, T, N));
}
auto totalMask = tileAs(fl::Variable(mask, false), input.dims());
auto maskEmbedding = tileAs(mask_emb, input.dims());
auto inputMasked =
input.as(f32) * (1 - totalMask) + maskEmbedding * totalMask;
return inputMasked.as(input.type());
}
void CPCSpecAugment::setMaskEmbedding(const fl::Variable& input) {
mask_emb_ = input * 1;
}
fl::Variable CPCSpecAugment::forward(const fl::Variable& input) {
if (!train_) {
return input;
}
auto output = maskFunction(input, mask_emb_, timeMaskP_, timeMaskT_, 1);
// output = maskFunction(output, fl::constant(0.0, af::dim4(1)), 0.25, 64, 0);
return output;
}
std::string CPCSpecAugment::prettyString() const {
std::ostringstream ss;
ss << "CPCSpecAugment ( ";
ss << "W: " << timeWarpW_ << ", ";
ss << "F: " << freqMaskF_ << ", ";
ss << "mF: " << numFreqMask_ << ", ";
// ss << "T: " << timeMaskT_ << ", ";
ss << "p: " << timeMaskP_ << ", ";
ss << "mT: " << numTimeMask_;
ss << " )";
return ss.str();
}
} // namespace w2l