in astra-sim-alibabacloud/astra-sim/workload/Layer.cc [1345:1525]
void Layer::issue_weight_grad_comm(
SchedulingPolicy pref_scheduling,
CollectiveBarrier barrier) {
MockNcclLog* NcclLog = MockNcclLog::getInstance();
#ifdef ANALYTI
wg_barrier = barrier;
if (generator->id == 0){
NcclLog->writeLog(
NcclLogLevel::DEBUG,
"weight grad collective for layer %s is analytical ",
id.c_str());
NcclLog->writeLog(
NcclLogLevel::DEBUG,
"weight grad collective for layer-id %d is analytical ",
layer_num);
}
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
#endif
DataSet* wg = NULL;
wg_barrier = barrier;
collective_counter++;
if (weight_grad_comm_type == ComType::All_Reduce) {
#ifdef PHY_MTP
wg = generator->generate_all_reduce(
weight_grad_comm_size,
weight_grad_comm_involved_dimensions,
pref_scheduling,
layer_num,
EventType::Wight_Grad_Comm_Finished,
this);
#else
wg = generator->generate_all_reduce(
weight_grad_comm_size,
weight_grad_comm_involved_dimensions,
pref_scheduling,
layer_num);
#endif
if (!wg->active) {
if (generator->id == 0) {
std::cout
<< "info: all dims disabled, no weight grad collective for layer: "
<< id << std::endl;
}
collective_counter--;
delete wg;
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
}
if (generator->id == 0) {
std::cout << "info: all-reduce weight grad collective issued for layer: "
<< id << " with size: " << weight_grad_comm_size << ",";
print_involved_dimensions(weight_grad_comm_involved_dimensions);
}
} else if (weight_grad_comm_type == ComType::All_to_All) {
#ifdef PHY_MTP
wg = generator->generate_all_to_all(
weight_grad_comm_size,
weight_grad_comm_involved_dimensions,
pref_scheduling,
layer_num,
EventType::Wight_Grad_Comm_Finished,
this);
#else
wg = generator->generate_all_to_all(
weight_grad_comm_size,
weight_grad_comm_involved_dimensions,
pref_scheduling,
layer_num);
#endif
if (!wg->active) {
if (generator->id == 0) {
std::cout
<< "info: all dims disabled, no weight grad collective for layer: "
<< id << std::endl;
}
collective_counter--;
delete wg;
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
}
if (generator->id == 0) {
std::cout << "info: all-to-all weight grad collective issued for layer: "
<< id << " with size: " << weight_grad_comm_size << ",";
print_involved_dimensions(weight_grad_comm_involved_dimensions);
}
} else if (weight_grad_comm_type == ComType::All_Gather) {
if(generator->id == 0) std::cout << "Layer issue wg all gather at tick: " << Sys::boostedTick() << std::endl;
#ifdef PHY_MTP
wg = generator->generate_all_gather(
weight_grad_comm_size,
weight_grad_comm_involved_dimensions,
pref_scheduling,
layer_num,
EventType::Wight_Grad_Comm_Finished,
this);
#else
wg = generator->generate_all_gather(
weight_grad_comm_size,
weight_grad_comm_involved_dimensions,
pref_scheduling,
layer_num);
#endif
if (!wg->active) {
if (generator->id == 0) {
std::cout
<< "info: all dims disabled, no weight grad collective for layer: "
<< id << std::endl;
}
collective_counter--;
delete wg;
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
}
if (generator->id == 0) {
std::cout << "info: all-gather weight grad collective issued for layer: "
<< id << ",";
print_involved_dimensions(weight_grad_comm_involved_dimensions);
}
} else if (weight_grad_comm_type == ComType::Reduce_Scatter) {
#ifdef PHY_MTP
wg = generator->generate_reduce_scatter(
weight_grad_comm_size,
weight_grad_comm_involved_dimensions,
pref_scheduling,
layer_num,
EventType::Wight_Grad_Comm_Finished,
this);
#else
wg = generator->generate_reduce_scatter(
weight_grad_comm_size,
weight_grad_comm_involved_dimensions,
pref_scheduling,
layer_num);
#endif
if (!wg->active) {
if (generator->id == 0) {
std::cout
<< "info: all dims disabled, no weight grad collective for layer: "
<< id << std::endl;
}
collective_counter--;
delete wg;
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
}
if (generator->id == 0) {
std::cout
<< "info: reduce-scatter weight grad collective issued for layer: "
<< id << ",";
print_involved_dimensions(weight_grad_comm_involved_dimensions);
}
} else if (weight_grad_comm_type == ComType::None) {
collective_counter--;
if (generator->id == 0) {
std::cout << "info: no weight grad collective for layer: " << id
<< std::endl;
}
if (barrier == CollectiveBarrier::Blocking) {
workload->call(EventType::General, NULL);
}
return;
} else {
Sys::sys_panic("no known collective operation! ");
}
#ifndef PHY_MTP
weight_grad_datasets[wg->my_id] = wg;
wg->set_notifier(this, EventType::Wight_Grad_Comm_Finished);
#endif
}