CollectivePhase Sys::generate_collective_phase()

in astra-sim-alibabacloud/astra-sim/system/Sys.cc [1093:1315]


CollectivePhase Sys::generate_collective_phase(
    ComType collective_type,
    int layer_num,
    BasicLogicalTopology* topology,
    uint64_t data_size,
    int queue_id,
    RingTopology::Direction direction,
    InjectionPolicy injection_policy,
    CollectiveImplementation* collective_implementation,
    bool boost_mode) {
        MockNcclLog* NcclLog = MockNcclLog::getInstance();

        if (collective_implementation->type == CollectiveImplementationType::Ring ||
              collective_implementation->type ==
                  CollectiveImplementationType::OneRing) {
            CollectivePhase vn(
                this,
                queue_id,
                new Ring(
                    collective_type,
                    id,
                    layer_num,
                    (RingTopology*)topology,
                    data_size,
                    direction,
                    injection_policy,
                    boost_mode));
                  return vn;
          } else if (
              collective_implementation->type == CollectiveImplementationType::Direct ||
              collective_implementation->type ==
                  CollectiveImplementationType::OneDirect) {
            CollectivePhase vn(
                this,
                queue_id,
                new AllToAll(
                    collective_type,
                    ((DirectCollectiveImplementation*)collective_implementation)
                        ->direct_collective_window,
                    id,
                    layer_num,
                    (RingTopology*)topology,
                    data_size,
                    direction,
                    InjectionPolicy::Normal,
                    boost_mode));
                return vn;
          } else if (
              collective_implementation->type ==
              CollectiveImplementationType::DoubleBinaryTree) {
            CollectivePhase vn(
                this,
                queue_id,
                new DoubleBinaryTreeAllReduce(
                    id, layer_num, (BinaryTree*)topology, data_size, boost_mode));
                return vn;
          } else if (
              collective_implementation->type ==
                  CollectiveImplementationType::HalvingDoubling ||
              collective_implementation->type ==
                  CollectiveImplementationType::OneHalvingDoubling) {
            CollectivePhase vn(
                this,
                queue_id,
                new HalvingDoubling(
                    collective_type,
                    id,
                    layer_num,
                    (RingTopology*)topology,
                    data_size,
                    boost_mode));
                    return vn;
          } else if(collective_implementation->type == CollectiveImplementationType::NcclFlowModel) {
              ParallelStrategy  comm_ps;
              if (workload->current_state == Workload::LoopState::Forward_Pass){
                comm_ps = static_cast<ParallelStrategy> (workload->layers[workload->index]->fwd_pass_group_type);
              }
              else if(workload->current_state == Workload::LoopState::Input_Gradient){
                comm_ps = static_cast<ParallelStrategy> (workload->layers[workload->index]->input_grad_group_type);
              }
              else if(workload->current_state == Workload::LoopState::Weight_Gradient){
                comm_ps = static_cast<ParallelStrategy> (workload->layers[workload->index]->weight_grad_group_type);
              }
              MockNccl::ncclInfo *nccl_info;
              std::shared_ptr<void> ptr_FlowModels;
              {
                Sys::sysCriticalSection cs;
                nccl_info = get_nccl_Info(comm_ps,data_size,collective_type);
                ptr_FlowModels = generate_flow_model(comm_ps, data_size, collective_type); 
                cs.ExitSection();
              }
              
              if(nccl_info->algorithm == NCCL_ALGO_RING) {
                std::shared_ptr<MockNccl::FlowModels> RingFlowModels = std::static_pointer_cast<MockNccl::FlowModels>(ptr_FlowModels);
                std::map<int,std::map<int,std::vector<int>>> channels;
                {
                  Sys::sysCriticalSection cs;
                  channels = mock_nccl_comms[comm_ps]->get_rings();
                  cs.ExitSection();
                }
                NcclLog->writeLog(NcclLogLevel::DEBUG,"rank %d generate FlowModels",id);
                if(RingFlowModels != nullptr){
                  NcclLog->writeLog(NcclLogLevel::DEBUG,"rank %d NcclMock generate  %d channel and flow model count:  %d",id,channels.size(),RingFlowModels->size());
                  for (auto flow : *RingFlowModels) {
                    int prev;
                    int parent_flow_id;
                    int child_flow_id;
                    if (flow.second.prev.size() == 0) {
                      prev = -1;
                    } else {
                      prev = flow.second.prev[0];
                    }
                    if (flow.second.child_flow_id.size() == 0) {
                      child_flow_id = -1;
                    } else {
                      child_flow_id = flow.second.child_flow_id[0];
                    }
                    if (flow.second.parent_flow_id.size() == 0) {
                      parent_flow_id = -1;
                    } else {
                      parent_flow_id = flow.second.parent_flow_id[0];
                    }
                    NcclLog->writeLog(NcclLogLevel::DEBUG," %d,  %d,  %d to  %d current_flow_id %d prev rank:  %d parent_flow_id:  %d child_flow_id:  %d chunk_id:  %d flow_size: %lu chunk_count:  %d ",flow.first.first,flow.first.second,flow.second.src,flow.second.dest,flow.second.flow_id,prev,parent_flow_id,child_flow_id,flow.second.chunk_id,flow.second.flow_size,flow.second.chunk_count);
                  }
                }
                CollectivePhase vn(
                    this,
                    queue_id,
                    new NcclTreeFlowModel(
                        collective_type,
                        id,
                        layer_num,
                        (RingTopology*)topology,
                        data_size,
                        direction,
                        injection_policy,
                        boost_mode,
                        RingFlowModels,
                        channels.size()));
                return vn;
              } else if(nccl_info->algorithm == NCCL_ALGO_TREE) {
                std::shared_ptr<MockNccl::FlowModels> TreeFlowModels;
                MockNccl::TreeChannels treechannels;
                {
                  Sys::sysCriticalSection cs;
                  TreeFlowModels = std::static_pointer_cast<MockNccl::FlowModels>(ptr_FlowModels);
                  treechannels = mock_nccl_comms[comm_ps]->get_treechannels();
                  cs.ExitSection();
                }
                CollectivePhase vn(
                    this,
                    queue_id,
                    new NcclTreeFlowModel(
                        collective_type,
                        id,
                        layer_num,
                        (RingTopology*)topology,
                        data_size,
                        direction,
                        injection_policy,
                        boost_mode,
                        TreeFlowModels,
                        treechannels.size()));
                return vn;

              } else if(nccl_info->algorithm == NCCL_ALGO_NVLS) {
                collective_type = ComType::All_Reduce_NVLS;
                std::shared_ptr<MockNccl::FlowModels> RingFlowModels = std::static_pointer_cast<MockNccl::FlowModels>(ptr_FlowModels);
                MockNccl::TreeChannels treechannels;
                {
                  Sys::sysCriticalSection cs;
                  treechannels = mock_nccl_comms[comm_ps]->get_treechannels();
                  cs.ExitSection();
                }
                NcclLog->writeLog(NcclLogLevel::DEBUG,"rank %d generate FlowModels",id);
                if(RingFlowModels != nullptr){
                  NcclLog->writeLog(NcclLogLevel::DEBUG,"rank %d NcclMock generate  %d channel and flow model count:  %d",id,treechannels.size(),RingFlowModels->size());
                  for (auto flow : *RingFlowModels) {
                    int prev;
                    int parent_flow_id;
                    int child_flow_id;
                    if (flow.second.prev.size() == 0) {
                      prev = -1;
                    } else {
                      prev = flow.second.prev[0];
                    }
                    if (flow.second.child_flow_id.size() == 0) {
                      child_flow_id = -1;
                    } else {
                      child_flow_id = flow.second.child_flow_id[0];
                    }
                    if (flow.second.parent_flow_id.size() == 0) {
                      parent_flow_id = -1;
                    } else {
                      parent_flow_id = flow.second.parent_flow_id[0];
                    }
                    NcclLog->writeLog(NcclLogLevel::DEBUG," %d,  %d,  %d to  %d current_flow_id %d prev rank:  %d parent_flow_id:  %d child_flow_id:  %d chunk_id:  %d flow_size: %lu chunk_count:  %d ",flow.first.first,flow.first.second,flow.second.src,flow.second.dest,flow.second.flow_id,prev,parent_flow_id,child_flow_id,flow.second.chunk_id,flow.second.flow_size,flow.second.chunk_count);
                  }
                }
                CollectivePhase vn(
                    this,
                    queue_id,
                    new NcclTreeFlowModel(
                        collective_type,
                        id,
                        layer_num,
                        (RingTopology*)topology,
                        data_size,
                        direction,
                        injection_policy,
                        boost_mode,
                        RingFlowModels,
                        treechannels.size()));
                return vn;
              } 

          } else {
            std::cerr
                << "Error: No known collective implementation for collective phase"
                << std::endl;
            exit(1);
          }
}