std::vector GenerateQueries()

in benchmarks/rnnt/ootb/inference/loadgen/loadgen.cc [191:330]


std::vector<QueryMetadata> GenerateQueries(
    const TestSettingsInternal& settings,
    const LoadableSampleSet& loaded_sample_set, SequenceGen* sequence_gen,
    ResponseDelegate* response_delegate) {
  auto tracer =
      MakeScopedTracer([](AsyncTrace& trace) { trace("GenerateQueries"); });

  auto& loaded_samples = loaded_sample_set.set;

  // Generate 2x more samples than we think we'll need given the expected
  // QPS in case the SUT is faster than expected.
  // We should exit before issuing all queries.
  // Does not apply to the server scenario since the duration only
  // depends on the ideal scheduled time, not the actual issue time.
  const int duration_multiplier = scenario == TestScenario::Server ? 1 : 2;
  std::chrono::microseconds gen_duration =
      duration_multiplier * settings.target_duration;
  size_t min_queries = settings.min_query_count;

  size_t samples_per_query = settings.samples_per_query;
  if (mode == TestMode::AccuracyOnly && scenario == TestScenario::Offline) {
    samples_per_query = loaded_sample_set.sample_distribution_end;
  }

  // We should not exit early in accuracy mode.
  if (mode == TestMode::AccuracyOnly || settings.performance_issue_unique ||
      settings.performance_issue_same) {
    gen_duration = std::chrono::microseconds(0);
    // Integer truncation here is intentional.
    // For MultiStream, loaded samples is properly padded.
    // For Offline, we create a 'remainder' query at the end of this function.
    min_queries = loaded_samples.size() / samples_per_query;
  }

  std::vector<QueryMetadata> queries;

  // Using the std::mt19937 pseudo-random number generator ensures a modicum of
  // cross platform reproducibility for trace generation.
  std::mt19937 sample_rng(settings.sample_index_rng_seed);
  std::mt19937 schedule_rng(settings.schedule_rng_seed);

  constexpr bool kIsMultiStream = scenario == TestScenario::MultiStream ||
                                  scenario == TestScenario::MultiStreamFree;
  const size_t sample_stride = kIsMultiStream ? samples_per_query : 1;

  auto sample_distribution = SampleDistribution<mode>(
      loaded_sample_set.sample_distribution_end, sample_stride, &sample_rng);
  // Use the unique sample distribution same as in AccuracyMode to
  // to choose samples when either flag performance_issue_unique
  // or performance_issue_same is set.
  auto sample_distribution_unique = SampleDistribution<TestMode::AccuracyOnly>(
      loaded_sample_set.sample_distribution_end, sample_stride, &sample_rng);

  auto schedule_distribution =
      ScheduleDistribution<scenario>(settings.target_qps);

  std::vector<QuerySampleIndex> samples(samples_per_query);
  std::chrono::nanoseconds timestamp(0);
  std::chrono::nanoseconds prev_timestamp(0);
  // Choose a single sample to repeat when in performance_issue_same mode
  QuerySampleIndex same_sample = settings.performance_issue_same_index;

  while (prev_timestamp < gen_duration || queries.size() < min_queries) {
    if (kIsMultiStream) {
      QuerySampleIndex sample_i = settings.performance_issue_unique
                                      ? sample_distribution_unique(sample_rng)
                                      : settings.performance_issue_same
                                            ? same_sample
                                            : sample_distribution(sample_rng);
      for (auto& s : samples) {
        // Select contiguous samples in the MultiStream scenario.
        // This will not overflow, since GenerateLoadableSets adds padding at
        // the end of the loadable sets in the MultiStream scenario.
        // The padding allows the starting samples to be the same for each
        // query as the value of samples_per_query increases.
        s = loaded_samples[sample_i++];
      }
    } else if (scenario == TestScenario::Offline) {
      // For the Offline + Performance scenario, we also want to support
      // contiguous samples. In this scenario the query can be much larger than
      // what fits into memory. We simply repeat loaded_samples N times, plus a
      // remainder to ensure we fill up samples. Note that this eliminates
      // randomization.
      size_t num_loaded_samples = loaded_samples.size();
      size_t num_full_repeats = samples_per_query / num_loaded_samples;
      uint64_t remainder = samples_per_query % (num_loaded_samples);
      if (settings.performance_issue_same) {
        std::fill(samples.begin(), samples.begin() + num_loaded_samples,
                  loaded_samples[same_sample]);
      } else {
        for (size_t i = 0; i < num_full_repeats; ++i) {
          std::copy(loaded_samples.begin(), loaded_samples.end(),
                    samples.begin() + i * num_loaded_samples);
        }

        std::copy(loaded_samples.begin(), loaded_samples.begin() + remainder,
                  samples.begin() + num_full_repeats * num_loaded_samples);
      }
    } else {
      for (auto& s : samples) {
        s = loaded_samples[settings.performance_issue_unique
                               ? sample_distribution_unique(sample_rng)
                               : settings.performance_issue_same
                                     ? same_sample
                                     : sample_distribution(sample_rng)];
      }
    }
    queries.emplace_back(samples, timestamp, response_delegate, sequence_gen);
    prev_timestamp = timestamp;
    timestamp += schedule_distribution(schedule_rng);
  }

  // See if we need to create a "remainder" query for offline+accuracy to
  // ensure we issue all samples in loaded_samples. Offline doesn't pad
  // loaded_samples like MultiStream does.
  if (scenario == TestScenario::Offline && mode == TestMode::AccuracyOnly) {
    size_t remaining_samples = loaded_samples.size() % samples_per_query;
    if (remaining_samples != 0) {
      samples.resize(remaining_samples);
      for (auto& s : samples) {
        s = loaded_samples[sample_distribution(sample_rng)];
      }
      queries.emplace_back(samples, timestamp, response_delegate, sequence_gen);
    }
  }

  LogDetail([count = queries.size(), spq = settings.samples_per_query,
             duration = timestamp.count()](AsyncDetail& detail) {
#if USE_NEW_LOGGING_FORMAT
    MLPERF_LOG(detail, "generated_query_count", count);
    MLPERF_LOG(detail, "generated_samples_per_query", spq);
    MLPERF_LOG(detail, "generated_query_duration", duration);
#else
    detail("GeneratedQueries: ", "queries", count, "samples per query", spq,
           "duration", duration);
#endif
  });

  return queries;
}