in astra-sim-alibabacloud/astra-sim/workload/Workload.cc [309:398]
void Workload::iterate_hybrid_parallel_customized() {
assert(index >= 0);
assert(index < SIZE);
check_for_sim_end();
if (current_state == LoopState::Forward_Pass) {
if (!layers[index]->is_weight_grad_comm_finished_blocking()) {
return;
}
if (delay_loaded == false) {
counter = layers[index]->get_fwd_pass_compute();
delay_loaded = true;
}
if (counter > 0) {
generator->try_register_event(
this, EventType::Workload_Wait, NULL, counter);
return;
}
if (!collective_issued) {
collective_issued = true;
layers[index]->issue_forward_pass_comm(
SchedulingPolicy::None, CollectiveBarrier::Blocking);
return;
}
index++;
delay_loaded = false;
collective_issued = false;
if (index >= SIZE) {
current_state = LoopState::Input_Gradient;
index--;
}
generator->register_event(this, EventType::General, NULL, 1);
return;
} else if (current_state == LoopState::Weight_Gradient) {
if (delay_loaded == false) {
counter = layers[index]->get_weight_grad_compute();
delay_loaded = true;
}
if (counter > 0) {
generator->try_register_event(
this, EventType::Workload_Wait, NULL, counter);
return;
}
if (!collective_issued) {
collective_issued = true;
layers[index]->issue_weight_grad_comm(
SchedulingPolicy::FIFO, CollectiveBarrier::Non_Blocking);
}
if (!layers[index]->is_input_grad_comm_finished_blocking()) {
return;
}
collective_issued = false;
delay_loaded = false;
if (index >= 0) {
index--;
}
if (index == -1) {
index = 0;
if (generator->id == 0) {
std::cout << "pass: " << pass_counter
<< " finished at time: " << Sys::boostedTick() << std::endl;
}
pass_counter++;
current_state = LoopState::Forward_Pass;
} else {
current_state = LoopState::Input_Gradient;
}
generator->register_event(this, EventType::General, NULL, 1);
return;
} else if (current_state == LoopState::Input_Gradient) {
if (delay_loaded == false) {
counter = layers[index]->get_input_grad_compute();
delay_loaded = true;
}
if (counter > 0) {
generator->try_register_event(
this, EventType::Workload_Wait, NULL, counter);
return;
}
if (!collective_issued && index > 0) {
collective_issued = true;
layers[index]->issue_input_grad_comm(
SchedulingPolicy::LIFO, CollectiveBarrier::Non_Blocking);
}
collective_issued = false;
delay_loaded = false;
current_state = LoopState::Weight_Gradient;
generator->register_event(this, EventType::General, NULL, 1);
return;
}
}