in astra-sim-alibabacloud/astra-sim/workload/Workload.cc [242:308]
void Workload::iterate_data_parallel() {
assert(index >= 0);
assert(index < SIZE);
check_for_sim_end();
if (current_state == LoopState::Forward_Pass) {
if (!layers[index]->is_weight_grad_comm_finished_blocking()) {
return;
}
if (delay_loaded == false) {
counter = layers[index]->get_fwd_pass_compute();
delay_loaded = true;
}
if (counter > 0) {
generator->try_register_event(
this, EventType::Workload_Wait, NULL, counter);
return;
}
index++;
delay_loaded = false;
if (index >= SIZE) {
current_state = LoopState::Weight_Gradient;
index--;
}
generator->register_event(this, EventType::General, NULL, 1);
return;
} else if (current_state == LoopState::Weight_Gradient) {
if (delay_loaded == false) {
counter = layers[index]->get_weight_grad_compute();
delay_loaded = true;
}
if (counter > 0) {
generator->try_register_event(
this, EventType::Workload_Wait, NULL, counter);
return;
}
delay_loaded = false;
layers[index]->issue_weight_grad_comm(
SchedulingPolicy::None, CollectiveBarrier::Non_Blocking);
if (index == 0) {
if (generator->id == 0) {
std::cout << "pass: " << pass_counter
<< " finished at time: " << Sys::boostedTick() << std::endl;
}
pass_counter++;
current_state = LoopState::Forward_Pass;
} else {
current_state = LoopState::Input_Gradient;
}
generator->register_event(this, EventType::General, NULL, 1);
return;
} else if (current_state == LoopState::Input_Gradient) {
if (delay_loaded == false) {
counter = layers[index]->get_input_grad_compute();
delay_loaded = true;
}
if (counter > 0) {
generator->try_register_event(
this, EventType::Workload_Wait, NULL, counter);
return;
}
delay_loaded = false;
index--;
current_state = LoopState::Weight_Gradient;
generator->register_event(this, EventType::General, NULL, 1);
return;
}
}