in astra-sim-alibabacloud/astra-sim/workload/Layer.cc [983:1162]
void Layer::issue_forward_pass_comm(
SchedulingPolicy pref_scheduling,
CollectiveBarrier barrier) {
MockNcclLog* NcclLog = MockNcclLog::getInstance();
#ifdef ANALYTI
fwd_barrier = barrier;
if (generator->id == 0){
NcclLog->writeLog(
NcclLogLevel::DEBUG,
"forward pass for layer %s is analytical ",
id.c_str());
NcclLog->writeLog(
NcclLogLevel::DEBUG,
"forward pass for layer-id %d is analytical ",
layer_num);
}
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
#endif
DataSet* fp = NULL;
fwd_barrier = barrier;
collective_counter++;
if (fwd_pass_comm_type == ComType::All_Reduce) {
#ifdef PHY_MTP
fp = generator->generate_all_reduce(
fwd_pass_comm_size,
fwd_pass_comm_involved_dimensions,
pref_scheduling,
layer_num,
EventType::Fwd_Comm_Finished,
this);
#else
fp = generator->generate_all_reduce(
fwd_pass_comm_size,
fwd_pass_comm_involved_dimensions,
pref_scheduling,
layer_num);
#endif
if (!fp->active) {
if (generator->id == 0) {
std::cout
<< "info: all dims disabled, no forward pass collective for layer: "
<< id << std::endl;
}
collective_counter--;
delete fp;
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
}
if (generator->id == 0) {
std::cout << "info: all-reduce forward pass collective issued for layer: "
<< id << ",";
print_involved_dimensions(fwd_pass_comm_involved_dimensions);
}
} else if (fwd_pass_comm_type == ComType::All_to_All) {
#ifdef PHY_MTP
fp = generator->generate_all_to_all(
fwd_pass_comm_size,
fwd_pass_comm_involved_dimensions,
pref_scheduling,
layer_num,
EventType::Fwd_Comm_Finished,
this);
#else
fp = generator->generate_all_to_all(
fwd_pass_comm_size,
fwd_pass_comm_involved_dimensions,
pref_scheduling,
layer_num);
#endif
if (!fp->active) {
if (generator->id == 0) {
std::cout
<< "info: all dims disabled, no forward pass collective for layer: "
<< id << std::endl;
}
collective_counter--;
delete fp;
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
}
if (generator->id == 0) {
std::cout << "info: all-to-all forward pass collective issued for layer: "
<< id << ",";
print_involved_dimensions(fwd_pass_comm_involved_dimensions);
}
} else if (fwd_pass_comm_type == ComType::All_Gather) {
#ifdef PHY_MTP
fp = generator->generate_all_gather(
fwd_pass_comm_size,
fwd_pass_comm_involved_dimensions,
pref_scheduling,
layer_num,
EventType::Fwd_Comm_Finished,
this);
#else
fp = generator->generate_all_gather(
fwd_pass_comm_size,
fwd_pass_comm_involved_dimensions,
pref_scheduling,
layer_num);
#endif
if (!fp->active) {
if (generator->id == 0) {
std::cout
<< "info: all dims disabled, no forward pass collective for layer: "
<< id << std::endl;
}
collective_counter--;
delete fp;
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
}
if (generator->id == 0) {
std::cout << "info: all-gather forward pass collective issued for layer: "
<< id << ",";
print_involved_dimensions(fwd_pass_comm_involved_dimensions);
}
} else if (fwd_pass_comm_type == ComType::Reduce_Scatter) {
#ifdef PHY_MTP
fp = generator->generate_reduce_scatter(
fwd_pass_comm_size,
fwd_pass_comm_involved_dimensions,
pref_scheduling,
layer_num,
EventType::Fwd_Comm_Finished,
this);
#else
fp = generator->generate_reduce_scatter(
fwd_pass_comm_size,
fwd_pass_comm_involved_dimensions,
pref_scheduling,
layer_num);
#endif
if (!fp->active) {
if (generator->id == 0) {
std::cout
<< "info: all dims disabled, no forward pass collective for layer: "
<< id << std::endl;
}
collective_counter--;
delete fp;
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
}
if (generator->id == 0) {
std::cout
<< "info: reduce-scatter forward pass collective issued for layer: "
<< id << ",";
print_involved_dimensions(fwd_pass_comm_involved_dimensions);
}
} else if (fwd_pass_comm_type == ComType::None) {
collective_counter--;
if (generator->id == 0) {
std::cout << "info: no forward pass collective for layer: " << id
<< std::endl;
}
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
} else {
Sys::sys_panic("no known collective operation! ");
}
#ifndef PHY_MTP
fwd_pass_datasets[fp->my_id] = fp;
fp->set_notifier(this, EventType::Fwd_Comm_Finished);
#endif
NcclLog->writeLog(NcclLogLevel::DEBUG,"Fwd_Comm_Finished set_notifier success");
}