in congestion_control/CongestionControlLocalEnv.cpp [48:108]
void CongestionControlLocalEnv::loop() {
Action action;
bool done = true;
uint32_t episode_step = 0;
float episode_return = 0.0;
std::unique_lock<std::mutex> lock(mutex_);
// Initialize LSTM core state with zeros
auto core_state = at::ivalue::Tuple::create(
{torch::zeros({1, kLSTMHiddenSize}, at::kFloat),
torch::zeros({1, kLSTMHiddenSize}, at::kFloat)});
while (!shutdown_) {
cv_.wait(lock, [&]() -> bool { return (observationReady_ || shutdown_); });
if (shutdown_) {
break;
}
done = (episode_step == 0);
episode_return += reward_;
VLOG(2) << "Episode step = " << episode_step
<< ", total return = " << episode_return;
// env_inputs: (obs, reward, done)
auto reward_tensor = torch::from_blob(&reward_, {1}, at::kFloat);
auto done_tensor = torch::from_blob(&done, {1}, at::kBool);
auto env_inputs = at::ivalue::Tuple::create({tensor_.reshape({1, -1}),
std::move(reward_tensor),
std::move(done_tensor)});
// inputs: (last_action, (obs, reward, done), core_state)
auto last_action_tensor =
torch::from_blob(&action.cwndAction, {1}, at::kLong);
quic::utils::vector<torch::IValue> inputs{std::move(last_action_tensor),
std::move(env_inputs),
std::move(core_state)};
const auto &outputs = module_.forward(inputs).toTuple();
// output: (action, core_state)
const auto &action_tensor = outputs->elements()[0].toTensor();
core_state = outputs->elements()[1].toTuple();
action.cwndAction = *action_tensor.data_ptr<long>();
// If there is an ongoing shutdown, it is important not to trigger the action
// because `onAction()` calls `runImmediatelyOrRunInEventBaseThreadAndWait()`
// and this method will hang forever during shutdown, preventing the thread from
// exiting cleanly.
if (!shutdown_) {
onAction(action);
} else {
LOG(INFO) << "Skipping action due to shutdown in progress";
}
episode_step++;
observationReady_ = false; // Back to waiting
}
LOG(INFO) << "Inference loop terminating after " << episode_step
<< " steps, total return = " << episode_return;
}