void Layer::call()

in astra-sim-alibabacloud/astra-sim/workload/Layer.cc [90:252]


void Layer::call(EventType event, CallData* mdata) {
  if (event == EventType::Wight_Grad_Comm_Finished) {
    last_wg_finished = Sys::boostedTick();
    generator->register_event(
        this,
        EventType::Wight_Grad_Comm_Finished_After_Delay,
        mdata,
        weight_grad_update_time);
    return;
  } else if (event == EventType::Input_Grad_Comm_Finished) {
    last_ig_finished = Sys::boostedTick();
    generator->register_event(
        this,
        EventType::Input_Grad_Comm_Finished_After_Delay,
        mdata,
        input_grad_update_time);
    return;
  } else if (event == EventType::Fwd_Comm_Finished) {
    last_fwd_finished = Sys::boostedTick();
    generator->register_event(
        this, EventType::Fwd_Comm_Finished_After_Delay, mdata, fwd_update_time);
    return;
  }
  int data = ((IntData*)mdata)->data;
  IntData* intData = ((IntData*)mdata);
  if (event == EventType::Wight_Grad_Comm_Finished_After_Delay) {
    #ifndef PHY_MTP
    if (generator->id == 0) {
      std::cout << "***** info: weight gradient collective for layer: " << id
                << " is finished************" << std::endl;
    }
    weight_grad_datasets[data]->finish_tick += weight_grad_update_time;
    total_weight_grad_comm += weight_grad_datasets[data]->finish_tick -
        weight_grad_datasets[data]->creation_tick;

    if (weight_grad_datasets.size() == 1 &&
        wg_barrier == CollectiveBarrier::Blocking) { 
      total_waiting_for_wg_comm += weight_grad_datasets[data]->finish_tick -
          weight_grad_datasets[data]->creation_tick;
      update_stream_stats(weight_grad_datasets[data]);
      int dataset_streams = weight_grad_datasets[data]->total_streams;
      delete weight_grad_datasets[data];
      weight_grad_datasets.erase(data);
      workload->call(EventType::General, NULL);
      generator->increase_finished_streams(dataset_streams);
      delete intData;
      return;
    } else if (started_waiting_for_weight_grad.size() > 0) {  
      total_waiting_for_wg_comm += weight_grad_datasets[data]->finish_tick -
          started_waiting_for_weight_grad.front();
      started_waiting_for_weight_grad.pop_front();
      update_stream_stats(weight_grad_datasets[data]);
      int dataset_streams = weight_grad_datasets[data]->total_streams;
      delete weight_grad_datasets[data];
      weight_grad_datasets.erase(data);
      workload->call(EventType::General, NULL);
      generator->increase_finished_streams(dataset_streams);
      delete intData;
      return;
    }
    update_stream_stats(weight_grad_datasets[data]);
    int dataset_streams = weight_grad_datasets[data]->total_streams;
    delete weight_grad_datasets[data];
    weight_grad_datasets.erase(data);
    generator->increase_finished_streams(dataset_streams);
    delete intData;
    #else
    workload->call(EventType::General, NULL);
    generator->increase_finished_streams(1);
    #endif
    return;
  } else if (event == EventType::Input_Grad_Comm_Finished_After_Delay) {
    #ifndef PHY_MTP
    if (generator->id == 0) {
      std::cout << "***** info: input gradient collective for layer: " << id
                << " is finished************" << std::endl;
    }
    input_grad_datasets[data]->finish_tick += input_grad_update_time;
    total_input_grad_comm += input_grad_datasets[data]->finish_tick -
        input_grad_datasets[data]->creation_tick;
    if (input_grad_datasets.size() == 1 &&
        ig_barrier == CollectiveBarrier::Blocking) {
      total_waiting_for_ig_comm += input_grad_datasets[data]->finish_tick -
          input_grad_datasets[data]->creation_tick;
      update_stream_stats(input_grad_datasets[data]);
      int dataset_streams = input_grad_datasets[data]->total_streams;
      delete input_grad_datasets[data];
      input_grad_datasets.erase(data);
      workload->call(EventType::General, NULL);
      generator->increase_finished_streams(dataset_streams);
      delete intData;
      return;
    } else if (started_waiting_for_input_grad.size() > 0) {
      total_waiting_for_ig_comm += input_grad_datasets[data]->finish_tick -
          started_waiting_for_input_grad.front();
      started_waiting_for_input_grad.pop_front();
      update_stream_stats(input_grad_datasets[data]);
      int dataset_streams = input_grad_datasets[data]->total_streams;
      delete input_grad_datasets[data];
      input_grad_datasets.erase(data);
      workload->call(EventType::General, NULL);
      generator->increase_finished_streams(dataset_streams);
      delete intData;
      return;
    }
    update_stream_stats(input_grad_datasets[data]);
    int dataset_streams = input_grad_datasets[data]->total_streams;
    delete input_grad_datasets[data];
    input_grad_datasets.erase(data);
    generator->increase_finished_streams(dataset_streams);
    delete intData;
    #else
    workload->call(EventType::General, NULL);
    generator->increase_finished_streams(1);
    #endif
    return;
  } else if (event == EventType::Fwd_Comm_Finished_After_Delay) {
    #ifndef PHY_MTP
    if (generator->id == 0) {
      std::cout << "***** info: fwd pass comm collective for layer: " << id
                << " is finished************" << std::endl;
    }
    fwd_pass_datasets[data]->finish_tick += fwd_update_time;
    total_fwd_comm += fwd_pass_datasets[data]->finish_tick -
        fwd_pass_datasets[data]->creation_tick;
    if (fwd_pass_datasets.size() == 1 &&
        fwd_barrier == CollectiveBarrier::Blocking) {
      total_waiting_for_fwd_comm += fwd_pass_datasets[data]->finish_tick -
          fwd_pass_datasets[data]->creation_tick;
      update_stream_stats(fwd_pass_datasets[data]);
      int dataset_streams = fwd_pass_datasets[data]->total_streams;
      delete fwd_pass_datasets[data];
      fwd_pass_datasets.erase(data);
      workload->call(EventType::General, NULL);
      generator->increase_finished_streams(dataset_streams);
      delete intData;
      return;
    } else if (started_waiting_for_fwd_pass.size() > 0) {
      total_waiting_for_fwd_comm += fwd_pass_datasets[data]->finish_tick -
          started_waiting_for_fwd_pass.front();
      started_waiting_for_fwd_pass.pop_front();
      update_stream_stats(fwd_pass_datasets[data]);
      int dataset_streams = fwd_pass_datasets[data]->total_streams;
      delete fwd_pass_datasets[data];
      fwd_pass_datasets.erase(data);
      workload->call(EventType::General, NULL);
      generator->increase_finished_streams(dataset_streams);
      delete intData;
      return;
    }
    update_stream_stats(fwd_pass_datasets[data]);
    int dataset_streams = fwd_pass_datasets[data]->total_streams;
    delete fwd_pass_datasets[data];
    fwd_pass_datasets.erase(data);
    generator->increase_finished_streams(dataset_streams);
    delete intData;
    #else
    workload->call(EventType::General, NULL);
    generator->increase_finished_streams(1);
    #endif
    return;
  }
}