astra-sim-alibabacloud/astra-sim/system/collective/Ring.cc (281 lines of code) (raw):

/****************************************************************************** This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. *******************************************************************************/ #include "Ring.hh" #include "astra-sim/system/PacketBundle.hh" #include "astra-sim/system/RecvPacketEventHadndlerData.hh" namespace AstraSim { Ring::Ring( ComType type, int id, int layer_num, RingTopology* ring_topology, uint64_t data_size, RingTopology::Direction direction, InjectionPolicy injection_policy, bool boost_mode) : Algorithm(layer_num) { this->comType = type; this->id = id; this->logicalTopology = ring_topology; this->data_size = data_size; this->direction = direction; this->nodes_in_ring = ring_topology->get_nodes_in_ring(); this->current_receiver = ring_topology->get_receiver_node(id, direction); this->current_sender = ring_topology->get_sender_node(id, direction); this->parallel_reduce = 1; // each ring has xx channels this->injection_policy = injection_policy; this->total_packets_sent = 0; this->total_packets_received = 0; this->free_packets = 0; this->zero_latency_packets = 0; this->non_zero_latency_packets = 0; this->toggle = false; this->name = Name::Ring; this->enabled = true; if (boost_mode) { this->enabled = ring_topology->is_enabled(); } if (ring_topology->dimension == RingTopology::Dimension::Local) { transmition = MemBus::Transmition::Fast; } else { transmition = MemBus::Transmition::Usual; } switch (type) { case ComType::All_Reduce: stream_count = 2 * (nodes_in_ring - 1); break; case ComType::All_to_All: this->stream_count = ((nodes_in_ring - 1) * nodes_in_ring) / 2; switch (injection_policy) { case InjectionPolicy::Aggressive: this->parallel_reduce = nodes_in_ring - 1; break; case InjectionPolicy::Normal: this->parallel_reduce = 1; break; default: this->parallel_reduce = 1; break; } break; default: stream_count = nodes_in_ring - 1; } if (type == ComType::All_to_All || type == ComType::All_Gather) { max_count = 0; } else { max_count = nodes_in_ring - 1; } remained_packets_per_message = 1; remained_packets_per_max_count = 1; switch (type) { case ComType::All_Reduce: this->final_data_size = data_size; this->msg_size = data_size / nodes_in_ring; break; case ComType::All_Gather: this->final_data_size = data_size * nodes_in_ring; this->msg_size = data_size; break; case ComType::Reduce_Scatter: this->final_data_size = data_size / nodes_in_ring; this->msg_size = data_size / nodes_in_ring; break; case ComType::All_to_All: this->final_data_size = data_size; this->msg_size = data_size / nodes_in_ring; break; default:; } } int Ring::get_non_zero_latency_packets() { return (nodes_in_ring - 1) * parallel_reduce * 1; } void Ring::run(EventType event, CallData* data) { if (event == EventType::General) { free_packets += 1; ready(); iteratable(); } else if (event == EventType::PacketReceived) { total_packets_received++; insert_packet(nullptr); } else if (event == EventType::StreamInit) { for (int i = 0; i < parallel_reduce; i++) { insert_packet(nullptr); } } } void Ring::release_packets() { for (auto packet : locked_packets) { packet->set_notifier(this); } if (NPU_to_MA == true) { (new PacketBundle( stream->owner, stream, locked_packets, processed, send_back, msg_size, transmition)) ->send_to_MA(); } else { (new PacketBundle( stream->owner, stream, locked_packets, processed, send_back, msg_size, transmition)) ->send_to_NPU(); } locked_packets.clear(); } void Ring::process_stream_count() { if (remained_packets_per_message > 0) { remained_packets_per_message--; } if (id == 0) { } if (remained_packets_per_message == 0 && stream_count > 0) { stream_count--; if (stream_count > 0) { remained_packets_per_message = 1; } } if (remained_packets_per_message == 0 && stream_count == 0 && stream->state != StreamState::Dead) { stream->changeState(StreamState::Zombie); } } void Ring::process_max_count() { if (remained_packets_per_max_count > 0) remained_packets_per_max_count--; if (remained_packets_per_max_count == 0) { max_count--; release_packets(); remained_packets_per_max_count = 1; } } void Ring::reduce() { process_stream_count(); packets.pop_front(); free_packets--; total_packets_sent++; // not_delivered++; } bool Ring::iteratable() { if (stream_count == 0 && free_packets == (parallel_reduce * 1)) { // && not_delivered==0 exit(); return false; } return true; } void Ring::insert_packet(Callable* sender) { if (!enabled) { return; } if (zero_latency_packets == 0 && non_zero_latency_packets == 0) { zero_latency_packets = parallel_reduce * 1; non_zero_latency_packets = get_non_zero_latency_packets(); //(nodes_in_ring-1)*parallel_reduce*1; toggle = !toggle; } if (zero_latency_packets > 0) { packets.push_back(MyPacket( stream->current_queue_id, current_sender, current_receiver)); // vnet Must be changed for alltoall topology packets.back().sender = sender; locked_packets.push_back(&packets.back()); processed = false; send_back = false; NPU_to_MA = true; process_max_count(); zero_latency_packets--; return; } else if (non_zero_latency_packets > 0) { // if(id == 0) std::cout << "non_zero_latency_packets > 0" << std::endl; packets.push_back(MyPacket( stream->current_queue_id, current_sender, current_receiver)); // vnet Must be changed for alltoall topology packets.back().sender = sender; locked_packets.push_back(&packets.back()); if (comType == ComType::Reduce_Scatter || (comType == ComType::All_Reduce && toggle)) { processed = true; } else { processed = false; } if (non_zero_latency_packets <= parallel_reduce * 1) { send_back = false; } else { send_back = true; } NPU_to_MA = false; std::cout << "id: " << id << " non-zero latency packets at tick: " << Sys::boostedTick() << std::endl; process_max_count(); non_zero_latency_packets--; return; } Sys::sys_panic("should not inject nothing!"); } bool Ring::ready() { if (stream->state == StreamState::Created || stream->state == StreamState::Ready) { stream->changeState(StreamState::Executing); } if (!enabled || packets.size() == 0 || stream_count == 0 || free_packets == 0) { return false; } MyPacket packet = packets.front(); sim_request snd_req; snd_req.srcRank = id; snd_req.dstRank = packet.preferred_dest; snd_req.tag = stream->stream_num; snd_req.reqType = UINT8; snd_req.vnet = this->stream->current_queue_id; snd_req.layerNum = layer_num; stream->owner->front_end_sim_send( 0, Sys::dummy_data, msg_size, UINT8, packet.preferred_dest, stream->stream_num, &snd_req, &Sys::handleEvent, nullptr); // stream_num+(packet.preferred_dest*50) sim_request rcv_req; rcv_req.vnet = this->stream->current_queue_id; rcv_req.layerNum = layer_num; RecvPacketEventHadndlerData* ehd = new RecvPacketEventHadndlerData( stream, stream->owner->id, EventType::PacketReceived, packet.preferred_vnet, packet.stream_num); stream->owner->front_end_sim_recv( 0, Sys::dummy_data, msg_size, UINT8, packet.preferred_src, stream->stream_num, &rcv_req, &Sys::handleEvent, ehd); // stream_num+(owner->id*50) reduce(); return true; } void Ring::exit() { if (packets.size() != 0) { packets.clear(); } if (locked_packets.size() != 0) { locked_packets.clear(); } stream->owner->proceed_to_next_vnet_baseline((StreamBaseline*)stream); return; } } // namespace AstraSim