recipes/joint_training_vox_populi/cpc/CPCCriterion.h (71 lines of code) (raw):
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <memory>
//#include "flashlight/common/FlashlightUtils.h"
#include "flashlight/fl/contrib/modules/modules.h"
#include "flashlight/pkg/speech/criterion/Defines.h"
#include "flashlight/pkg/speech/criterion/SequenceCriterion.h"
namespace w2l {
constexpr const char* kCPCCriterion = "cpc";
namespace detail {}
void PartialLoading(
int n_layers,
std::shared_ptr<fl::Sequential> net0,
std::shared_ptr<fl::Sequential> net);
class CPCCriterion : public fl::pkg::speech::SequenceCriterion {
public:
CPCCriterion(
int nEncoder,
int nContext,
int nMutual,
int nOffset,
int nUnits,
int nPieces,
int nNegative,
int nBuffer,
float temperature);
std::vector<fl::Variable> forward(
const std::vector<fl::Variable>& inputs) override;
af::array viterbiPath(
const af::array& input,
const af::array& inputSize = af::array()) override;
std::string prettyString() const override;
fl::Variable
getMask(const fl::Variable& input, float mask_prob, int mask_length);
float numMask() {
return sum(masked_, {1, 2}).scalar<float>();
}
fl::Variable getMaskEmbedding() {
return params_[0];
}
private:
af::array getRandomIntegers(int N);
fl::Variable getNegativeSamples(const fl::Variable& inp);
std::shared_ptr<fl::Linear> mutualLinear(int k) const {
return std::static_pointer_cast<fl::Linear>(module(k));
}
int nEncoder_;
int nContext_;
int nMutual_;
int nOffset_;
int nUnits_;
int nPieces_;
int nNegative_;
int nBuffer_;
float temperature_;
fl::Variable masked_;
af::array not_masked_;
FL_SAVE_LOAD_WITH_BASE(
SequenceCriterion,
nEncoder_,
nContext_,
nMutual_,
nOffset_,
nUnits_,
nPieces_,
nNegative_,
nBuffer_,
temperature_)
CPCCriterion() = default;
};
} // namespace w2l
CEREAL_REGISTER_TYPE(w2l::CPCCriterion)
CEREAL_CLASS_VERSION(w2l::CPCCriterion, 3)