recipes/joint_training_vox_populi/cpc/TransformerCPC.h (57 lines of code) (raw):

/* * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include "flashlight/fl/nn/modules/Container.h" #include "flashlight/fl/nn/modules/LayerNorm.h" #include "flashlight/fl/nn/modules/Linear.h" #include "flashlight/fl/nn/modules/Module.h" using namespace fl; namespace w2l { namespace cpc { /** * A module which implements a Transformer. * * For details, see [Vaswani et al * (2017)](https://arxiv.org/abs/1706.03762). * * This module also supports layer drop regularization, as introduced in * [Fan et al (2019)](https://arxiv.org/abs/1909.11556). * * Forward takes {previous step[optionally], input, padMask} * previous step is used in for the decoder phase, previous output with size * CxT'xBx1 Input dimension at forward is assumed to be CxTxBx1, where C is the * number of features, T the sequence length and B the batch size. * padMask is with T''xB sizes (T'' will be af::resize to the input size) * padMask should be empty if "previous step" is provided (in the decoder phase) * padMask is expected to have "1" on the normal positions and "0" on the padded * positions * * @param modelDim input embedding dimension * @param headDim dimension of each head * @param mlpDim dimension of the feed-forward layers * @param nHeads number of heads * @param bptt size for learnt relative positional embedding matrix (2 * bptt - * 1) * nHeads * @param pDropout dropout probability * @param pLayerdrop layer dropout probability * @param useMask mask or not future in the computations * if true then don't use future (for example for autoregressive language models * or for decoder part in the encoder-decoder transformer models) * @param preLN apply layer normalization before or after residual connection */ class TransformerCPC : public Container { public: TransformerCPC( int32_t modelDim, int32_t headDim, int32_t mlpDim, int32_t nHeads, int32_t bptt, float pDropout, float pLayerdrop, bool useMask = false, bool preLN = false, double layerNormEps = 1e-5); std::vector<Variable> forward(const std::vector<Variable>& input) override; std::string prettyString() const override; private: int32_t nHeads_; int32_t bptt_; double pDropout_; double pLayerdrop_; bool useMask_; bool preLN_; double layerNormEps_; std::shared_ptr<Linear> w1_, w2_, wq_, wk_, wv_, wf_; std::shared_ptr<LayerNorm> norm1_, norm2_; Variable mlp(const Variable& input); Variable getMask(int32_t n, bool cache = false); Variable selfAttention(const std::vector<Variable>& input); FL_SAVE_LOAD_WITH_BASE( Container, w1_, w2_, wq_, wk_, wv_, wf_, norm1_, norm2_, nHeads_, pDropout_, pLayerdrop_, bptt_, useMask_, preLN_) TransformerCPC(); }; } // namespace cpc } // namespace w2l CEREAL_REGISTER_TYPE(w2l::cpc::TransformerCPC);