astra-sim-alibabacloud/astra-sim/system/fast-backend/FastBackEnd.hh (138 lines of code) (raw):
/*
*Copyright (c) 2024, Alibaba Group;
*Licensed under the Apache License, Version 2.0 (the "License");
*you may not use this file except in compliance with the License.
*You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*Unless required by applicable law or agreed to in writing, software
*distributed under the License is distributed on an "AS IS" BASIS,
*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*See the License for the specific language governing permissions and
*limitations under the License.
*/
#ifndef __FASTBACKEND_HH__
#define __FASTBACKEND_HH__
#include <cassert>
#include <iostream>
#include <map>
#include <tuple>
#include "astra-sim/system/AstraNetworkAPI.hh"
namespace AstraSim {
class FastBackEnd;
class WrapperData : public MetaData {
public:
enum class Type { FastSendRecv, DetailedSend, DetailedRecv, Undefined };
Type type;
void (*msg_handler)(void* fun_arg);
void* fun_arg;
WrapperData(Type type, void (*msg_handler)(void* fun_arg), void* fun_arg);
};
class WrapperRelayData : public WrapperData {
public:
int partner_node;
uint64_t comm_size;
double creation_time;
FastBackEnd* fast_backend;
WrapperRelayData(
FastBackEnd* fast_backend,
Type type,
double creation_time,
int partner_node,
uint64_t comm_size,
void (*msg_handler)(void* fun_arg),
void* fun_arg);
};
class InflightPairsMap {
public:
InflightPairsMap() = default;
/**
* Insert a pair to the inflight map.
* @param src
* @param dest
* @param tag
* @param communicationSize
* @param simulationType
*/
void insert(
int src,
int dest,
int tag,
int communicationSize,
WrapperData::Type simulationType);
/**
* Remove a pair from the inflight map.
* @param src
* @param dest
* @param tag
* @param communicationSize
* @return <true, WrapperData::Type> if matching pair exists, return <false,
* Undefined> if not.
*/
std::pair<bool, WrapperData::Type> pop(
int src,
int dest,
int tag,
int communicationSize);
/**
* Print out all residing pairs in the inflight map.
*/
void print();
private:
// key: <src, dest, tag, communicationSize>
std::map<std::tuple<int, int, int, int>, WrapperData::Type> inflightPairs;
};
class DynamicLatencyTable {
public:
DynamicLatencyTable() = default;
/**
* Insert a new latency data to the (src, dest) table.
*
* @param nodesPair (src, dest) pair
* @param communicationSize communication size in bytes
* @param latency
*/
void insertLatencyData(
std::pair<int, int> nodesPair,
int communicationSize,
double latency);
/**
* lookup whether there exists a exact matching latency value in the table.
*
* @param nodesPair (src, dest) pair for lookup
* @param communicationSize communication size (in bytes) for lookup
* @return <true, latency> is exists, <false, -1> if not.
*/
std::pair<bool, double> lookupLatency(
std::pair<int, int> nodesPair,
int communicationSize);
/**
* Linearly interpolate/extrapolate two nearest points and predict latency.
*
* @param nodesPair (src, dest) pair for prediction
* @param communicationSize communication size (in bytes) for prediction.
* @return predicted traffic latency
*/
double predictLatency(std::pair<int, int> nodesPair, int communicationSize);
/**
* Returns a boolean whether there's enough number of data points available
* for prediction.
* @param nodesPair (src, dest) pair for test
* @return true if there are at least 2 data points, false if not.
*/
bool canPredictLatency(std::pair<int, int> nodesPair);
/**
* Print latency data table.
*/
void print();
private:
std::map<std::pair<int, int>, std::map<int, double>> latencyTables;
std::map<std::pair<int, int>, int> latencyDataCountTable;
/**
* Insert new table for a new (src, dest) pair.
* @param nodesPair (src, dest) pair
*/
void insertNewTableForNode(std::pair<int, int> nodesPair);
};
class FastBackEnd : public AstraNetworkAPI {
public:
AstraNetworkAPI* wrapped_backend;
static void handleEvent(void* arg);
int sim_comm_size(sim_comm comm, int* size);
// int sim_comm_get_rank();
// int sim_comm_set_rank(sim_comm comm, int rank);
int sim_finish();
double sim_time_resolution();
int sim_init(AstraMemoryAPI* MEM);
timespec_t sim_get_time();
void sim_schedule(
timespec_t delta,
void (*fun_ptr)(void* fun_arg),
void* fun_arg);
int sim_send(
void* buffer,
uint64_t count,
int type,
int dst,
int tag,
sim_request* request,
void (*msg_handler)(void* fun_arg),
void* fun_arg);
int sim_recv(
void* buffer,
uint64_t count,
int type,
int src,
int tag,
sim_request* request,
void (*msg_handler)(void* fun_arg),
void* fun_arg);
FastBackEnd(int rank, AstraNetworkAPI* wrapped_backend);
void update_table_recv(
double start,
double finished,
int src,
uint64_t comm_size);
void update_table_send(
double start,
double finished,
int dst,
uint64_t comm_size);
int relay_send_request(
void* buffer,
uint64_t count,
int type,
int dst,
int tag,
sim_request* request,
void (*msg_handler)(void* fun_arg),
void* fun_arg);
int relay_recv_request(
void* buffer,
uint64_t count,
int type,
int src,
int tag,
sim_request* request,
void (*msg_handler)(void* fun_arg),
void* fun_arg);
int fast_send_recv_request(
double delay,
void (*msg_handler)(void* fun_arg),
void* fun_arg);
private:
static InflightPairsMap inflightPairsMap;
static DynamicLatencyTable dynamicLatencyTable;
};
} // namespace AstraSim
#endif