recipes/joint_training_vox_populi/cpc/TransformerCPC.cpp (150 lines of code) (raw):

/* * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT-style license found in the * LICENSE file in the root directory of this source tree. */ #include "TransformerCPC.h" #include "flashlight/fl/autograd/Functions.h" #include "flashlight/fl/contrib/modules/Transformer.h" #include "flashlight/fl/nn/Init.h" #include "flashlight/fl/nn/Utils.h" namespace { fl::Variable transformerInitLinear(int32_t inDim, int32_t outDim, float gain = 1.0) { // float std = std::sqrt(1.0 / float(inDim)); float std = gain * std::sqrt(6.0 / (float(inDim) + float(outDim))); return fl::uniform(outDim, inDim, -std, std, af::dtype::f32, true); } fl::Variable transformerInitLinearBias(int32_t inDim, int32_t outDim, bool zero = false) { float std = std::sqrt(1.0 / float(inDim)); if (zero) { std = 0; } return fl::uniform(af::dim4(outDim), -std, std); } } // namespace namespace w2l { namespace cpc { TransformerCPC::TransformerCPC( int32_t modelDim, int32_t headDim, int32_t mlpDim, int32_t nHeads, int32_t bptt, float pDropout, float pLayerdrop, bool useMask, bool preLN, double layerNormEps) : nHeads_(nHeads), bptt_(bptt), pDropout_(pDropout), pLayerdrop_(pLayerdrop), useMask_(useMask), preLN_(preLN), layerNormEps_(layerNormEps), w1_(std::make_shared<Linear>(modelDim, mlpDim)), w2_(std::make_shared<Linear>(mlpDim, modelDim)), wq_(std::make_shared<Linear>( transformerInitLinear(modelDim, headDim * nHeads, 0.707), transformerInitLinearBias(modelDim, headDim * nHeads))), wk_(std::make_shared<Linear>( transformerInitLinear(modelDim, headDim * nHeads, 0.707), transformerInitLinearBias(modelDim, headDim * nHeads))), wv_(std::make_shared<Linear>( transformerInitLinear(modelDim, headDim * nHeads, 0.707), transformerInitLinearBias(modelDim, headDim * nHeads))), wf_(std::make_shared<Linear>( transformerInitLinear(headDim * nHeads, modelDim), transformerInitLinearBias(headDim * nHeads, modelDim, true))), norm1_( std::make_shared<LayerNorm>(std::vector<int>({0, 3}), layerNormEps_)), norm2_(std::make_shared<LayerNorm>( std::vector<int>({0, 3}), layerNormEps_)) { if (bptt > 0) { params_.push_back( uniform(2 * bptt - 1, headDim, -0.1, 0.1, af::dtype::f32, true)); } add(w1_); add(w2_); add(wq_); add(wk_); add(wv_); add(wf_); add(norm1_); add(norm2_); } Variable TransformerCPC::mlp(const Variable& input) { float pDropout = train_ ? pDropout_ : 0.0; // return (*w2_)(dropout(relu((*w1_)(input)), pDropout)); return (*w2_)(dropout(relu((*w1_)(input)), 0.0)); } Variable TransformerCPC::getMask(int32_t n, bool cache) { auto mask = af::lower(af::constant(1.0, n, n), true); if (cache) { auto maskCache = af::upper(af::constant(1.0, n, n)); mask = af::join(1, maskCache, mask); } return Variable(af::log(mask), false); } Variable TransformerCPC::selfAttention(const std::vector<Variable>& input) { // previous step[optionally], input, padMask auto encoderInput = input.at(input.size() - 2); // in case of previous state input[0] has size CxT_prevxB int n = input[0].dims(1), bsz = input[0].dims(2); double pDrop = train_ ? pDropout_ : 0.0; auto q = transpose((*wq_)(encoderInput)); std::vector<fl::Variable> inputWithState(input.begin(), input.end() - 1); auto k = transpose((*wk_)(concatenate(inputWithState, 1))); auto v = transpose((*wv_)(concatenate(inputWithState, 1))); q = q / std::sqrt(float(q.dims(1) / nHeads_)); Variable mask, posEmb; if (bptt_ > 0) { posEmb = tile(params_[0].as(encoderInput.type()), af::dim4(1, 1, nHeads_ * bsz)); } if (useMask_ && encoderInput.dims(1) > 1) { // mask future if we use the previous state (then n is previous time) mask = getMask(n, input.size() == 3); } int offset = (input.size() == 2) ? 0 : n; // time x batch fl::Variable padMask; if (!input.back().isempty()) { auto padMaskArr = input.back().array(); padMaskArr = af::resize(padMaskArr, encoderInput.dims(1), encoderInput.dims(2)); padMask = fl::Variable(af::log(padMaskArr), false); } auto result = multiheadAttention( q, k, v, posEmb, mask, padMask, nHeads_, pDrop, offset); result = (*wf_)(transpose(result)); return result; } std::vector<Variable> TransformerCPC::forward( const std::vector<Variable>& input) { // previous step[optionally], input, padMask // padMask should be empty if previous step is provided // padMask is expected to have "1" on the used positions and "0" on padded // positions if (input.size() < 2) { throw std::invalid_argument( "Invalid inputs for transformer block: there should be at least input and mask"); } auto x = input.at(input.size() - 2); if (!input.back().isempty() && x.dims(2) != input.back().dims(1)) { throw std::invalid_argument( "Invalid inputs for transformer block: input and Mask batch sizes are different"); } float f = 1.0; if (train_ && (af::randu(1).scalar<float>() < pLayerdrop_)) { f = 0.0; } if (preLN_) { auto h = (f * (*norm1_)(selfAttention(input))).as(x.type()) + x; return {f * (*norm2_)(mlp(h)).as(h.type()) + h}; } else { auto h = (*norm1_)((f * selfAttention(input)).as(x.type()) + x); return {(*norm2_)((f * mlp(h)).as(h.type()) + h)}; } } std::string TransformerCPC::prettyString() const { std::ostringstream ss; ss << "Transformer (nHeads: " << nHeads_ << "), " << "(pDropout: " << pDropout_ << "), " << "(pLayerdrop: " << pLayerdrop_ << "), " << "(bptt: " << bptt_ << "), " << "(useMask: " << useMask_ << "), " << "(preLayerNorm: " << preLN_ << ")"; return ss.str(); } TransformerCPC::TransformerCPC() {} } // namespace cpc } // namespace w2l