in astra-sim-alibabacloud/astra-sim/workload/Layer.cc [90:252]
void Layer::call(EventType event, CallData* mdata) {
if (event == EventType::Wight_Grad_Comm_Finished) {
last_wg_finished = Sys::boostedTick();
generator->register_event(
this,
EventType::Wight_Grad_Comm_Finished_After_Delay,
mdata,
weight_grad_update_time);
return;
} else if (event == EventType::Input_Grad_Comm_Finished) {
last_ig_finished = Sys::boostedTick();
generator->register_event(
this,
EventType::Input_Grad_Comm_Finished_After_Delay,
mdata,
input_grad_update_time);
return;
} else if (event == EventType::Fwd_Comm_Finished) {
last_fwd_finished = Sys::boostedTick();
generator->register_event(
this, EventType::Fwd_Comm_Finished_After_Delay, mdata, fwd_update_time);
return;
}
int data = ((IntData*)mdata)->data;
IntData* intData = ((IntData*)mdata);
if (event == EventType::Wight_Grad_Comm_Finished_After_Delay) {
#ifndef PHY_MTP
if (generator->id == 0) {
std::cout << "***** info: weight gradient collective for layer: " << id
<< " is finished************" << std::endl;
}
weight_grad_datasets[data]->finish_tick += weight_grad_update_time;
total_weight_grad_comm += weight_grad_datasets[data]->finish_tick -
weight_grad_datasets[data]->creation_tick;
if (weight_grad_datasets.size() == 1 &&
wg_barrier == CollectiveBarrier::Blocking) {
total_waiting_for_wg_comm += weight_grad_datasets[data]->finish_tick -
weight_grad_datasets[data]->creation_tick;
update_stream_stats(weight_grad_datasets[data]);
int dataset_streams = weight_grad_datasets[data]->total_streams;
delete weight_grad_datasets[data];
weight_grad_datasets.erase(data);
workload->call(EventType::General, NULL);
generator->increase_finished_streams(dataset_streams);
delete intData;
return;
} else if (started_waiting_for_weight_grad.size() > 0) {
total_waiting_for_wg_comm += weight_grad_datasets[data]->finish_tick -
started_waiting_for_weight_grad.front();
started_waiting_for_weight_grad.pop_front();
update_stream_stats(weight_grad_datasets[data]);
int dataset_streams = weight_grad_datasets[data]->total_streams;
delete weight_grad_datasets[data];
weight_grad_datasets.erase(data);
workload->call(EventType::General, NULL);
generator->increase_finished_streams(dataset_streams);
delete intData;
return;
}
update_stream_stats(weight_grad_datasets[data]);
int dataset_streams = weight_grad_datasets[data]->total_streams;
delete weight_grad_datasets[data];
weight_grad_datasets.erase(data);
generator->increase_finished_streams(dataset_streams);
delete intData;
#else
workload->call(EventType::General, NULL);
generator->increase_finished_streams(1);
#endif
return;
} else if (event == EventType::Input_Grad_Comm_Finished_After_Delay) {
#ifndef PHY_MTP
if (generator->id == 0) {
std::cout << "***** info: input gradient collective for layer: " << id
<< " is finished************" << std::endl;
}
input_grad_datasets[data]->finish_tick += input_grad_update_time;
total_input_grad_comm += input_grad_datasets[data]->finish_tick -
input_grad_datasets[data]->creation_tick;
if (input_grad_datasets.size() == 1 &&
ig_barrier == CollectiveBarrier::Blocking) {
total_waiting_for_ig_comm += input_grad_datasets[data]->finish_tick -
input_grad_datasets[data]->creation_tick;
update_stream_stats(input_grad_datasets[data]);
int dataset_streams = input_grad_datasets[data]->total_streams;
delete input_grad_datasets[data];
input_grad_datasets.erase(data);
workload->call(EventType::General, NULL);
generator->increase_finished_streams(dataset_streams);
delete intData;
return;
} else if (started_waiting_for_input_grad.size() > 0) {
total_waiting_for_ig_comm += input_grad_datasets[data]->finish_tick -
started_waiting_for_input_grad.front();
started_waiting_for_input_grad.pop_front();
update_stream_stats(input_grad_datasets[data]);
int dataset_streams = input_grad_datasets[data]->total_streams;
delete input_grad_datasets[data];
input_grad_datasets.erase(data);
workload->call(EventType::General, NULL);
generator->increase_finished_streams(dataset_streams);
delete intData;
return;
}
update_stream_stats(input_grad_datasets[data]);
int dataset_streams = input_grad_datasets[data]->total_streams;
delete input_grad_datasets[data];
input_grad_datasets.erase(data);
generator->increase_finished_streams(dataset_streams);
delete intData;
#else
workload->call(EventType::General, NULL);
generator->increase_finished_streams(1);
#endif
return;
} else if (event == EventType::Fwd_Comm_Finished_After_Delay) {
#ifndef PHY_MTP
if (generator->id == 0) {
std::cout << "***** info: fwd pass comm collective for layer: " << id
<< " is finished************" << std::endl;
}
fwd_pass_datasets[data]->finish_tick += fwd_update_time;
total_fwd_comm += fwd_pass_datasets[data]->finish_tick -
fwd_pass_datasets[data]->creation_tick;
if (fwd_pass_datasets.size() == 1 &&
fwd_barrier == CollectiveBarrier::Blocking) {
total_waiting_for_fwd_comm += fwd_pass_datasets[data]->finish_tick -
fwd_pass_datasets[data]->creation_tick;
update_stream_stats(fwd_pass_datasets[data]);
int dataset_streams = fwd_pass_datasets[data]->total_streams;
delete fwd_pass_datasets[data];
fwd_pass_datasets.erase(data);
workload->call(EventType::General, NULL);
generator->increase_finished_streams(dataset_streams);
delete intData;
return;
} else if (started_waiting_for_fwd_pass.size() > 0) {
total_waiting_for_fwd_comm += fwd_pass_datasets[data]->finish_tick -
started_waiting_for_fwd_pass.front();
started_waiting_for_fwd_pass.pop_front();
update_stream_stats(fwd_pass_datasets[data]);
int dataset_streams = fwd_pass_datasets[data]->total_streams;
delete fwd_pass_datasets[data];
fwd_pass_datasets.erase(data);
workload->call(EventType::General, NULL);
generator->increase_finished_streams(dataset_streams);
delete intData;
return;
}
update_stream_stats(fwd_pass_datasets[data]);
int dataset_streams = fwd_pass_datasets[data]->total_streams;
delete fwd_pass_datasets[data];
fwd_pass_datasets.erase(data);
generator->increase_finished_streams(dataset_streams);
delete intData;
#else
workload->call(EventType::General, NULL);
generator->increase_finished_streams(1);
#endif
return;
}
}