recipes/joint_training_vox_populi/cpc/CPCSpecAugment.h (51 lines of code) (raw):
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <random>
#include "flashlight/fl/nn/nn.h"
namespace w2l {
/**
* Implementation of CPCSpecAugment: A Simple Data Augmentation Method
* for Automatic Speech Recognition - https://arxiv.org/pdf/1904.08779.pdf
*
* We assume time axis is the 0th dimension, and freq axis is the 1st dimension
* for the input array
*
* Example policies tWarpW fMaskF nFMask tMaskT tMaskP nTMask
* LibriSpeech basic (LB) 80 27 1 100 1.0 1
* LibriSpeech double (LD) 80 27 2 100 1.0 2
* Switchboard mild (SM) 40 15 2 70 0.2 2
* Switchboard strong (SS) 40 27 2 70 0.2 2
**/
class CPCSpecAugment : public fl::UnaryModule {
public:
enum class MaskingStrategy {
ZERO = 0,
GLOBAL_MEAN = 1,
// TODO - add support for mean along time, freq axes
};
CPCSpecAugment(
int tWarpW,
int fMaskF,
int nFMask,
int tMaskT,
float tMaskP,
int nTMask,
MaskingStrategy mStrategy = MaskingStrategy::ZERO);
fl::Variable forward(const fl::Variable& input) override;
void setMaskEmbedding(const fl::Variable& input);
FL_SAVE_LOAD_WITH_BASE(
fl::UnaryModule,
timeWarpW_,
freqMaskF_,
numFreqMask_,
timeMaskT_,
timeMaskP_,
numTimeMask_,
maskStrategy_)
std::string prettyString() const override;
private:
// Time Warping - NOT SUPPORTED CURRENTLY
// Use timeWarpW_ = 0 to disable this
int timeWarpW_;
// Frequency Masking
// Use freqMaskF_ = 0 to disable this
int freqMaskF_;
int numFreqMask_;
// Time Masking
// Use timeMaskT_ = 0 to disable this
int timeMaskT_;
float timeMaskP_;
int numTimeMask_;
std::mt19937 eng_{0};
MaskingStrategy maskStrategy_;
fl::Variable mask_emb_;
fl::Variable maskFunction(
const fl::Variable& input,
const fl::Variable& mask_emb,
double mask_prob,
int mask_length,
int dim);
int generateRandomInt(int low, int high);
CPCSpecAugment() = default;
};
} // namespace w2l
CEREAL_REGISTER_TYPE(w2l::CPCSpecAugment)