astra-sim-alibabacloud/astra-sim/system/fast-backend/FastBackEnd.cc (422 lines of code) (raw):

/* *Copyright (c) 2024, Alibaba Group; *Licensed under the Apache License, Version 2.0 (the "License"); *you may not use this file except in compliance with the License. *You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 *Unless required by applicable law or agreed to in writing, software *distributed under the License is distributed on an "AS IS" BASIS, *WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *See the License for the specific language governing permissions and *limitations under the License. */ #include "FastBackEnd.hh" namespace AstraSim { WrapperData::WrapperData( WrapperData::Type type, void (*msg_handler)(void* fun_arg), void* fun_arg) { this->type = type; this->msg_handler = msg_handler; this->fun_arg = fun_arg; } WrapperRelayData::WrapperRelayData( FastBackEnd* fast_backend, WrapperData::Type type, double creation_time, int partner_node, uint64_t comm_size, void (*msg_handler)(void* fun_arg), void* fun_arg) : WrapperData(type, msg_handler, fun_arg) { this->creation_time = creation_time; this->partner_node = partner_node; this->comm_size = comm_size; this->fast_backend = fast_backend; this->type = type; this->msg_handler = msg_handler; this->fun_arg = fun_arg; } void InflightPairsMap::insert( int src, int dest, int tag, int communicationSize, WrapperData::Type simulationType) { auto insertionResult = inflightPairs.emplace( std::make_tuple(src, dest, tag, communicationSize), simulationType); assert(insertionResult.second); } std::pair<bool, WrapperData::Type> InflightPairsMap::pop( int src, int dest, int tag, int communicationSize) { auto searchResult = inflightPairs.find(std::make_tuple(src, dest, tag, communicationSize)); if (searchResult == inflightPairs.end()) { return std::make_pair(false, WrapperData::Type::Undefined); } auto simulationType = searchResult->second; inflightPairs.erase(searchResult); return std::make_pair(true, simulationType); } void InflightPairsMap::print() { for (const auto& pair : inflightPairs) { std::cout << "src: " << std::get<0>(pair.first) << ", dest: " << std::get<1>(pair.first) << ", tag: " << std::get<2>(pair.first) << ", communicationSize: " << std::get<3>(pair.first) << std::endl; } } void DynamicLatencyTable::insertLatencyData( std::pair<int, int> nodesPair, int communicationSize, double latency) { auto latencyTable = latencyTables.find(nodesPair); if (latencyTable == latencyTables.end()) { insertNewTableForNode(nodesPair); latencyTable = latencyTables.find(nodesPair); } auto latencyInsertionResult = latencyTable->second.emplace(communicationSize, latency); if (latencyInsertionResult.second) { latencyDataCountTable[nodesPair]++; } } std::pair<bool, double> DynamicLatencyTable::lookupLatency( std::pair<int, int> nodesPair, int communicationSize) { auto latencyTable = latencyTables.find(nodesPair); if (latencyTable == latencyTables.end()) { return std::make_pair(false, -1); } auto latency = latencyTable->second.find(communicationSize); if (latency == latencyTable->second.end()) { return std::make_pair(false, -1); } return std::make_pair(true, latency->second); } double DynamicLatencyTable::predictLatency( std::pair<int, int> nodesPair, int communicationSize) { assert(canPredictLatency(nodesPair)); assert(!lookupLatency(nodesPair, communicationSize).first); auto latencyTable = latencyTables[nodesPair]; auto smallerPoint = latencyTable.upper_bound(communicationSize); auto largerPoint = latencyTable.lower_bound(communicationSize); if (smallerPoint == latencyTable.begin()) { smallerPoint = latencyTable.begin(); largerPoint = std::next(smallerPoint); } else if (largerPoint == latencyTable.end()) { largerPoint = std::prev(latencyTable.end()); smallerPoint = std::prev(largerPoint); } else { smallerPoint = std::prev(smallerPoint); } auto x1 = smallerPoint->first; auto y1 = smallerPoint->second; auto x2 = largerPoint->first; auto y2 = largerPoint->second; auto slope = (double)(y2 - y1) / (x2 - x1); auto predictedLatency = (int)(slope * (communicationSize - x1) + y1); assert(predictedLatency > 0); return predictedLatency; } bool DynamicLatencyTable::canPredictLatency(std::pair<int, int> nodesPair) { auto latencyDataCount = latencyDataCountTable.find(nodesPair); if (latencyDataCount == latencyDataCountTable.end()) { return false; } return latencyDataCount->second >= 2; }; void DynamicLatencyTable::print() { for (const auto& latencyTable : latencyTables) { std::cout << "src: " << latencyTable.first.first << ", dest: " << latencyTable.first.second << ", datapoints: " << latencyDataCountTable[latencyTable.first] << std::endl; for (const auto& latencyPair : latencyTable.second) { std::cout << "\t- commSize: " << latencyPair.first << " - latency: " << latencyPair.second << std::endl; } } } void DynamicLatencyTable::insertNewTableForNode(std::pair<int, int> nodesPair) { auto latencyTable = latencyTables.find(nodesPair); auto latencyDataCount = latencyDataCountTable.find(nodesPair); assert( (latencyTable == latencyTables.end()) && (latencyDataCount == latencyDataCountTable.end())); auto latencyTableInsertionResult = latencyTables.emplace(nodesPair, std::map<int, double>()); auto latencyDataCountInsertionResult = latencyDataCountTable.emplace(nodesPair, 0); assert( latencyTableInsertionResult.second && latencyDataCountInsertionResult.second); } InflightPairsMap FastBackEnd::inflightPairsMap; DynamicLatencyTable FastBackEnd::dynamicLatencyTable; void FastBackEnd::handleEvent(void* arg) { WrapperData* wrapperData = (WrapperData*)arg; if (wrapperData->type == WrapperData::Type::FastSendRecv) { (*(wrapperData->msg_handler))(wrapperData->fun_arg); delete wrapperData; } else if ( ((WrapperRelayData*)wrapperData)->type == WrapperData::Type:: DetailedRecv) { // should be replaced by // type==WrapperData::Type::DetailedRecv WrapperRelayData* wrapperRelayData = (WrapperRelayData*)arg; timespec_t current_time = wrapperRelayData->fast_backend->wrapped_backend->sim_get_time(); wrapperRelayData->fast_backend->update_table_recv( wrapperRelayData->creation_time, current_time.time_val, wrapperRelayData->partner_node, wrapperRelayData->comm_size); (*(wrapperRelayData->msg_handler))(wrapperRelayData->fun_arg); delete wrapperRelayData; } else if ( ((WrapperRelayData*)wrapperData)->type == WrapperData::Type:: DetailedSend) { // should be replaced by // type==WrapperData::Type::DetailedSend WrapperRelayData* wrapperRelayData = (WrapperRelayData*)arg; timespec_t current_time = wrapperRelayData->fast_backend->wrapped_backend->sim_get_time(); wrapperRelayData->fast_backend->update_table_send( wrapperRelayData->creation_time, current_time.time_val, wrapperRelayData->partner_node, wrapperRelayData->comm_size); (*(wrapperRelayData->msg_handler))(wrapperRelayData->fun_arg); delete wrapperRelayData; } else { std::cerr << "Event type undefined!" << std::endl; } } FastBackEnd::FastBackEnd(int rank, AstraNetworkAPI* wrapped_backend) : AstraNetworkAPI(rank) { this->wrapped_backend = wrapped_backend; } double FastBackEnd::sim_time_resolution() { return wrapped_backend->sim_time_resolution(); } int FastBackEnd::sim_finish() { wrapped_backend->sim_finish(); delete this; return 1; } // int FastBackEnd::sim_comm_get_rank(sim_comm comm, int *size) { // return wrapped_backend->sim_comm_get_rank(comm,size); //} // int FastBackEnd::sim_comm_set_rank(sim_comm comm, int rank) { // return wrapped_backend->sim_comm_set_rank(comm,rank); //} int FastBackEnd::sim_comm_size(sim_comm comm, int* size) { return wrapped_backend->sim_comm_size(comm, size); } timespec_t FastBackEnd::sim_get_time() { return wrapped_backend->sim_get_time(); } int FastBackEnd::sim_init(AstraMemoryAPI* MEM) { return wrapped_backend->sim_init(MEM); } void FastBackEnd::sim_schedule( timespec_t delta, void (*fun_ptr)(void* fun_arg), void* fun_arg) { wrapped_backend->sim_schedule(delta, fun_ptr, fun_arg); return; } void FastBackEnd::update_table_send( double start, double finished, int dst, uint64_t comm_size) { auto latency = finished - start; auto src = sim_comm_get_rank(); dynamicLatencyTable.insertLatencyData( std::make_pair(src, dst), comm_size, latency); } void FastBackEnd::update_table_recv( double start, double finished, int src, uint64_t comm_size) { auto latency = finished - start; auto dst = sim_comm_get_rank(); dynamicLatencyTable.insertLatencyData( std::make_pair(src, dst), comm_size, latency); } int FastBackEnd::relay_recv_request( void* buffer, uint64_t count, int type, int src, int tag, sim_request* request, void (*msg_handler)(void* fun_arg), void* fun_arg) { timespec_t current_time = wrapped_backend->sim_get_time(); WrapperRelayData* wrapperData = new WrapperRelayData( this, WrapperData::Type::DetailedRecv, current_time.time_val, src, count, msg_handler, fun_arg); return wrapped_backend->sim_recv( buffer, count, type, src, tag, request, &FastBackEnd::handleEvent, wrapperData); } int FastBackEnd::relay_send_request( void* buffer, uint64_t count, int type, int dst, int tag, sim_request* request, void (*msg_handler)(void* fun_arg), void* fun_arg) { timespec_t current_time = wrapped_backend->sim_get_time(); WrapperRelayData* wrapperData = new WrapperRelayData( this, WrapperData::Type::DetailedSend, current_time.time_val, dst, count, msg_handler, fun_arg); return wrapped_backend->sim_send( buffer, count, type, dst, tag, request, &FastBackEnd::handleEvent, wrapperData); } int FastBackEnd::fast_send_recv_request( double delay, void (*msg_handler)(void* fun_arg), void* fun_arg) { timespec_t delta; delta.time_res = time_type_e::NS; delta.time_val = delay; WrapperData* wrapperData = new WrapperData(WrapperData::Type::FastSendRecv, msg_handler, fun_arg); wrapped_backend->sim_schedule( delta, &FastBackEnd::handleEvent, (void*)wrapperData); return 1; } int FastBackEnd::sim_send( void* buffer, uint64_t count, int type, int dst, int tag, sim_request* request, void (*msg_handler)(void* fun_arg), void* fun_arg) { auto src = sim_comm_get_rank(); auto srcDestPair = std::make_pair(src, dst); auto inflightPair = inflightPairsMap.pop(src, dst, tag, count); if (inflightPair.first) { // inflight pair exists. Use that method. switch (inflightPair.second) { case WrapperData::Type::FastSendRecv: { // Use fast send. either lookup latency or predicted one is used. auto lookupResult = dynamicLatencyTable.lookupLatency(srcDestPair, count); if (lookupResult.first) { // lookup result exists return fast_send_recv_request( lookupResult.second, msg_handler, fun_arg); } // lookup doesn't exist - use prediction auto predictedLatency = dynamicLatencyTable.predictLatency(srcDestPair, count); return fast_send_recv_request(predictedLatency, msg_handler, fun_arg); } case WrapperData::Type::DetailedRecv: { // relay send request to ns3 return relay_send_request( buffer, count, type, dst, tag, request, msg_handler, fun_arg); } default: { // should not fall here std::cerr << "sim_send inflight pair error" << std::endl; exit(-1); } } } // inflight pair doesn't exist -- initiate one auto lookupResult = dynamicLatencyTable.lookupLatency(srcDestPair, count); if (lookupResult.first) { // lookup result exists; use this latency for fast simulation inflightPairsMap.insert( src, dst, tag, count, WrapperData::Type::FastSendRecv); return fast_send_recv_request(lookupResult.second, msg_handler, fun_arg); } // lookup doesn't exist if (dynamicLatencyTable.canPredictLatency(srcDestPair)) { // prediction available // for 10%, don't use prediction and relay the request. if ((rand() % 100) < 10) { inflightPairsMap.insert( src, dst, tag, count, WrapperData::Type::DetailedSend); return relay_send_request( buffer, count, type, dst, tag, request, msg_handler, fun_arg); } // for 90% cases, use prediction result. inflightPairsMap.insert( src, dst, tag, count, WrapperData::Type::FastSendRecv); auto predictedLatency = dynamicLatencyTable.predictLatency(srcDestPair, count); return fast_send_recv_request(predictedLatency, msg_handler, fun_arg); } // prediction not available -- should relay to ns3 inflightPairsMap.insert( src, dst, tag, count, WrapperData::Type::DetailedSend); return relay_send_request( buffer, count, type, dst, tag, request, msg_handler, fun_arg); } int FastBackEnd::sim_recv( void* buffer, uint64_t count, int type, int src, int tag, sim_request* request, void (*msg_handler)(void* fun_arg), void* fun_arg) { auto dst = sim_comm_get_rank(); auto srcDestPair = std::make_pair(src, dst); auto inflightPair = inflightPairsMap.pop(src, dst, tag, count); if (inflightPair.first) { // inflight pair exists. Use that method. switch (inflightPair.second) { case WrapperData::Type::FastSendRecv: { // Use fast recv. either lookup latency or predicted one is used. auto lookupResult = dynamicLatencyTable.lookupLatency(srcDestPair, count); if (lookupResult.first) { // lookup result exists. use this latency return fast_send_recv_request( lookupResult.second, msg_handler, fun_arg); } // lookup doesn't exist. use prediction. auto predictedLatency = dynamicLatencyTable.predictLatency(srcDestPair, count); return fast_send_recv_request(predictedLatency, msg_handler, fun_arg); } case WrapperData::Type::DetailedSend: { // relay to ns3 return relay_recv_request( buffer, count, type, src, tag, request, msg_handler, fun_arg); } default: { // should not fall here std::cerr << "sim_recv inflight pair error" << std::endl; exit(-1); } } } // inflight pair doesn't exist -- initiate one auto lookupResult = dynamicLatencyTable.lookupLatency(srcDestPair, count); if (lookupResult.first) { // lookup result exists; use fast simulation inflightPairsMap.insert( src, dst, tag, count, WrapperData::Type::FastSendRecv); return fast_send_recv_request(lookupResult.second, msg_handler, fun_arg); } // lookup doesn't exist if (dynamicLatencyTable.canPredictLatency(srcDestPair)) { // prediction available // for 10%, don't use prediction and relay the request. if ((rand() % 100) < 10) { inflightPairsMap.insert( src, dst, tag, count, WrapperData::Type::DetailedRecv); return relay_recv_request( buffer, count, type, src, tag, request, msg_handler, fun_arg); } // for 90% cases, use prediction result. inflightPairsMap.insert( src, dst, tag, count, WrapperData::Type::FastSendRecv); auto predictedLatency = dynamicLatencyTable.predictLatency(srcDestPair, count); return fast_send_recv_request(predictedLatency, msg_handler, fun_arg); } // cannot predict -- relay to ns3 inflightPairsMap.insert( src, dst, tag, count, WrapperData::Type::DetailedRecv); return relay_recv_request( buffer, count, type, src, tag, request, msg_handler, fun_arg); } } // namespace AstraSim