void Layer::issue_forward_pass_comm()

in astra-sim-alibabacloud/astra-sim/workload/Layer.cc [983:1162]


void Layer::issue_forward_pass_comm(
    SchedulingPolicy pref_scheduling,
    CollectiveBarrier barrier) {
  MockNcclLog* NcclLog = MockNcclLog::getInstance();
  #ifdef ANALYTI
    fwd_barrier = barrier;
    if (generator->id == 0){
      NcclLog->writeLog(
          NcclLogLevel::DEBUG,
          "forward pass for layer %s is analytical ",
          id.c_str());
      NcclLog->writeLog(
          NcclLogLevel::DEBUG,
          "forward pass for layer-id %d is analytical ",
          layer_num);
    }
    if (barrier == CollectiveBarrier::Blocking) {
      workload->call(EventType::General, NULL);
    }
    return;
  #endif
  DataSet* fp = NULL;
  fwd_barrier = barrier;
  collective_counter++;
  if (fwd_pass_comm_type == ComType::All_Reduce) {
    #ifdef PHY_MTP
    fp = generator->generate_all_reduce(
        fwd_pass_comm_size,
        fwd_pass_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Fwd_Comm_Finished,
        this);
    #else
    fp = generator->generate_all_reduce(
        fwd_pass_comm_size,
        fwd_pass_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!fp->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no forward pass collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete fp;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout << "info: all-reduce forward pass collective issued for layer: "
                << id << ",";
      print_involved_dimensions(fwd_pass_comm_involved_dimensions);
    }
  } else if (fwd_pass_comm_type == ComType::All_to_All) {
    #ifdef PHY_MTP
    fp = generator->generate_all_to_all(
        fwd_pass_comm_size,
        fwd_pass_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Fwd_Comm_Finished,
        this);
    #else
    fp = generator->generate_all_to_all(
        fwd_pass_comm_size,
        fwd_pass_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!fp->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no forward pass collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete fp;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout << "info: all-to-all forward pass collective issued for layer: "
                << id << ",";
      print_involved_dimensions(fwd_pass_comm_involved_dimensions);
    }
  } else if (fwd_pass_comm_type == ComType::All_Gather) {
    #ifdef PHY_MTP
    fp = generator->generate_all_gather(
        fwd_pass_comm_size,
        fwd_pass_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Fwd_Comm_Finished,
        this);
    #else
    fp = generator->generate_all_gather(
        fwd_pass_comm_size,
        fwd_pass_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!fp->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no forward pass collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete fp;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout << "info: all-gather forward pass collective issued for layer: "
                << id << ",";
      print_involved_dimensions(fwd_pass_comm_involved_dimensions);
    }
  } else if (fwd_pass_comm_type == ComType::Reduce_Scatter) {
    #ifdef PHY_MTP
    fp = generator->generate_reduce_scatter(
        fwd_pass_comm_size,
        fwd_pass_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Fwd_Comm_Finished,
        this);
    #else
    fp = generator->generate_reduce_scatter(
        fwd_pass_comm_size,
        fwd_pass_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!fp->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no forward pass collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete fp;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout
          << "info: reduce-scatter forward pass collective issued for layer: "
          << id << ",";
      print_involved_dimensions(fwd_pass_comm_involved_dimensions);
    }
  } else if (fwd_pass_comm_type == ComType::None) {
    collective_counter--;
    if (generator->id == 0) {
      std::cout << "info: no forward pass collective for layer: " << id
                << std::endl;
    }
    if (barrier == CollectiveBarrier::Blocking) {
      workload->call(EventType::General, NULL);
    }
    return;
  } else {
    Sys::sys_panic("no known collective operation! ");
  }
  #ifndef PHY_MTP
  fwd_pass_datasets[fp->my_id] = fp;
  fp->set_notifier(this, EventType::Fwd_Comm_Finished);
  #endif
  NcclLog->writeLog(NcclLogLevel::DEBUG,"Fwd_Comm_Finished set_notifier success");
}