std::vector CPCCriterion::forward()

in recipes/joint_training_vox_populi/cpc/CPCCriterion.cpp [175:220]


std::vector<Variable> CPCCriterion::forward(
    const std::vector<Variable>& inputs) {
  // enc_out, context = C T N 1
  const auto& samples = inputs[0];
  const auto& context_mask = inputs[1];

  int N = context_mask.dims(2);
  int T = context_mask.dims(1);
  int C = context_mask.dims(0);

  std::vector<Variable> loss;

  auto mask = masked_.array();

  for (int i = 0; i < N; i++) {
    auto batch_mask = af::where(af::flat(mask(af::span, af::span, i))).as(s64);
    auto anchor =
        mutualLinear(1)
            ->forward(context_mask(af::span, batch_mask, i, af::span, true))
            .as(f32);
    auto pos_samples =
        mutualLinear(0)
            ->forward(samples(af::span, batch_mask, i, af::span, true))
            .as(f32);

    anchor = anchor / tileAs(norm(anchor, {0}), anchor.dims());
    pos_samples =
        pos_samples / tileAs(norm(pos_samples, {0}), pos_samples.dims());

    auto neg_samples = getNegativeSamples(pos_samples);
    auto anchor_neg = tileAs(anchor, neg_samples.dims());

    pos_samples = sum(pos_samples * anchor, {0}) / temperature_;
    neg_samples = sum(neg_samples * anchor_neg, {0}) / temperature_;
    auto all_samples = concatenate({pos_samples, neg_samples}, 3);

    auto max_samples = Variable(af::max(all_samples.array(), 3), false);
    auto sum_samples =
        sum(exp(all_samples - tileAs(max_samples, all_samples)), {3});
    loss.push_back(
        sum((max_samples + log(sum_samples) - pos_samples), {1}) /
        anchor.dims(1));
  }

  return {reorder(concatenate(loss, 2), 2, 0, 1, 3)};
}