quic::utils::vector CongestionControlEnv::stateSummary()

in congestion_control/CongestionControlEnv.cpp [141:180]


quic::utils::vector<NetworkState> CongestionControlEnv::stateSummary(
    const quic::utils::vector<NetworkState> &states) {
  int dim = 0;
  bool keepdim = true;
  // Bassel's correction on stddev only when defined to avoid NaNs.
  bool unbiased = (states.size() > 1);

  NetworkState::toTensor(states, summaryTensor_);
  const auto &sum = torch::sum(summaryTensor_, dim, keepdim);
  const auto &std_mean =
      torch::std_mean(summaryTensor_, dim, unbiased, keepdim);
  const auto &min = torch::amin(summaryTensor_, dim, keepdim);
  const auto &max = torch::amax(summaryTensor_, dim, keepdim);
  // If these statistics are modified / re-ordered, make sure to also update
  // the corresponding `OFFSET_*` constants in state.py.
  const auto &summary = torch::cat(
      {sum, std::get<1>(std_mean), std::get<0>(std_mean), min, max}, dim);
  auto summaryStates = NetworkState::fromTensor(summary);

  // Certain stats for some fields don't make sense such as sum over
  // RTT from ACKs. Zero-out them.
  static const quic::utils::vector<Field> invalidSumFields = {
      Field::RTT_MIN, Field::RTT_STANDING, Field::LRTT,
      Field::SRTT,    Field::RTT_VAR,      Field::DELAY,
      Field::CWND,    Field::IN_FLIGHT,    Field::WRITABLE,
  };
  for (const Field field : invalidSumFields) {
    summaryStates[0][field] = 0.0;
  }

  static const quic::utils::vector<std::string> keys = {
      "Sum", "Mean", "Std", "Min", "Max",
  };
  VLOG(2) << "State summary: ";
  for (size_t i = 0; i < summaryStates.size(); ++i) {
    VLOG(2) << keys[i] << ": " << summaryStates[i];
  }

  return summaryStates;
}