in astra-sim-alibabacloud/astra-sim/system/scheduling/OfflineGreedy.cc [80:301]
std::vector<int> OfflineGreedy::get_chunk_scheduling(
long long chunk_id,
uint64_t& remaining_data_size,
uint64_t recommended_chunk_size,
std::vector<bool>& dimensions_involved,
InterDimensionScheduling inter_dim_scheduling,
ComType comm_type) {
if (chunk_schedule.find(chunk_id) != chunk_schedule.end()) {
schedule_consumer[chunk_id]++;
if (schedule_consumer[chunk_id] == sys->all_generators.size()) {
std::vector<int> res = chunk_schedule[chunk_id];
remaining_data_size -= global_chunk_size[chunk_id];
chunk_schedule.erase(chunk_id);
schedule_consumer.erase(chunk_id);
global_chunk_size.erase(chunk_id);
return res;
}
remaining_data_size -= global_chunk_size[chunk_id];
return chunk_schedule[chunk_id];
}
if (sys->id != 0) {
return sys->all_generators[0]->offline_greedy->get_chunk_scheduling(
chunk_id,
remaining_data_size,
recommended_chunk_size,
dimensions_involved,
inter_dim_scheduling,
comm_type);
} else {
if (comm_type == ComType::All_Reduce) {
comm_type = ComType::Reduce_Scatter;
}
std::sort(dim_elapsed_time.begin(), dim_elapsed_time.end());
if (comm_type == ComType::All_Gather) {
std::reverse(dim_elapsed_time.begin(), dim_elapsed_time.end());
}
std::vector<int> result;
uint64_t chunk_size =
recommended_chunk_size; //*(dim_BW[dim_elapsed_time.front().dim_num]/dim_BW[0]);
bool chunk_size_calculated = false;
if (inter_dim_scheduling == InterDimensionScheduling::OfflineGreedy) {
global_chunk_size[chunk_id] = std::min(remaining_data_size, chunk_size);
remaining_data_size -= std::min(remaining_data_size, chunk_size);
}
int dim_elapsed_time_pointer = -1;
for (auto& dim : dim_elapsed_time) {
dim_elapsed_time_pointer++;
if (!dimensions_involved[dim.dim_num] || dim_size[dim.dim_num] == 1) {
result.push_back(dim.dim_num);
continue;
} else if (
inter_dim_scheduling == InterDimensionScheduling::OfflineGreedyFlex &&
!chunk_size_calculated) {
chunk_size_calculated = true;
if (comm_type == ComType::Reduce_Scatter) {
double load_difference =
fabs(dim_elapsed_time.back().elapsed_time - dim.elapsed_time);
chunk_size = get_chunk_size_from_elapsed_time(
load_difference, dim, ComType::Reduce_Scatter);
} else {
int lastIndex = dim_elapsed_time.size() - 1;
while (!dimensions_involved[dim_elapsed_time[lastIndex].dim_num] ||
dim_size[dim_elapsed_time[lastIndex].dim_num] == 1) {
lastIndex--;
}
double load_difference =
fabs(dim_elapsed_time[lastIndex].elapsed_time - dim.elapsed_time);
chunk_size = get_chunk_size_from_elapsed_time(
load_difference,
dim_elapsed_time[lastIndex],
ComType::All_Gather);
lastIndex--;
while (dim_elapsed_time_pointer <= lastIndex) {
if (dimensions_involved[dim_elapsed_time[lastIndex].dim_num] &&
dim_size[dim_elapsed_time[lastIndex].dim_num] > 1) {
chunk_size /= dim_size[dim_elapsed_time[lastIndex].dim_num];
}
lastIndex--;
}
}
if (chunk_size < (recommended_chunk_size)) {
result.resize(dim_elapsed_time.size());
std::iota(std::begin(result), std::end(result), 0);
global_chunk_size[chunk_id] =
std::min(remaining_data_size, recommended_chunk_size);
chunk_size = std::min(remaining_data_size, recommended_chunk_size);
remaining_data_size -=
std::min(remaining_data_size, recommended_chunk_size);
chunk_schedule[chunk_id] = result;
schedule_consumer[chunk_id] = 1;
std::vector<DimElapsedTime> myReordered;
myReordered.resize(dim_elapsed_time.size(), dim_elapsed_time[0]);
for (int myDim = 0; myDim < dim_elapsed_time.size(); myDim++) {
for (int searchDim = 0; searchDim < dim_elapsed_time.size();
searchDim++) {
if (dim_elapsed_time[searchDim].dim_num == myDim) {
myReordered[myDim] = dim_elapsed_time[searchDim];
break;
}
}
}
dim_elapsed_time = myReordered;
if (comm_type == ComType::All_Gather) {
std::reverse(dim_elapsed_time.begin(), dim_elapsed_time.end());
}
for (int myDim = 0; myDim < dim_elapsed_time.size(); myDim++) {
if (!dimensions_involved[myDim] || dim_size[myDim] == 1) {
result.push_back(myDim);
continue;
}
if (comm_type == ComType::Reduce_Scatter) {
dim_elapsed_time[myDim].elapsed_time +=
((((double)chunk_size) / 1048576) *
(((double)(dim_size[myDim] - 1)) / (dim_size[myDim]))) /
(dim_BW[myDim] / dim_BW[0]);
chunk_size /= dim_size[myDim];
} else {
dim_elapsed_time[myDim].elapsed_time +=
((((double)chunk_size) / 1048576) *
(((double)(dim_size[myDim] - 1)))) /
(dim_BW[myDim] / dim_BW[0]);
chunk_size *= dim_size[myDim];
}
}
return result;
} else {
global_chunk_size[chunk_id] =
std::min(remaining_data_size, chunk_size);
remaining_data_size -= std::min(remaining_data_size, chunk_size);
}
} else if (
inter_dim_scheduling == InterDimensionScheduling::OfflineGreedy &&
!chunk_size_calculated) {
chunk_size_calculated = true;
uint64_t diff_size = 0;
if (comm_type == ComType::Reduce_Scatter) {
double load_difference =
fabs(dim_elapsed_time.back().elapsed_time - dim.elapsed_time);
diff_size = get_chunk_size_from_elapsed_time(
load_difference, dim, ComType::Reduce_Scatter);
} else {
int lastIndex = dim_elapsed_time.size() - 1;
while (!dimensions_involved[dim_elapsed_time[lastIndex].dim_num] ||
dim_size[dim_elapsed_time[lastIndex].dim_num] == 1) {
lastIndex--;
}
double load_difference =
fabs(dim_elapsed_time[lastIndex].elapsed_time - dim.elapsed_time);
diff_size = get_chunk_size_from_elapsed_time(
load_difference,
dim_elapsed_time[lastIndex],
ComType::All_Gather);
lastIndex--;
while (dim_elapsed_time_pointer <= lastIndex) {
if (dimensions_involved[dim_elapsed_time[lastIndex].dim_num] &&
dim_size[dim_elapsed_time[lastIndex].dim_num] > 1) {
diff_size /= dim_size[dim_elapsed_time[lastIndex].dim_num];
}
lastIndex--;
}
}
if (diff_size < (recommended_chunk_size / 16)) {
result.resize(dim_elapsed_time.size());
std::iota(std::begin(result), std::end(result), 0);
chunk_schedule[chunk_id] = result;
schedule_consumer[chunk_id] = 1;
std::vector<DimElapsedTime> myReordered;
myReordered.resize(dim_elapsed_time.size(), dim_elapsed_time[0]);
for (int myDim = 0; myDim < dim_elapsed_time.size(); myDim++) {
for (int searchDim = 0; searchDim < dim_elapsed_time.size();
searchDim++) {
if (dim_elapsed_time[searchDim].dim_num == myDim) {
myReordered[myDim] = dim_elapsed_time[searchDim];
break;
}
}
}
dim_elapsed_time = myReordered;
if (comm_type == ComType::All_Gather) {
std::reverse(dim_elapsed_time.begin(), dim_elapsed_time.end());
}
for (int myDim = 0; myDim < dim_elapsed_time.size(); myDim++) {
if (!dimensions_involved[myDim] || dim_size[myDim] == 1) {
// result.push_back(myDim);
continue;
}
if (comm_type == ComType::Reduce_Scatter) {
dim_elapsed_time[myDim].elapsed_time +=
((((double)chunk_size) / 1048576) *
(((double)(dim_size[myDim] - 1)) / (dim_size[myDim]))) /
(dim_BW[myDim] / dim_BW[0]);
chunk_size /= dim_size[myDim];
} else {
dim_elapsed_time[myDim].elapsed_time +=
((((double)chunk_size) / 1048576) *
(((double)(dim_size[myDim] - 1)))) /
(dim_BW[myDim] / dim_BW[0]);
chunk_size *= dim_size[myDim];
}
}
return result;
}
}
result.push_back(dim.dim_num);
if (comm_type == ComType::Reduce_Scatter) {
dim.elapsed_time += ((((double)chunk_size) / 1048576) *
(((double)(dim_size[dim.dim_num] - 1)) /
(dim_size[dim.dim_num]))) /
(dim_BW[dim.dim_num] / dim_BW[0]);
chunk_size /= dim_size[dim.dim_num];
} else {
dim.elapsed_time += ((((double)chunk_size) / 1048576) *
(((double)(dim_size[dim.dim_num] - 1)))) /
(dim_BW[dim.dim_num] / dim_BW[0]);
chunk_size *= dim_size[dim.dim_num];
}
}
chunk_schedule[chunk_id] = result;
schedule_consumer[chunk_id] = 1;
return result;
}
}