in recipes/joint_training_vox_populi/cpc/TransformerCPC.cpp [104:143]
Variable TransformerCPC::selfAttention(const std::vector<Variable>& input) {
// previous step[optionally], input, padMask
auto encoderInput = input.at(input.size() - 2);
// in case of previous state input[0] has size CxT_prevxB
int n = input[0].dims(1), bsz = input[0].dims(2);
double pDrop = train_ ? pDropout_ : 0.0;
auto q = transpose((*wq_)(encoderInput));
std::vector<fl::Variable> inputWithState(input.begin(), input.end() - 1);
auto k = transpose((*wk_)(concatenate(inputWithState, 1)));
auto v = transpose((*wv_)(concatenate(inputWithState, 1)));
q = q / std::sqrt(float(q.dims(1) / nHeads_));
Variable mask, posEmb;
if (bptt_ > 0) {
posEmb =
tile(params_[0].as(encoderInput.type()), af::dim4(1, 1, nHeads_ * bsz));
}
if (useMask_ && encoderInput.dims(1) > 1) {
// mask future if we use the previous state (then n is previous time)
mask = getMask(n, input.size() == 3);
}
int offset = (input.size() == 2) ? 0 : n;
// time x batch
fl::Variable padMask;
if (!input.back().isempty()) {
auto padMaskArr = input.back().array();
padMaskArr =
af::resize(padMaskArr, encoderInput.dims(1), encoderInput.dims(2));
padMask = fl::Variable(af::log(padMaskArr), false);
}
auto result = multiheadAttention(
q, k, v, posEmb, mask, padMask, nHeads_, pDrop, offset);
result = (*wf_)(transpose(result));
return result;
}