in recipes/joint_training_vox_populi/cpc/CPCCriterion.cpp [175:220]
std::vector<Variable> CPCCriterion::forward(
const std::vector<Variable>& inputs) {
// enc_out, context = C T N 1
const auto& samples = inputs[0];
const auto& context_mask = inputs[1];
int N = context_mask.dims(2);
int T = context_mask.dims(1);
int C = context_mask.dims(0);
std::vector<Variable> loss;
auto mask = masked_.array();
for (int i = 0; i < N; i++) {
auto batch_mask = af::where(af::flat(mask(af::span, af::span, i))).as(s64);
auto anchor =
mutualLinear(1)
->forward(context_mask(af::span, batch_mask, i, af::span, true))
.as(f32);
auto pos_samples =
mutualLinear(0)
->forward(samples(af::span, batch_mask, i, af::span, true))
.as(f32);
anchor = anchor / tileAs(norm(anchor, {0}), anchor.dims());
pos_samples =
pos_samples / tileAs(norm(pos_samples, {0}), pos_samples.dims());
auto neg_samples = getNegativeSamples(pos_samples);
auto anchor_neg = tileAs(anchor, neg_samples.dims());
pos_samples = sum(pos_samples * anchor, {0}) / temperature_;
neg_samples = sum(neg_samples * anchor_neg, {0}) / temperature_;
auto all_samples = concatenate({pos_samples, neg_samples}, 3);
auto max_samples = Variable(af::max(all_samples.array(), 3), false);
auto sum_samples =
sum(exp(all_samples - tileAs(max_samples, all_samples)), {3});
loss.push_back(
sum((max_samples + log(sum_samples) - pos_samples), {1}) /
anchor.dims(1));
}
return {reorder(concatenate(loss, 2), 2, 0, 1, 3)};
}