Variable TransformerCPC::selfAttention()

in recipes/joint_training_vox_populi/cpc/TransformerCPC.cpp [104:143]


Variable TransformerCPC::selfAttention(const std::vector<Variable>& input) {
  // previous step[optionally], input, padMask
  auto encoderInput = input.at(input.size() - 2);
  // in case of previous state input[0] has size CxT_prevxB
  int n = input[0].dims(1), bsz = input[0].dims(2);
  double pDrop = train_ ? pDropout_ : 0.0;

  auto q = transpose((*wq_)(encoderInput));
  std::vector<fl::Variable> inputWithState(input.begin(), input.end() - 1);
  auto k = transpose((*wk_)(concatenate(inputWithState, 1)));
  auto v = transpose((*wv_)(concatenate(inputWithState, 1)));

  q = q / std::sqrt(float(q.dims(1) / nHeads_));

  Variable mask, posEmb;
  if (bptt_ > 0) {
    posEmb =
        tile(params_[0].as(encoderInput.type()), af::dim4(1, 1, nHeads_ * bsz));
  }
  if (useMask_ && encoderInput.dims(1) > 1) {
    // mask future if we use the previous state (then n is previous time)
    mask = getMask(n, input.size() == 3);
  }

  int offset = (input.size() == 2) ? 0 : n;

  // time x batch
  fl::Variable padMask;
  if (!input.back().isempty()) {
    auto padMaskArr = input.back().array();
    padMaskArr =
        af::resize(padMaskArr, encoderInput.dims(1), encoderInput.dims(2));
    padMask = fl::Variable(af::log(padMaskArr), false);
  }
  auto result = multiheadAttention(
      q, k, v, posEmb, mask, padMask, nHeads_, pDrop, offset);
  result = (*wf_)(transpose(result));

  return result;
}