recipes/joint_training_vox_populi/cpc/CPCCriterion.cpp (177 lines of code) (raw):

/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT-style license found in the * LICENSE file in the root directory of this source tree. */ #include "CPCCriterion.h" #include <algorithm> #include <iostream> #include <numeric> #include <queue> #include <vector> #include "flashlight/pkg/speech/criterion/CriterionUtils.h" using namespace fl; namespace w2l { void PartialLoading( int n_layers, std::shared_ptr<fl::Sequential> net0, std::shared_ptr<fl::Sequential> net) { auto modules_0 = net0->modules(); if (n_layers < 0) { n_layers = modules_0.size() + n_layers; } for (int i = 0; i < n_layers; i++) { net->add(modules_0[i]); } } CPCCriterion::CPCCriterion( int nEncoder, int nContext, int nMutual, int nOffset, int nUnits, int nPieces, int nNegative, int nBuffer, float temperature) : nEncoder_(nEncoder), nContext_(nContext), nMutual_(nMutual), nOffset_(nOffset), nUnits_(nUnits), nPieces_(nPieces), nNegative_(nNegative), nBuffer_(nBuffer), temperature_(temperature) { params_.push_back(uniform(nEncoder_, 1, -1, 1)); // linear layers for computing mutual information between // encoder and context features add(std::make_shared<Linear>(nEncoder_, nMutual_)); add(std::make_shared<Linear>(nContext_, nMutual_)); } af::array CPCCriterion::getRandomIntegers(int N) { auto rnd = af::randu(N) * 100000; return rnd.as(s64); } af::array shift_non_circular(const af::array& inp, int delta) { int T = inp.dims(1); int abs_delta = delta; if (delta < 0) { abs_delta = -delta; } auto pad_inp = af::pad( inp, af::dim4(0, abs_delta, 0, 0), af::dim4(0, abs_delta, 0, 0), AF_PAD_ZERO); auto out = af::shift(pad_inp, 0, delta); out = out(af::span, af::seq(abs_delta, T + abs_delta - 1)); return out; } // apply mask on input using supplied masking parameters // also internally store masking indices Variable CPCCriterion::getMask(const Variable& input, float mask_prob, int mask_length) { int C = input.dims(0); int T = input.dims(1); int N = input.dims(2); af::dim4 maskDims = af::dim4(1, T, N); af::array randMatrix; af::array midMask, sumMask; Variable totalMask; int numMask = (mask_prob * T); midMask = af::constant(0., af::dim4(1, T, N), f32); for (int i = 0; i < N; i++) { auto startIdx = getRandomIntegers(numMask) % T; midMask(af::span, startIdx, i) = 1.; } sumMask = midMask * 1.; int delta; int pow = 1; for (int i = 1; i < mask_length; i++) { delta = pow * (i + 1) / 2; sumMask = sumMask + shift_non_circular(midMask, delta); pow *= -1; } sumMask = af::min(sumMask, 1.); // masked_ = Variable(midMask*1., false); auto mask = af::moddims(sumMask, af::dim4(T, N)); // restrict masking by min len across batches int minLen = af::min<int>(af::sum(mask, 0)); auto maskMinLen = af::constant(0., af::dim4(T, N), f32); for (int i = 0; i < N; i++) { auto maskIdx = af::where(mask(af::span, i)); auto tmp = af::randu(maskIdx.dims(0)); af::array val, idx; af::sort(val, idx, tmp); idx = idx(af::seq(0, minLen - 1)); maskIdx = maskIdx(idx); maskMinLen(maskIdx, i) = 1.; } mask = af::moddims(maskMinLen, af::dim4(1, T, N)); masked_ = Variable(mask, false); totalMask = tileAs(Variable(mask, false), input.dims()); auto maskEmbedding = tileAs(params_[0], input.dims()); auto inputMasked = input * (1 - totalMask) + maskEmbedding * totalMask; return inputMasked; } Variable CPCCriterion::getNegativeSamples(const Variable& inp) { int C = inp.dims(0); int T = inp.dims(1); int N = inp.dims(2); int nNeg = T; // int nBuff = nBuffer_; int nBuff = 1; if (nNeg > nNegative_) nNeg = nNegative_; // exclude current position with window auto time_idx = af::range(af::dim4(T, N, nNeg), 0, s64); auto min_idx = af::min(T, 1 + nBuff + time_idx); auto max_idx = af::max(T, T - nBuff + time_idx); auto mod_idx = max_idx - min_idx; auto rnd_idx = af::moddims(getRandomIntegers(T * N * nNeg), af::dim4(T, N, nNeg)) % mod_idx; time_idx = (min_idx + rnd_idx) % T; // current sequence only auto batch_idx = af::range(af::dim4(T, N, nNeg), 1, s64); batch_idx = batch_idx % N; auto idx = af::flat(time_idx + batch_idx * T); auto out = moddims(inp, af::dim4(C, T * N)); out = out(af::span, idx); out = moddims(out, af::dim4(C, T, N, nNeg)); return out; } /* C = number of channels * T = number of time frames * N = number of elements in batch */ std::vector<Variable> CPCCriterion::forward( const std::vector<Variable>& inputs) { // enc_out, context = C T N 1 const auto& samples = inputs[0]; const auto& context_mask = inputs[1]; int N = context_mask.dims(2); int T = context_mask.dims(1); int C = context_mask.dims(0); std::vector<Variable> loss; auto mask = masked_.array(); for (int i = 0; i < N; i++) { auto batch_mask = af::where(af::flat(mask(af::span, af::span, i))).as(s64); auto anchor = mutualLinear(1) ->forward(context_mask(af::span, batch_mask, i, af::span, true)) .as(f32); auto pos_samples = mutualLinear(0) ->forward(samples(af::span, batch_mask, i, af::span, true)) .as(f32); anchor = anchor / tileAs(norm(anchor, {0}), anchor.dims()); pos_samples = pos_samples / tileAs(norm(pos_samples, {0}), pos_samples.dims()); auto neg_samples = getNegativeSamples(pos_samples); auto anchor_neg = tileAs(anchor, neg_samples.dims()); pos_samples = sum(pos_samples * anchor, {0}) / temperature_; neg_samples = sum(neg_samples * anchor_neg, {0}) / temperature_; auto all_samples = concatenate({pos_samples, neg_samples}, 3); auto max_samples = Variable(af::max(all_samples.array(), 3), false); auto sum_samples = sum(exp(all_samples - tileAs(max_samples, all_samples)), {3}); loss.push_back( sum((max_samples + log(sum_samples) - pos_samples), {1}) / anchor.dims(1)); } return {reorder(concatenate(loss, 2), 2, 0, 1, 3)}; } af::array CPCCriterion::viterbiPath( const af::array& input, const af::array& inputSize) { std::cout << "Should not be here" << std::endl; exit(1); return input * 0; } // namespace w2l std::string CPCCriterion::prettyString() const { return "CPCCriterion"; } } // namespace w2l