in astra-sim-alibabacloud/astra-sim/workload/Workload.cc [930:1027]
void Workload::iterate_hybrid_parallel_DLRM() {
assert(index >= 0);
assert(index < SIZE);
check_for_sim_end();
if (current_state == LoopState::Forward_Pass) {
if (!layers[index]->is_weight_grad_comm_finished_blocking()) {
return;
}
if (delay_loaded == false) {
counter = layers[index]->get_fwd_pass_compute();
delay_loaded = true;
}
if (counter > 0) {
generator->try_register_event(
this, EventType::Workload_Wait, NULL, counter);
return;
}
if (!collective_issued &&
layers[index]->fwd_pass_comm_type == ComType::All_to_All) {
collective_issued = true;
layers[index]->issue_forward_pass_comm(
SchedulingPolicy::HIGHEST, CollectiveBarrier::Non_Blocking);
} else if (index == DLRM_LAST_BOTTOM_LAYER) {
if (!layers[0]->is_fwd_pass_comm_finished_blocking()) {
return;
}
}
index++;
delay_loaded = false;
collective_issued = false;
if (index >= SIZE) {
current_state = LoopState::Weight_Gradient;
index--;
}
if (generator->id == 0) {
std::cout << "*************************layer changed to: " << index
<< std::endl;
}
generator->register_event(this, EventType::General, NULL, 1);
return;
} else if (current_state == LoopState::Weight_Gradient) {
if (delay_loaded == false) {
counter = layers[index]->get_weight_grad_compute();
delay_loaded = true;
}
if (counter > 0) {
generator->try_register_event(
this, EventType::Workload_Wait, NULL, counter);
return;
}
if (!collective_issued) {
collective_issued = true;
layers[index]->issue_weight_grad_comm(
SchedulingPolicy::None, CollectiveBarrier::Non_Blocking);
}
if (parallelismPolicy == ParallelismPolicy::DLRM &&
!layers[index]->is_input_grad_comm_finished_blocking()) {
return;
}
if (index == 0) {
if (generator->id == 0) {
std::cout << "pass: " << pass_counter
<< " finished at time: " << Sys::boostedTick() << std::endl;
}
pass_counter++;
current_state = LoopState::Forward_Pass;
} else {
current_state = LoopState::Input_Gradient;
}
delay_loaded = false;
collective_issued = false;
generator->register_event(this, EventType::General, NULL, 1);
} else if (current_state == LoopState::Input_Gradient) {
if (delay_loaded == false) {
counter = layers[index]->get_input_grad_compute();
delay_loaded = true;
}
if (counter > 0) {
generator->try_register_event(
this, EventType::Workload_Wait, NULL, counter);
return;
}
if (index == DLRM_LAST_BOTTOM_LAYER + 1) {
layers[0]->issue_input_grad_comm(
SchedulingPolicy::HIGHEST, CollectiveBarrier::Non_Blocking);
}
index--;
if (generator->id == 0) {
std::cout << "*************************layer changed to: " << index
<< " in ig" << std::endl;
}
current_state = LoopState::Weight_Gradient;
collective_issued = false;
delay_loaded = false;
generator->register_event(this, EventType::General, NULL, 1);
}
}