void Layer::issue_weight_grad_comm()

in astra-sim-alibabacloud/astra-sim/workload/Layer.cc [1345:1525]


void Layer::issue_weight_grad_comm(
    SchedulingPolicy pref_scheduling,
    CollectiveBarrier barrier) {
  MockNcclLog* NcclLog = MockNcclLog::getInstance();
  #ifdef ANALYTI
  wg_barrier = barrier;
  if (generator->id == 0){
    NcclLog->writeLog(
        NcclLogLevel::DEBUG,
        "weight grad collective for layer %s is analytical ",
        id.c_str());
    NcclLog->writeLog(
        NcclLogLevel::DEBUG,
        "weight grad collective for layer-id %d is analytical ",
        layer_num);
  }
    
  if (barrier == CollectiveBarrier::Blocking) {
    workload->call(EventType::General, NULL);
  }
  return;
  #endif
  DataSet* wg = NULL;
  wg_barrier = barrier;
  collective_counter++;
  if (weight_grad_comm_type == ComType::All_Reduce) {
    #ifdef PHY_MTP
    wg = generator->generate_all_reduce(
        weight_grad_comm_size,
        weight_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Wight_Grad_Comm_Finished,
        this);
    #else
    wg = generator->generate_all_reduce(
        weight_grad_comm_size,
        weight_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!wg->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no weight grad collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete wg;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout << "info: all-reduce weight grad collective issued for layer: "
                << id << " with size: " << weight_grad_comm_size << ",";
      print_involved_dimensions(weight_grad_comm_involved_dimensions);
    }
  } else if (weight_grad_comm_type == ComType::All_to_All) {
    #ifdef PHY_MTP
    wg = generator->generate_all_to_all(
        weight_grad_comm_size,
        weight_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Wight_Grad_Comm_Finished,
        this);
    #else
    wg = generator->generate_all_to_all(
        weight_grad_comm_size,
        weight_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!wg->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no weight grad collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete wg;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout << "info: all-to-all weight grad collective issued for layer: "
                << id << " with size: " << weight_grad_comm_size << ",";
      print_involved_dimensions(weight_grad_comm_involved_dimensions);
    }
  } else if (weight_grad_comm_type == ComType::All_Gather) {
    if(generator->id == 0) std::cout << "Layer issue wg all gather at tick: " << Sys::boostedTick() << std::endl;
    #ifdef PHY_MTP
    wg = generator->generate_all_gather(
        weight_grad_comm_size,
        weight_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Wight_Grad_Comm_Finished,
        this);
    #else
    wg = generator->generate_all_gather(
        weight_grad_comm_size,
        weight_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!wg->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no weight grad collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete wg;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout << "info: all-gather weight grad collective issued for layer: "
                << id << ",";
      print_involved_dimensions(weight_grad_comm_involved_dimensions);
    }
  } else if (weight_grad_comm_type == ComType::Reduce_Scatter) {
    #ifdef PHY_MTP
    wg = generator->generate_reduce_scatter(
        weight_grad_comm_size,
        weight_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Wight_Grad_Comm_Finished,
        this);
    #else
    wg = generator->generate_reduce_scatter(
        weight_grad_comm_size,
        weight_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!wg->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no weight grad collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete wg;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout
          << "info: reduce-scatter weight grad collective issued for layer: "
          << id << ",";
      print_involved_dimensions(weight_grad_comm_involved_dimensions);
    }
  } else if (weight_grad_comm_type == ComType::None) {
    collective_counter--;
    if (generator->id == 0) {
      std::cout << "info: no weight grad collective for layer: " << id
                << std::endl;
    }
    if (barrier == CollectiveBarrier::Blocking) {
      workload->call(EventType::General, NULL);
    }
    return;
  } else {
    Sys::sys_panic("no known collective operation! ");
  }
  #ifndef PHY_MTP
  weight_grad_datasets[wg->my_id] = wg;
  wg->set_notifier(this, EventType::Wight_Grad_Comm_Finished);
  #endif
}