void CongestionControlLocalEnv::loop()

in congestion_control/CongestionControlLocalEnv.cpp [48:108]


void CongestionControlLocalEnv::loop() {
  Action action;
  bool done = true;
  uint32_t episode_step = 0;
  float episode_return = 0.0;
  std::unique_lock<std::mutex> lock(mutex_);

  // Initialize LSTM core state with zeros
  auto core_state = at::ivalue::Tuple::create(
      {torch::zeros({1, kLSTMHiddenSize}, at::kFloat),
       torch::zeros({1, kLSTMHiddenSize}, at::kFloat)});

  while (!shutdown_) {
    cv_.wait(lock, [&]() -> bool { return (observationReady_ || shutdown_); });
    if (shutdown_) {
      break;
    }

    done = (episode_step == 0);
    episode_return += reward_;
    VLOG(2) << "Episode step = " << episode_step
            << ", total return = " << episode_return;

    // env_inputs: (obs, reward, done)
    auto reward_tensor = torch::from_blob(&reward_, {1}, at::kFloat);
    auto done_tensor = torch::from_blob(&done, {1}, at::kBool);
    auto env_inputs = at::ivalue::Tuple::create({tensor_.reshape({1, -1}),
                                                 std::move(reward_tensor),
                                                 std::move(done_tensor)});

    // inputs: (last_action, (obs, reward, done), core_state)
    auto last_action_tensor =
        torch::from_blob(&action.cwndAction, {1}, at::kLong);
    quic::utils::vector<torch::IValue> inputs{std::move(last_action_tensor),
                                              std::move(env_inputs),
                                              std::move(core_state)};
    const auto &outputs = module_.forward(inputs).toTuple();

    // output: (action, core_state)
    const auto &action_tensor = outputs->elements()[0].toTensor();
    core_state = outputs->elements()[1].toTuple();

    action.cwndAction = *action_tensor.data_ptr<long>();

    // If there is an ongoing shutdown, it is important not to trigger the action
    // because `onAction()` calls `runImmediatelyOrRunInEventBaseThreadAndWait()`
    // and this method will hang forever during shutdown, preventing the thread from
    // exiting cleanly.
    if (!shutdown_) {
      onAction(action);
    } else {
      LOG(INFO) << "Skipping action due to shutdown in progress";
    }

    episode_step++;
    observationReady_ = false; // Back to waiting
  }

  LOG(INFO) << "Inference loop terminating after " << episode_step
            << " steps, total return = " << episode_return;
}