Args AutotuneStrategy::ask()

in src/autotune.cc [126:188]


Args AutotuneStrategy::ask(double elapsed) {
  const double t = std::min(1.0, elapsed / maxDuration_);
  trials_++;

  if (trials_ == 1) {
    return bestArgs_;
  }

  Args args = bestArgs_;

  if (!args.isManual("epoch")) {
    args.epoch = updateArgGauss(args.epoch, 1, 100, 2.8, 2.5, t, false, rng_);
  }
  if (!args.isManual("lr")) {
    args.lr = updateArgGauss(args.lr, 0.01, 5.0, 1.9, 1.0, t, false, rng_);
  };
  if (!args.isManual("dim")) {
    args.dim = updateArgGauss(args.dim, 1, 1000, 1.4, 0.3, t, false, rng_);
  }
  if (!args.isManual("wordNgrams")) {
    args.wordNgrams =
        updateArgGauss(args.wordNgrams, 1, 5, 4.3, 2.4, t, true, rng_);
  }
  if (!args.isManual("dsub")) {
    int dsubExponent =
        updateArgGauss(bestDsubExponent_, 1, 4, 2.0, 1.0, t, true, rng_);
    args.dsub = (1 << dsubExponent);
  }
  if (!args.isManual("minn")) {
    int minnIndex = updateArgGauss(
        bestMinnIndex_,
        0,
        static_cast<int>(minnChoices_.size() - 1),
        4.0,
        1.4,
        t,
        true,
        rng_);
    args.minn = minnChoices_[minnIndex];
  }
  if (!args.isManual("maxn")) {
    if (args.minn == 0) {
      args.maxn = 0;
    } else {
      args.maxn = args.minn + 3;
    }
  }
  if (!args.isManual("bucket")) {
    int nonZeroBucket = updateArgGauss(
        bestNonzeroBucket_, 10000, 10000000, 2.0, 1.5, t, false, rng_);
    args.bucket = nonZeroBucket;
  } else {
    args.bucket = originalBucket_;
  }
  if (args.wordNgrams <= 1 && args.maxn == 0) {
    args.bucket = 0;
  }
  if (!args.isManual("loss")) {
    args.loss = loss_name::softmax;
  }

  return args;
}