astra-sim-alibabacloud/astra-sim/system/collective/DoubleBinaryTreeAllReduce.cc (329 lines of code) (raw):
/******************************************************************************
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
*******************************************************************************/
#include "DoubleBinaryTreeAllReduce.hh"
#include "astra-sim/system/PacketBundle.hh"
#include "astra-sim/system/RecvPacketEventHadndlerData.hh"
namespace AstraSim {
DoubleBinaryTreeAllReduce::DoubleBinaryTreeAllReduce(
int id,
int layer_num,
BinaryTree* tree,
uint64_t data_size,
bool boost_mode)
: Algorithm(layer_num) {
this->id = id;
this->logicalTopology = tree;
this->data_size = data_size;
this->state = State::Begin;
this->reductions = 0;
this->parent = tree->get_parent_id(id);
this->left_child = tree->get_left_child_id(id);
this->right_child = tree->get_right_child_id(id);
this->type = tree->get_node_type(id);
this->final_data_size = data_size;
this->comType = ComType::All_Reduce;
this->name = Name::DoubleBinaryTree;
this->enabled = true;
if (boost_mode) {
this->enabled = tree->is_enabled(id);
}
}
void DoubleBinaryTreeAllReduce::run(EventType event, CallData* data) {
if (state == State::Begin && type == BinaryTree::Type::Leaf) { // leaf.1
(new PacketBundle(
stream->owner,
stream,
false,
false,
data_size,
MemBus::Transmition::Usual))
->send_to_MA();
state = State::SendingDataToParent;
return;
} else if (
state == State::SendingDataToParent &&
type == BinaryTree::Type::Leaf) { // leaf.3
sim_request snd_req;
snd_req.srcRank = stream->owner->id;
snd_req.dstRank = parent;
snd_req.tag = stream->stream_num;
snd_req.reqType = UINT8;
snd_req.vnet = this->stream->current_queue_id;
snd_req.layerNum = layer_num;
stream->owner->front_end_sim_send(
0,
Sys::dummy_data,
data_size,
UINT8,
parent,
stream->stream_num,
&snd_req,
&Sys::handleEvent,
nullptr);
sim_request rcv_req;
rcv_req.vnet = this->stream->current_queue_id;
rcv_req.layerNum = layer_num;
RecvPacketEventHadndlerData* ehd = new RecvPacketEventHadndlerData(
stream,
stream->owner->id,
EventType::PacketReceived,
stream->current_queue_id,
stream->stream_num);
stream->owner->front_end_sim_recv(
0,
Sys::dummy_data,
data_size,
UINT8,
parent,
stream->stream_num,
&rcv_req,
&Sys::handleEvent,
ehd);
state = State::WaitingDataFromParent;
return;
} else if (
state == State::WaitingDataFromParent &&
type == BinaryTree::Type::Leaf) { // leaf.4
(new PacketBundle(
stream->owner,
stream,
false,
false,
data_size,
MemBus::Transmition::Usual))
->send_to_NPU();
state = State::End;
return;
} else if (state == State::End && type == BinaryTree::Type::Leaf) { // leaf.5
exit();
return;
}
else if (
state == State::Begin &&
type == BinaryTree::Type::Intermediate) { // int.1
sim_request rcv_req;
rcv_req.vnet = this->stream->current_queue_id;
rcv_req.layerNum = layer_num;
RecvPacketEventHadndlerData* ehd = new RecvPacketEventHadndlerData(
stream,
stream->owner->id,
EventType::PacketReceived,
stream->current_queue_id,
stream->stream_num);
stream->owner->front_end_sim_recv(
0,
Sys::dummy_data,
data_size,
UINT8,
left_child,
stream->stream_num,
&rcv_req,
&Sys::handleEvent,
ehd);
sim_request rcv_req2;
rcv_req2.vnet = this->stream->current_queue_id;
rcv_req2.layerNum = layer_num;
RecvPacketEventHadndlerData* ehd2 = new RecvPacketEventHadndlerData(
stream,
stream->owner->id,
EventType::PacketReceived,
stream->current_queue_id,
stream->stream_num);
stream->owner->front_end_sim_recv(
0,
Sys::dummy_data,
data_size,
UINT8,
right_child,
stream->stream_num,
&rcv_req2,
&Sys::handleEvent,
ehd2);
state = State::WaitingForTwoChildData;
return;
} else if (
state == State::WaitingForTwoChildData &&
type == BinaryTree::Type::Intermediate &&
event == EventType::PacketReceived) { // int.2
(new PacketBundle(
stream->owner,
stream,
true,
false,
data_size,
MemBus::Transmition::Usual))
->send_to_NPU();
state = State::WaitingForOneChildData;
return;
} else if (
state == State::WaitingForOneChildData &&
type == BinaryTree::Type::Intermediate &&
event == EventType::PacketReceived) { // int.3
(new PacketBundle(
stream->owner,
stream,
true,
true,
data_size,
MemBus::Transmition::Usual))
->send_to_NPU();
state = State::SendingDataToParent;
return;
} else if (
reductions < 1 && type == BinaryTree::Type::Intermediate &&
event == EventType::General) { // int.4
reductions++;
return;
} else if (
state == State::SendingDataToParent &&
type == BinaryTree::Type::Intermediate) { // int.5
sim_request snd_req;
snd_req.srcRank = stream->owner->id;
snd_req.dstRank = parent;
snd_req.tag = stream->stream_num;
snd_req.reqType = UINT8;
snd_req.vnet = this->stream->current_queue_id;
snd_req.layerNum = layer_num;
stream->owner->front_end_sim_send(
0,
Sys::dummy_data,
data_size,
UINT8,
parent,
stream->stream_num,
&snd_req,
&Sys::handleEvent,
nullptr);
sim_request rcv_req;
rcv_req.vnet = this->stream->current_queue_id;
rcv_req.layerNum = layer_num;
RecvPacketEventHadndlerData* ehd = new RecvPacketEventHadndlerData(
stream,
stream->owner->id,
EventType::PacketReceived,
stream->current_queue_id,
stream->stream_num);
stream->owner->front_end_sim_recv(
0,
Sys::dummy_data,
data_size,
UINT8,
parent,
stream->stream_num,
&rcv_req,
&Sys::handleEvent,
ehd);
state = State::WaitingDataFromParent;
} else if (
state == State::WaitingDataFromParent &&
type == BinaryTree::Type::Intermediate &&
event == EventType::PacketReceived) { // int.6
(new PacketBundle(
stream->owner,
stream,
true,
true,
data_size,
MemBus::Transmition::Usual))
->send_to_NPU();
state = State::SendingDataToChilds;
return;
} else if (
state == State::SendingDataToChilds &&
type == BinaryTree::Type::Intermediate) {
sim_request snd_req;
snd_req.srcRank = stream->owner->id;
snd_req.dstRank = left_child;
snd_req.tag = stream->stream_num;
snd_req.reqType = UINT8;
snd_req.vnet = this->stream->current_queue_id;
snd_req.layerNum = layer_num;
stream->owner->front_end_sim_send(
0,
Sys::dummy_data,
data_size,
UINT8,
left_child,
stream->stream_num,
&snd_req,
&Sys::handleEvent,
nullptr);
sim_request snd_req2;
snd_req2.srcRank = stream->owner->id;
snd_req2.dstRank = left_child;
snd_req2.tag = stream->stream_num;
snd_req2.reqType = UINT8;
snd_req2.vnet = this->stream->current_queue_id;
snd_req2.layerNum = layer_num;
stream->owner->front_end_sim_send(
0,
Sys::dummy_data,
data_size,
UINT8,
right_child,
stream->stream_num,
&snd_req2,
&Sys::handleEvent,
nullptr);
exit();
return;
}
else if (state == State::Begin && type == BinaryTree::Type::Root) { // root.1
int only_child_id = left_child >= 0 ? left_child : right_child;
sim_request rcv_req;
rcv_req.vnet = this->stream->current_queue_id;
rcv_req.layerNum = layer_num;
RecvPacketEventHadndlerData* ehd = new RecvPacketEventHadndlerData(
stream,
stream->owner->id,
EventType::PacketReceived,
stream->current_queue_id,
stream->stream_num);
stream->owner->front_end_sim_recv(
0,
Sys::dummy_data,
data_size,
UINT8,
only_child_id,
stream->stream_num,
&rcv_req,
&Sys::handleEvent,
ehd);
state = State::WaitingForOneChildData;
} else if (
state == State::WaitingForOneChildData &&
type == BinaryTree::Type::Root) { // root.2
(new PacketBundle(
stream->owner,
stream,
true,
true,
data_size,
MemBus::Transmition::Usual))
->send_to_NPU();
state = State::SendingDataToChilds;
return;
} else if (
state == State::SendingDataToChilds &&
type == BinaryTree::Type::Root) { // root.2
int only_child_id = left_child >= 0 ? left_child : right_child;
sim_request snd_req;
snd_req.srcRank = stream->owner->id;
snd_req.dstRank = only_child_id;
snd_req.tag = stream->stream_num;
snd_req.reqType = UINT8;
snd_req.vnet = this->stream->current_queue_id;
snd_req.layerNum = layer_num;
stream->owner->front_end_sim_send(
0,
Sys::dummy_data,
data_size,
UINT8,
only_child_id,
stream->stream_num,
&snd_req,
&Sys::handleEvent,
nullptr);
exit();
return;
}
}
} // namespace AstraSim