astra-sim-alibabacloud/astra-sim/system/collective/AllToAll.cc (90 lines of code) (raw):
/******************************************************************************
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
*******************************************************************************/
#include "AllToAll.hh"
namespace AstraSim {
AllToAll::AllToAll(
ComType type,
int window,
int id,
int layer_num,
RingTopology* allToAllTopology,
uint64_t data_size,
RingTopology::Direction direction,
InjectionPolicy injection_policy,
bool boost_mode)
: Ring(
type,
id,
layer_num,
allToAllTopology,
data_size,
direction,
injection_policy,
boost_mode) {
this->name = Name::AllToAll;
this->enabled = true;
this->middle_point = nodes_in_ring - 1;
if (boost_mode) {
this->enabled = allToAllTopology->is_enabled();
}
if (window == -1) {
parallel_reduce = nodes_in_ring - 1;
} else {
parallel_reduce = (int)std::min(window, nodes_in_ring - 1);
}
if (type == ComType::All_to_All) {
this->stream_count = nodes_in_ring - 1;
}
}
int AllToAll::get_non_zero_latency_packets() {
if (((RingTopology*)logicalTopology)->dimension !=
RingTopology::Dimension::Local) {
return parallel_reduce * 1;
} else {
return (nodes_in_ring - 1) * parallel_reduce * 1;
}
}
void AllToAll::process_max_count() {
if (remained_packets_per_max_count > 0)
remained_packets_per_max_count--;
if (remained_packets_per_max_count == 0) {
max_count--;
release_packets();
remained_packets_per_max_count = 1;
current_receiver = ((RingTopology*)logicalTopology)
->get_receiver_node(current_receiver, direction);
if (current_receiver == id) {
current_receiver = ((RingTopology*)logicalTopology)
->get_receiver_node(current_receiver, direction);
}
current_sender = ((RingTopology*)logicalTopology)
->get_sender_node(current_sender, direction);
if (current_sender == id) {
current_sender = ((RingTopology*)logicalTopology)
->get_sender_node(current_sender, direction);
}
}
}
void AllToAll::run(EventType event, CallData* data) {
if (event == EventType::General) {
free_packets += 1;
if (comType == ComType::All_Reduce && stream_count <= middle_point) {
if (total_packets_received < middle_point) {
return;
}
for (int i = 0; i < parallel_reduce; i++) {
ready();
}
iteratable();
} else {
ready();
iteratable();
}
} else if (event == EventType::PacketReceived) {
total_packets_received++;
insert_packet(nullptr);
} else if (event == EventType::StreamInit) {
for (int i = 0; i < parallel_reduce; i++) {
insert_packet(nullptr);
}
}
}
} // namespace AstraSim