void Layer::issue_input_grad_comm()

in astra-sim-alibabacloud/astra-sim/workload/Layer.cc [1163:1344]


void Layer::issue_input_grad_comm(
    SchedulingPolicy pref_scheduling,
    CollectiveBarrier barrier) {
  MockNcclLog* NcclLog = MockNcclLog::getInstance();  
  #ifdef ANALYTI
  ig_barrier = barrier;
  if (generator->id == 0){
    NcclLog->writeLog(
        NcclLogLevel::DEBUG,
        "input grad collective for layer %s is analytical ",
        id.c_str());
    NcclLog->writeLog(
        NcclLogLevel::DEBUG,
        "input grad collective for layer-id %d is analytical ",
        layer_num);
  }
    
  if (barrier == CollectiveBarrier::Blocking) {
    workload->call(EventType::General, NULL);
  }
  return;
  #endif
  DataSet* ig = NULL;
  ig_barrier = barrier;
  collective_counter++;
  if (input_grad_comm_type == ComType::All_Reduce) {
    #ifdef PHY_MTP
        ig = generator->generate_all_reduce(
        input_grad_comm_size,
        input_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Input_Grad_Comm_Finished,
        this);
    #else
    ig = generator->generate_all_reduce(
        input_grad_comm_size,
        input_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!ig->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no input grad collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete ig;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout << "info: all-reduce input grad collective issued for layer: "
                << id << ",";
      print_involved_dimensions(input_grad_comm_involved_dimensions);
    }
  } else if (input_grad_comm_type == ComType::All_to_All) {
    #ifdef PHY_MTP
    ig = generator->generate_all_to_all(
        input_grad_comm_size,
        input_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Input_Grad_Comm_Finished,
        this);
    #else
    ig = generator->generate_all_to_all(
        input_grad_comm_size,
        input_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!ig->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no input grad collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete ig;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout << "info: all-to-all input grad collective issued for layer: "
                << id << ",";
      print_involved_dimensions(input_grad_comm_involved_dimensions);
    }
  } else if (input_grad_comm_type == ComType::All_Gather) {
    #ifdef PHY_MTP
    ig = generator->generate_all_gather(
        input_grad_comm_size,
        input_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Input_Grad_Comm_Finished,
        this);
    #else
    ig = generator->generate_all_gather(
        input_grad_comm_size,
        input_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!ig->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no input grad collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete ig;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout << "info: all-gather input grad collective issued for layer: "
                << id << ",";
      print_involved_dimensions(input_grad_comm_involved_dimensions);
    }
  } else if (input_grad_comm_type == ComType::Reduce_Scatter) {
    #ifdef PHY_MTP
    ig = generator->generate_reduce_scatter(
        input_grad_comm_size,
        input_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num,
        EventType::Input_Grad_Comm_Finished,
        this);
    #else
    ig = generator->generate_reduce_scatter(
        input_grad_comm_size,
        input_grad_comm_involved_dimensions,
        pref_scheduling,
        layer_num);
    #endif
    if (!ig->active) {
      if (generator->id == 0) {
        std::cout
            << "info: all dims disabled, no input grad collective for layer: "
            << id << std::endl;
      }
      collective_counter--;
      delete ig;
      if (barrier == CollectiveBarrier::Blocking) {
        workload->call(EventType::General, NULL);
      }
      return;
    }
    if (generator->id == 0) {
      std::cout
          << "info: reduce-scatter input grad collective issued for layer: "
          << id << ",";
      print_involved_dimensions(input_grad_comm_involved_dimensions);
    }
  } else if (input_grad_comm_type == ComType::None) {
    collective_counter--;
    if (generator->id == 0) {
      std::cout << "info: no input grad collective for layer: " << id
                << std::endl;
    }
    if (barrier == CollectiveBarrier::Blocking) {
      workload->call(EventType::General, NULL);
    }
    return;
  } else {
    std::cout << "no known collective operation! for layer: " << id
              << std::endl;
    Sys::sys_panic("no known collective operation! ");
  }
  #ifndef PHY_MTP
  input_grad_datasets[ig->my_id] = ig;
  ig->set_notifier(this, EventType::Input_Grad_Comm_Finished);
  #endif
}